// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package ksp.com.intellij.psi.impl.source.resolve;

import ksp.com.intellij.openapi.project.Project;
import ksp.com.intellij.openapi.util.Key;
import ksp.com.intellij.psi.PsiExpression;
import ksp.com.intellij.psi.PsiType;
import ksp.com.intellij.psi.PsiVariable;
import ksp.com.intellij.psi.impl.AnyPsiChangeListener;
import ksp.com.intellij.psi.impl.PsiManagerImpl;
import ksp.com.intellij.psi.impl.source.resolve.graphInference.PsiPolyExpressionUtil;
import ksp.com.intellij.psi.infos.MethodCandidateInfo;
import ksp.com.intellij.psi.util.CachedValuesManager;
import ksp.com.intellij.util.ConcurrencyUtil;
import ksp.com.intellij.util.Function;
import ksp.com.intellij.util.containers.CollectionFactory;
import ksp.org.jetbrains.annotations.NotNull;
import ksp.org.jetbrains.annotations.Nullable;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;

public class JavaResolveCache {
  public static JavaResolveCache getInstance(Project project) {
    return project.getService(JavaResolveCache.class);
  }

  private final AtomicReference<Map<PsiVariable,Object>> myVarToConstValueMapPhysical = new AtomicReference<>();
  private final AtomicReference<Map<PsiVariable,Object>> myVarToConstValueMapNonPhysical = new AtomicReference<>();

  private static final Object NULL = Key.create("NULL");

  public JavaResolveCache(@NotNull Project project) {
    project.getMessageBus().connect().subscribe(PsiManagerImpl.ANY_PSI_CHANGE_TOPIC, new AnyPsiChangeListener() {
      @Override
      public void beforePsiChanged(boolean isPhysical) {
        clearCaches(isPhysical);
      }
    });
  }

  private void clearCaches(boolean isPhysical) {
    if (isPhysical) {
      myVarToConstValueMapPhysical.set(null);
    }
    myVarToConstValueMapNonPhysical.set(null);
  }

  @Nullable
  public <T extends PsiExpression> PsiType getType(@NotNull T expr, @NotNull Function<? super T, ? extends PsiType> f) {
    boolean prohibitCaching = MethodCandidateInfo.isOverloadCheck() && PsiPolyExpressionUtil.isPolyExpression(expr);
    if (prohibitCaching) {
      return f.fun(expr);
    }

    return CachedValuesManager.getProjectPsiDependentCache(expr, param -> f.fun(param));
  }

  @Nullable
  public Object computeConstantValueWithCaching(@NotNull PsiVariable variable, @NotNull ConstValueComputer computer, Set<PsiVariable> visitedVars){
    boolean physical = variable.isPhysical();

    AtomicReference<Map<PsiVariable, Object>> ref = physical ? myVarToConstValueMapPhysical : myVarToConstValueMapNonPhysical;
    Map<PsiVariable, Object> map = ref.get();
    if (map == null) {
      map = ConcurrencyUtil.cacheOrGet(ref, CollectionFactory.createConcurrentWeakMap());
    }

    Object cached = map.get(variable);
    if (cached == NULL) {
      return null;
    }
    if (cached != null) {
      return cached;
    }

    Object result = computer.execute(variable, visitedVars);
    map.put(variable, result == null ? NULL : result);
    return result;
  }

  @FunctionalInterface
  public interface ConstValueComputer{
    Object execute(@NotNull PsiVariable variable, Set<PsiVariable> visitedVars);
  }
}
