import { AccountInfo, Connection, PublicKey } from '@solana/web3.js';
import { isValidRoute, MarketInfo } from './market';
import { MIN_SEGMENT_SIZE_FOR_INTERMEDIATE_MINTS } from '../constants';
import { RaydiumAmm } from './raydium/raydiumAmm';
import fetch from 'cross-fetch';
import { TokenRouteSegments } from './types';
import { Amm, prefetchAmms, SwapMode } from './amm';
import { SerumAmm } from './serum/serumAmm';
import { ammFactory } from './ammFactory';
import { getSaberWrappedDecimalsAmms } from './saber/saberAddDecimalsAmm';
import { SplitTradeAmm } from './split-trade/splitTradeAmm';
import { getTwoPermutations } from '../utils/getTwoPermutations';
import { chunkedGetMultipleAccountInfos } from '../utils/chunkedGetMultipleAccountInfos';
import JSBI from 'jsbi';

export interface TransactionFeeInfo {
  signatureFee: number;
  openOrdersDeposits: number[];
  ataDeposits: number[];
  /* Overall amount that will be deducted from user wallet after the swap */
  totalFeeAndDeposits: number;
  /* Amount for fee, deposits and temporary token accounts */
  minimumSOLForTransaction: number;
}

export interface RouteInfo {
  marketInfos: MarketInfo[];
  inAmount: JSBI;
  outAmount: JSBI;
  amount: JSBI; // The target amount, expect input or expect output
  otherAmountThreshold: JSBI;
  swapMode: SwapMode;
  priceImpactPct: number;
  getDepositAndFee: () => Promise<TransactionFeeInfo | undefined>;
}

type MarketsCache = Array<
  Omit<AccountInfo<Buffer>, 'data' | 'owner'> & {
    data: [string, 'base64'];
    owner: string;
    pubkey: string;
  }
>;

type KeyedAccountInfo = AccountInfo<Buffer> & {
  pubkey: PublicKey;
  // api can pass some extra params
  params?: any;
};

export const fetchMarketCache = async (url: string) => {
  const marketsCache = (await (await fetch(url)).json()) as MarketsCache;
  return marketsCache;
};

/** For testing purposes when api does not have the new pools */
export async function fetchExtraKeyedAccountInfos(connection: Connection, pks: PublicKey[]) {
  const extraKeyedAccountInfos = (
    await chunkedGetMultipleAccountInfos(
      connection,
      pks.map((item) => item.toBase58()),
    )
  ).map((item, index) => {
    const pubkey = pks[index];
    if (!item) throw new Error(`Failed to fetch pool ${pubkey.toBase58()}`);
    return { pubkey, ...item };
  });
  return extraKeyedAccountInfos;
}

export async function getAllAmms(connection: Connection, marketsCache: MarketsCache): Promise<Amm[]> {
  const marketCacheToAccountInfo = (marketsCache: MarketsCache): Array<KeyedAccountInfo> => {
    return marketsCache.map((market) => {
      const {
        data: [accountInfo, format],
        pubkey,
        ...rest
      } = market;
      return {
        ...rest,
        pubkey: new PublicKey(pubkey),
        data: Buffer.from(accountInfo, format),
        owner: new PublicKey(rest.owner),
      };
    });
  };

  const marketKeyedAccountInfos = marketCacheToAccountInfo(marketsCache);

  // this is used for development
  const extraKeys: Array<PublicKey> = [];

  if (extraKeys.length) {
    const extraKeyedAccountInfos = await fetchExtraKeyedAccountInfos(connection, extraKeys);
    marketKeyedAccountInfos.push(...extraKeyedAccountInfos);
  }

  const amms = marketKeyedAccountInfos.reduce((acc, keyedAccountInfo) => {
    const amm = ammFactory(keyedAccountInfo.pubkey, keyedAccountInfo, keyedAccountInfo.params);
    // Amm might not be recognized by the current version of the frontend
    // or be in a state we don't want
    if (amm) {
      acc.push(amm);
    }
    return acc;
  }, new Array<Amm>());

  await prefetchAmms(
    amms.filter((amm) => amm.shouldPrefetch),
    connection,
  );

  amms.push(...getSaberWrappedDecimalsAmms());

  return amms;
}

export function ammCrossProtocolPairs(arr: Amm[], callback: (a: Amm, b: Amm) => void) {
  for (let i = 0; i < arr.length; i++) {
    for (let j = i + 1; j < arr.length; j++) {
      // Don't pair amm with same label
      if (arr[i].label !== arr[j].label) {
        callback(arr[i], arr[j]);
      }
    }
  }
}

const mintCache: Record<string, string> = {};

// Since the mints are mostly repeated we want to avoid converting the same pk again
// This seems to bring getTokenRouteSegments from 100ms => 50ms
function getOrUpdatePublicKeyCache(pk: PublicKey) {
  //@ts-ignore
  const pkBase64 = pk._bn.toString();
  const cached = mintCache[pkBase64];
  if (cached) {
    return cached;
  } else {
    const pkBase58 = pk.toBase58();
    mintCache[pkBase64] = pkBase58;
    return pkBase58;
  }
}

export function getTokenRouteSegments(amms: Amm[]): TokenRouteSegments {
  const tokenRouteSegments = new Map<string, Map<string, Amm[]>>();

  amms.forEach((amm) => {
    const reserveTokenMintPermutations = getTwoPermutations(amm.reserveTokenMints);
    reserveTokenMintPermutations.forEach(([firstReserveMint, secondReserveMint]) => {
      const firstMintBase58 = getOrUpdatePublicKeyCache(firstReserveMint);
      const secondMintBase58 = getOrUpdatePublicKeyCache(secondReserveMint);
      addSegment(firstMintBase58, secondMintBase58, amm, tokenRouteSegments);
    });
  });

  return tokenRouteSegments;
}

function addSegment(inMint: string, outMint: string, amm: Amm, tokenRouteSegments: TokenRouteSegments) {
  let segments = tokenRouteSegments.get(inMint);

  if (!segments) {
    segments = new Map<string, Amm[]>([[outMint, []]]);
    tokenRouteSegments.set(inMint, segments);
  }

  let amms = segments.get(outMint);
  if (!amms) {
    amms = [];
    segments.set(outMint, amms);
  }

  amms.push(amm);
}

export type Route = {
  amms: Amm[];
  mints: PublicKey[];
};

/*
 * Construct TokenRouteSegment that is only used for the selected inputMint and outputMint
 * Example:
 *   SOL => USDC, the map would consist of
 *     - SOL => USDC => Amm[]
 *     - SOL => USDT => Amm[]
 *     - USDT => SOL => Amm[]
 */
export function computeInputRouteSegments({
  inputMint,
  outputMint,
  tokenRouteSegments,
  intermediateTokens,
  swapMode,
  onlyDirectRoutes,
}: {
  inputMint: string;
  outputMint: string;
  tokenRouteSegments: TokenRouteSegments;
  intermediateTokens?: string[];
  swapMode: SwapMode;
  onlyDirectRoutes?: boolean;
}): TokenRouteSegments {
  const inputRouteSegments: TokenRouteSegments = new Map();

  const inputSegment = tokenRouteSegments.get(inputMint);
  const outputSegment = tokenRouteSegments.get(outputMint);

  if (inputSegment && outputSegment) {
    const minSegmentSize = Math.min(inputSegment.size, outputSegment.size);
    // this is used to minimize the looping part
    // if SOL => MER, SOL has 100 keys but MER has 6 keys so only the first 6 loops are required always
    const shouldStartWithInputSegment = inputSegment.size < outputSegment.size;

    const inputInnerMap = new Map<string, Amm[]>();
    const outputInnerMap = new Map<string, Amm[]>();
    let [startSegment, endSegment, startMint, endMint] = shouldStartWithInputSegment
      ? [inputSegment, outputSegment, inputMint, outputMint]
      : [outputSegment, inputSegment, outputMint, inputMint];

    for (let [mint, amms] of startSegment.entries()) {
      let filteredAmms = swapMode === SwapMode.ExactIn ? amms : amms.filter((amm) => amm.exactOutputSupported);

      if (mint === endMint) {
        inputInnerMap.set(mint, filteredAmms);
        outputInnerMap.set(startMint, filteredAmms);
        continue;
      }

      if (
        onlyDirectRoutes ||
        swapMode === SwapMode.ExactOut ||
        shouldSkipOutputMint(intermediateTokens, minSegmentSize, mint)
      ) {
        continue;
      }

      const intersectionAmms = endSegment.get(mint);
      if (intersectionAmms) {
        let filteredIntersectionAmms =
          swapMode === SwapMode.ExactIn ? intersectionAmms : intersectionAmms.filter((amm) => amm.exactOutputSupported);

        inputRouteSegments.set(
          mint,
          new Map([
            [startMint, filteredAmms],
            [endMint, filteredIntersectionAmms],
          ]),
        );
        inputInnerMap.set(mint, filteredAmms);
        outputInnerMap.set(mint, filteredIntersectionAmms);
      }
    }
    inputRouteSegments.set(startMint, inputInnerMap);
    inputRouteSegments.set(endMint, outputInnerMap);
  }

  return inputRouteSegments;
}

export function computeRouteMap(
  tokenRouteSegments: TokenRouteSegments,
  intermediateTokens?: string[],
  onlyDirectRoutes?: boolean,
): Map<string, string[]> {
  const routeMap = new Map<string, string[]>();

  for (const [tokenMint, firstLevelOutputs] of tokenRouteSegments) {
    const validOutputMints = new Set<string>();

    for (const [firstLevelOutputMint, firstLevelAmms] of firstLevelOutputs) {
      validOutputMints.add(firstLevelOutputMint);

      if (onlyDirectRoutes) {
        continue;
      }
      // add the single level output as possible valid mints as well
      const secondLevelOutputs = tokenRouteSegments.get(firstLevelOutputMint) ?? new Map<string, Amm[]>();

      for (const [secondLevelOutputMint, secondLevelAmms] of secondLevelOutputs) {
        // Prevent output mint == input mint when routing
        if (secondLevelOutputMint === tokenMint) {
          continue;
        }

        const outputMintSize = tokenRouteSegments.get(secondLevelOutputMint)?.size ?? 0;
        const minSegmentSize = Math.min(firstLevelOutputs.size, outputMintSize);

        // if intermediateTokens is specified and it doesnt include in the intermediateTokens, skip it
        if (shouldSkipOutputMint(intermediateTokens, minSegmentSize, firstLevelOutputMint)) {
          continue;
        }

        let found = false;
        for (const firstLevelAmm of firstLevelAmms) {
          for (const secondLevelAmm of secondLevelAmms) {
            if (isValidRoute(firstLevelAmm, secondLevelAmm)) {
              validOutputMints.add(secondLevelOutputMint);
              found = true;
              break;
            }
          }
          if (found) break;
        }
      }
    }
    routeMap.set(tokenMint, Array.from(validOutputMints));
  }

  return routeMap;
}

interface SplitTradeRequiredParams {
  hasSerumOpenOrderInstruction: boolean;
}

function shouldSkipOutputMint(
  intermediateTokens: string[] | undefined,
  minSegmentSize: number,
  outputMint: string,
): boolean {
  return Boolean(
    intermediateTokens &&
      minSegmentSize > MIN_SEGMENT_SIZE_FOR_INTERMEDIATE_MINTS &&
      !intermediateTokens.includes(outputMint),
  );
}

export function isSplitSetupRequired(
  marketInfos: MarketInfo[],
  { hasSerumOpenOrderInstruction }: SplitTradeRequiredParams,
): { needSetup: boolean; needCleanup: boolean } {
  let firstAmm: Amm;
  let secondAmm: Amm;

  if (marketInfos.length === 1) {
    const amm = marketInfos[0].amm;
    if (amm instanceof SplitTradeAmm) {
      firstAmm = amm.firstAmm;
      secondAmm = amm.secondAmm;
    } else {
      return { needSetup: false, needCleanup: false };
    }
  } else {
    [firstAmm, secondAmm] = marketInfos.map((marketInfo) => marketInfo.amm);
  }

  if (firstAmm instanceof RaydiumAmm || secondAmm instanceof RaydiumAmm) {
    return { needSetup: true, needCleanup: true };
  } else if (firstAmm instanceof SerumAmm && secondAmm instanceof SerumAmm) {
    return { needSetup: true, needCleanup: true };
  } else if (hasSerumOpenOrderInstruction) {
    return { needSetup: true, needCleanup: false };
  }

  return { needSetup: false, needCleanup: false };
}

// We cannot add platform fee to all possible routing due to transaction size limit
export function isPlatformFeeSupported(swapMode: SwapMode, amms: Amm[]): boolean {
  if (swapMode === SwapMode.ExactOut) return false;

  if (amms.length > 1) {
    const [firstMarket, secondMarket] = amms;

    if (firstMarket instanceof RaydiumAmm && secondMarket instanceof RaydiumAmm) {
      return false;
    }
  }
  return true;
}

export function getRouteInfoUniqueId(routeInfo: RouteInfo) {
  return routeInfo.marketInfos.map((marketInfo) => `${marketInfo.amm.id}-${marketInfo.inputMint}`).join('-');
}
