import JSBI from 'jsbi';
import { SwapResult, toDecimal, ZERO, ceilingDivision } from '../utils';
import Decimal from 'decimal.js';
import { Fraction } from '..';
import { calculateFeeAmount } from './fees';

export class TokenSwapConstantProduct {
  constructor(private traderFee: Fraction, private ownerFee: Fraction, private feesOnInput: boolean = true) {}

  public exchange(tokenAmounts: JSBI[], inputTradeAmount: JSBI, outputIndex: number): SwapResult {
    let inputIndex = outputIndex === 0 ? 1 : 0;
    const newInputTradeAmount = this.feesOnInput ? this.getAmountLessFees(inputTradeAmount) : inputTradeAmount;

    let expectedOutputAmount = this.getExpectedOutputAmount(tokenAmounts, newInputTradeAmount, inputIndex, outputIndex);

    let fees = this.getFees(this.feesOnInput ? inputTradeAmount : expectedOutputAmount);

    if (!this.feesOnInput) {
      expectedOutputAmount = this.getAmountLessFees(expectedOutputAmount);
    }

    return {
      priceImpact: this.getPriceImpact(
        tokenAmounts,
        newInputTradeAmount,
        expectedOutputAmount,
        inputIndex,
        outputIndex,
      ),
      fees,
      expectedOutputAmount,
    };
  }

  private getPriceImpact(
    tokenAmounts: JSBI[],
    inputTradeAmountJSBI: JSBI,
    expectedOutputAmountJSBI: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): Decimal {
    if (
      JSBI.equal(inputTradeAmountJSBI, ZERO) ||
      JSBI.equal(tokenAmounts[inputIndex], ZERO) ||
      JSBI.equal(tokenAmounts[outputIndex], ZERO)
    ) {
      return new Decimal(0);
    }

    const noSlippageOutputAmount = toDecimal(
      this.getExpectedOutputAmountWithNoSlippage(tokenAmounts, inputTradeAmountJSBI, inputIndex, outputIndex),
    );
    const expectedOutputAmount = toDecimal(expectedOutputAmountJSBI);
    const impact = noSlippageOutputAmount.sub(expectedOutputAmount).div(noSlippageOutputAmount);

    return impact;
  }

  private getFees(inputTradeAmount: JSBI): JSBI {
    const tradingFee = calculateFeeAmount(inputTradeAmount, this.traderFee);
    const ownerFee = calculateFeeAmount(inputTradeAmount, this.ownerFee);

    return JSBI.add(tradingFee, ownerFee);
  }

  private getExpectedOutputAmount(
    tokenAmounts: JSBI[],
    inputTradeAmount: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): JSBI {
    return this.getOutputAmount(tokenAmounts, inputTradeAmount, inputIndex, outputIndex);
  }

  private getExpectedOutputAmountWithNoSlippage(
    tokenAmounts: JSBI[],
    inputTradeAmount: JSBI,
    inputIndex: number,
    outputIndex: number,
  ): JSBI {
    if (JSBI.equal(tokenAmounts[inputIndex], ZERO)) {
      return tokenAmounts[outputIndex];
    }

    let expectedOutputAmountWithNoSlippage = JSBI.divide(
      JSBI.multiply(inputTradeAmount, tokenAmounts[outputIndex]),
      tokenAmounts[inputIndex],
    );

    if (this.feesOnInput) {
      return expectedOutputAmountWithNoSlippage;
    } else {
      return this.getAmountLessFees(expectedOutputAmountWithNoSlippage);
    }
  }

  private getAmountLessFees(tradeAmount: JSBI): JSBI {
    return JSBI.subtract(tradeAmount, this.getFees(tradeAmount));
  }

  private getOutputAmount(tokenAmounts: JSBI[], inputTradeAmount: JSBI, inputIndex: number, outputIndex: number): JSBI {
    const [poolInputAmount, poolOutputAmount] = [tokenAmounts[inputIndex], tokenAmounts[outputIndex]];

    const invariant = this.getInvariant(tokenAmounts);

    const [newPoolOutputAmount] = ceilingDivision(invariant, JSBI.add(poolInputAmount, inputTradeAmount));

    return JSBI.subtract(poolOutputAmount, newPoolOutputAmount);
  }

  getInvariant(tokenAmounts: JSBI[]) {
    return JSBI.multiply(tokenAmounts[0], tokenAmounts[1]);
  }
}
