// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package ksp.com.intellij.psi.impl;

import ksp.com.intellij.openapi.project.Project;
import ksp.com.intellij.openapi.util.Factory;
import ksp.com.intellij.openapi.util.Key;
import ksp.com.intellij.psi.*;
import ksp.com.intellij.psi.util.CachedValue;
import ksp.com.intellij.psi.util.CachedValueProvider;
import ksp.com.intellij.psi.util.CachedValuesManager;
import ksp.com.intellij.psi.util.PsiModificationTracker;
import ksp.com.intellij.util.ConcurrencyUtil;
import ksp.com.intellij.util.ObjectUtils;
import ksp.com.intellij.util.containers.CollectionFactory;
import ksp.org.jetbrains.annotations.NotNull;
import ksp.org.jetbrains.annotations.Nullable;

import java.util.Set;
import java.util.concurrent.ConcurrentMap;

public final class JavaConstantExpressionEvaluator extends JavaRecursiveElementWalkingVisitor {
  private final Factory<ConcurrentMap<PsiElement, Object>> myMapFactory;
  private final Project myProject;

  private static final Key<CachedValue<ConcurrentMap<PsiElement,Object>>> CONSTANT_VALUE_WO_OVERFLOW_MAP_KEY = Key.create("CONSTANT_VALUE_WO_OVERFLOW_MAP_KEY");
  private static final Key<CachedValue<ConcurrentMap<PsiElement,Object>>> CONSTANT_VALUE_WITH_OVERFLOW_MAP_KEY = Key.create("CONSTANT_VALUE_WITH_OVERFLOW_MAP_KEY");
  private static final Object NO_VALUE = ObjectUtils.NULL;
  private final ConstantExpressionVisitor myConstantExpressionVisitor;

  private JavaConstantExpressionEvaluator(Set<PsiVariable> visitedVars,
                                          final boolean throwExceptionOnOverflow,
                                          @NotNull Project project,
                                          final PsiConstantEvaluationHelper.AuxEvaluator auxEvaluator) {
    myMapFactory = auxEvaluator == null ? new Factory<ConcurrentMap<PsiElement, Object>>() {
      @Override
      public ConcurrentMap<PsiElement, Object> create() {
        final Key<CachedValue<ConcurrentMap<PsiElement, Object>>> key =
          throwExceptionOnOverflow ? CONSTANT_VALUE_WITH_OVERFLOW_MAP_KEY : CONSTANT_VALUE_WO_OVERFLOW_MAP_KEY;
        return CachedValuesManager.getManager(myProject).getCachedValue(myProject, key, PROVIDER, false);
      }
    } : () -> auxEvaluator.getCacheMap(throwExceptionOnOverflow);
    myProject = project;
    myConstantExpressionVisitor = new ConstantExpressionVisitor(visitedVars, throwExceptionOnOverflow, auxEvaluator);
  }

  @Override
  protected void elementFinished(@NotNull PsiElement element) {
    if (!(element instanceof PsiExpression)) return;

    Object value = getCached((PsiExpression)element);
    if (value == null) {
      Object result = myConstantExpressionVisitor.handle(element);
      cache((PsiExpression)element, result);
    }
    else {
      myConstantExpressionVisitor.store(element, value == NO_VALUE ? null : value);
    }
  }

  @Override
  public void visitElement(@NotNull PsiElement element) {
    if (!(element instanceof PsiExpression)) {
      super.visitElement(element);
      return;
    }

    Object value = getCached((PsiExpression)element);
    if (value == null) {
      super.visitElement(element);
      // will cache back in elementFinished()
    }
    else {
      myConstantExpressionVisitor.store(element, value == NO_VALUE ? null : value);
    }
  }

  private static final CachedValueProvider<ConcurrentMap<PsiElement,Object>> PROVIDER = () -> {
    ConcurrentMap<PsiElement, Object> value = CollectionFactory.createConcurrentWeakMap();
    return CachedValueProvider.Result.create(value, PsiModificationTracker.MODIFICATION_COUNT);
  };

  private Object getCached(@NotNull PsiExpression element) {
    return map().get(element);
  }

  private void cache(@NotNull PsiExpression element, @Nullable Object value) {
    ConcurrencyUtil.cacheOrGet(map(), element, value == null ? NO_VALUE : value);
  }

  @NotNull
  private ConcurrentMap<PsiElement, Object> map() {
    return myMapFactory.create();
  }

  public static Object computeConstantExpression(@Nullable PsiExpression expression, @Nullable Set<PsiVariable> visitedVars, boolean throwExceptionOnOverflow) {
    return computeConstantExpression(expression, visitedVars, throwExceptionOnOverflow, null);
  }

  public static Object computeConstantExpression(@Nullable PsiExpression expression,
                                                 @Nullable Set<PsiVariable> visitedVars,
                                                 boolean throwExceptionOnOverflow,
                                                 final PsiConstantEvaluationHelper.AuxEvaluator auxEvaluator) {
    if (expression == null) return null;

    if (expression instanceof PsiLiteralExpression) {
      return ((PsiLiteralExpression)expression).getValue(); // don't bother with caching etc
    }

    JavaConstantExpressionEvaluator evaluator = new JavaConstantExpressionEvaluator(visitedVars, throwExceptionOnOverflow, expression.getProject(), auxEvaluator);

    if (expression instanceof PsiCompiledElement) {
      // in case of compiled elements we are not allowed to use PSI walking
      // but really in Cls there are only so many cases to handle
      if (expression instanceof PsiPrefixExpression) {
        PsiExpression operand = ((PsiPrefixExpression)expression).getOperand();
        if (operand == null) return null;
        Object value = evaluator.myConstantExpressionVisitor.handle(operand);
        evaluator.myConstantExpressionVisitor.store(operand, value);
      }
      return evaluator.myConstantExpressionVisitor.handle(expression);
    }
    expression.accept(evaluator);
    Object cached = evaluator.getCached(expression);
    return cached == NO_VALUE ? null : cached;
  }

  public static Object computeConstantExpression(@Nullable PsiExpression expression, boolean throwExceptionOnOverflow) {
    return computeConstantExpression(expression, null, throwExceptionOnOverflow);
  }
}
