import { invariant } from '../validation';
import { visitAllTopLevelReferenceNodes } from './ast-common';
import { ASTLambdaNode, ASTLambdaParameterReferenceNode, ASTRootNode, type ASTNode } from './ast-types';
import { ERROR_CALC, ERROR_VALUE, ERROR_NUM } from './constants';
import { unbox, MaybeBoxed } from './ValueBox';
import type { EvaluationContext } from './EvaluationContext';
import { type MaybeBoxedFormulaArgument } from './types';
import { stringTruncate } from '../../lib/utils-stringTruncate';
import { DefaultMap } from '@grid-is/collections';

type LambdaASTArg = {
  astNode: ASTNode,
  ctx: EvaluationContext,
};

type LambdaValueArg = {
  value: MaybeBoxedFormulaArgument,
};

export type LambdaArg = LambdaASTArg | LambdaValueArg;

export type LambdaParameterSymbol = {
  param: string,
  expression: ASTRootNode,
};

export type LambdaBindings = {
  /**
   * Mapping from parameter name to lambda-specific parameter symbol.
   */
  params: Map<string, LambdaParameterSymbol>,
  /**
   * Mapping from lambda-specific parameter symbol to argument AST.
   * This exists in the context of evaluating a lambda call.
   */
  args: Map<LambdaParameterSymbol, LambdaArg>,
};

export class Lambda {
  readonly parameterNames: string[];
  readonly parameterSymbols: Map<string, LambdaParameterSymbol>;
  readonly expression: ASTRootNode;
  readonly boundArgs: undefined | Map<LambdaParameterSymbol, LambdaArg>;
  readonly ownParameterSymbols: DefaultMap<string, LambdaParameterSymbol>;

  constructor (
    parameterNames: string[],
    expression: ASTRootNode,
    outer?: LambdaBindings,
    parameterSymbols?: Map<string, LambdaParameterSymbol>,
  ) {
    this.parameterNames = parameterNames;
    this.expression = expression;
    this.ownParameterSymbols = new DefaultMap((param: string) => ({
      param,
      expression: JSON.stringify(this.expression),
    }));
    this.parameterSymbols =
      parameterSymbols ||
      new Map(this.parameterNames.map(param => [ param.toLowerCase(), this.ownParameterSymbols.get(param) ]));
    if (outer) {
      for (const [ paramName, outerSymbol ] of outer.params.entries()) {
        const paramNameLower = paramName.toLowerCase();
        this.parameterSymbols.set(paramNameLower, outerSymbol);
      }
    }
    this.boundArgs = outer?.args;
    visitAllTopLevelReferenceNodes(this.expression, node => {
      if ('name' in node) {
        const boundParameterSymbol = this.parameterSymbols.get(node.name.toLowerCase());
        if (boundParameterSymbol != null) {
          (node as unknown as ASTLambdaParameterReferenceNode).param = boundParameterSymbol;
          delete (node as any).name;
        }
      }
    });
  }

  call (ctx: EvaluationContext, args: MaybeBoxedFormulaArgument[]): MaybeBoxedFormulaArgument {
    if (args.length !== this.numParams) {
      return ERROR_VALUE.detailed(
        `Invalid call with ${args.length} arguments to lambda with ${this.numParams || 0} parameters`,
      );
    }
    const argMap = bindArgs(this, ctx, { valueArgs: args });
    return evaluateLambda(this, ctx, argMap);
  }

  callWithAST (ctx: EvaluationContext, args: ASTNode[]): MaybeBoxedFormulaArgument {
    if (args.length !== this.numParams) {
      return ERROR_VALUE.detailed(
        `Invalid call with ${args.length} arguments to lambda with ${this.numParams || 0} parameters`,
      );
    }
    const argMap = bindArgs(this, ctx, { astArgs: args });
    return evaluateLambda(this, ctx, argMap);
  }

  static fromAST (ast: ASTLambdaNode, outer?: LambdaBindings) {
    let lambda = ast.bound;
    if (lambda) {
      if (outer) {
        lambda = lambda.withBoundArgs(outer);
      }
    }
    else {
      lambda = new Lambda(ast.params || [], ast.lambda, outer);
    }
    ast.bound = lambda;
    return lambda;
  }

  withBoundArgs (outer: LambdaBindings): Lambda {
    return new Lambda(this.parameterNames, this.expression, outer, this.parameterSymbols);
  }

  toString () {
    return `LAMBDA(${this.parameterNames.join(', ')}, ${stringTruncate(JSON.stringify(this.expression), 100)})`;
  }

  get numParams () {
    return this.parameterNames.length;
  }
}

export function isLambda (value: any): value is MaybeBoxed<Lambda> {
  return unbox(value) instanceof Lambda;
}

export const ERROR_CALC_LAMBDA_NOT_ALLOWED = ERROR_CALC.detailed('Lambda not allowed');

function bindArgs (
  lambda: Lambda,
  ctx: EvaluationContext,
  args: { astArgs: ASTNode[] } | { valueArgs: MaybeBoxedFormulaArgument[] },
) {
  const alreadyBound = ctx.lambdaBindings?.args;
  const lambdaArguments = new Map([ ...(alreadyBound?.entries() || []), ...(lambda.boundArgs?.entries() || []) ]);
  let i = 0;
  for (const paramName of lambda.parameterNames) {
    const symbol = lambda.parameterSymbols.get(paramName.toLowerCase());
    invariant(symbol);
    lambdaArguments.set(
      symbol,
      'astArgs' in args ? { astNode: args.astArgs[i++], ctx } : { value: args.valueArgs[i++] },
    );
  }
  return lambdaArguments;
}

function evaluateLambda (lambda: Lambda, ctx: EvaluationContext, argMap: Map<LambdaParameterSymbol, LambdaArg>) {
  const lambdaContext = {
    ...ctx,
    lambdaBindings: {
      params: new Map(lambda.parameterSymbols.entries()),
      args: argMap,
    },
  };
  try {
    return lambdaContext.evaluateASTNode(lambda.expression);
  }
  catch (e) {
    if (e instanceof RangeError && e.message === 'Maximum call stack size exceeded') {
      return ERROR_NUM.detailed('Infinite lambda recursion');
    }
    throw e;
  }
}
