/*
 * Copyright 2010-2023 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package ksp.org.jetbrains.kotlin.fir.analysis.checkers.declaration

import ksp.org.jetbrains.kotlin.*
import ksp.org.jetbrains.kotlin.config.LanguageFeature
import ksp.org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import ksp.org.jetbrains.kotlin.diagnostics.findChildByType
import ksp.org.jetbrains.kotlin.diagnostics.findChildrenByType
import ksp.org.jetbrains.kotlin.diagnostics.reportOn
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import ksp.org.jetbrains.kotlin.fir.analysis.checkers.getModifierList
import ksp.org.jetbrains.kotlin.fir.analysis.diagnostics.FirErrors
import ksp.org.jetbrains.kotlin.fir.declarations.*
import ksp.org.jetbrains.kotlin.fir.declarations.impl.FirPrimaryConstructor
import ksp.org.jetbrains.kotlin.fir.declarations.utils.isOperator
import ksp.org.jetbrains.kotlin.fir.declarations.utils.nameOrSpecialName
import ksp.org.jetbrains.kotlin.fir.isEnabled
import ksp.org.jetbrains.kotlin.fir.types.*
import ksp.org.jetbrains.kotlin.psi.KtModifierList
import ksp.org.jetbrains.kotlin.psi.psiUtil.getChildOfType
import ksp.org.jetbrains.kotlin.util.OperatorNameConventions

object FirContextParametersDeclarationChecker : FirBasicDeclarationChecker(MppCheckerKind.Platform) {
    context(context: CheckerContext, reporter: DiagnosticReporter)
    override fun check(declaration: FirDeclaration) {
        if (declaration.source?.kind is KtFakeSourceElementKind) return

        val contextListSources = when (declaration) {
            is FirFile -> declaration.packageDirective.source
            else -> declaration.source
        }?.findContextReceiverListSources().orEmpty().ifEmpty { return }

        val source = contextListSources.first()

        if (contextListSources.size > 1) {
            reporter.reportOn(source, FirErrors.MULTIPLE_CONTEXT_LISTS)
        }

        val contextReceiversEnabled = LanguageFeature.ContextReceivers.isEnabled()
        val contextParametersEnabled = LanguageFeature.ContextParameters.isEnabled()

        val errorMessage = when (declaration) {
            // Stuff that was never supported
            is FirTypeAlias -> "Context parameters on type aliases are unsupported."
            is FirAnonymousInitializer -> "Context parameters on initializers are unsupported."
            is FirEnumEntry -> "Context parameters on enum entries are unsupported."
            is FirPropertyAccessor -> "Context parameters on property accessors are unsupported."
            is FirBackingField -> "Context parameters on backing fields are unsupported."
            is FirPrimaryConstructor -> "Context parameters on primary constructors are unsupported."
            is FirProperty if declaration.isLocal -> "Context parameters on local properties are unsupported.".takeIf { contextParametersEnabled }
            // Stuff that is unsupported with context parameters
            is FirConstructor -> "Context parameters on constructors are unsupported.".takeIf { contextParametersEnabled }
            is FirClass -> "Context parameters on classes are unsupported.".takeIf { contextParametersEnabled }
            is FirCallableDeclaration if declaration.isDelegationOperator() -> "Context parameters on delegation operators are unsupported.".takeIf { contextParametersEnabled }
            is FirProperty if declaration.delegate != null -> "Context parameters on delegated properties are unsupported.".takeIf { contextParametersEnabled }
            // Only valid positions
            is FirSimpleFunction, is FirProperty, is FirAnonymousFunction -> null
            // Fallback if we forgot something.
            else -> "Context parameters are unsupported in this position."
        }

        if (errorMessage != null) {
            reporter.reportOn(
                source,
                FirErrors.UNSUPPORTED,
                errorMessage
            )
        }

        val contextParameters = declaration.getContextParameters()
        if (contextParameters.isEmpty()) return

        if (!contextReceiversEnabled && !contextParametersEnabled) {
            reporter.reportOn(
                source,
                FirErrors.UNSUPPORTED_FEATURE,
                LanguageFeature.ContextParameters to context.languageVersionSettings
            )
            return
        }

        if (contextReceiversEnabled) {
            if (checkSubTypes(contextParameters.map { it.returnTypeRef.coneType })) {
                reporter.reportOn(
                    source,
                    FirErrors.SUBTYPING_BETWEEN_CONTEXT_RECEIVERS
                )
            }
            for (parameter in contextParameters) {
                if (!parameter.isLegacyContextReceiver()) {
                    reporter.reportOn(
                        parameter.source,
                        FirErrors.UNSUPPORTED_FEATURE,
                        LanguageFeature.ContextParameters to context.languageVersionSettings
                    )
                }
            }
        }

        if (contextParametersEnabled) {
            for (parameter in contextParameters) {
                if (parameter.isLegacyContextReceiver()) {
                    reporter.reportOn(parameter.source, FirErrors.CONTEXT_PARAMETER_WITHOUT_NAME)
                }

                parameter.source?.getModifierList()?.modifiers?.forEach { modifier ->
                    reporter.reportOn(modifier.source, FirErrors.WRONG_MODIFIER_TARGET, modifier.token, "context parameter")
                }

                FirFunctionParameterChecker.checkValOrVar(parameter)
            }
        }
    }

    private fun FirCallableDeclaration.isDelegationOperator(): Boolean {
        return this.isOperator && this.nameOrSpecialName in OperatorNameConventions.DELEGATED_PROPERTY_OPERATORS
    }

    private fun FirDeclaration.getContextParameters(): List<FirValueParameter> {
        return when (this) {
            is FirCallableDeclaration -> contextParameters
            is FirRegularClass -> contextParameters
            else -> emptyList()
        }
    }

    private fun KtSourceElement.findContextReceiverListSources(): List<KtSourceElement> {
        return when (this) {
            is KtPsiSourceElement ->
                psi.getChildOfType<KtModifierList>()?.contextReceiverLists?.map { it.toKtPsiSourceElement() }.orEmpty()
            is KtLightSourceElement ->
                treeStructure.findChildByType(lighterASTNode, KtNodeTypes.MODIFIER_LIST)
                    ?.let { treeStructure.findChildrenByType(it, KtNodeTypes.CONTEXT_RECEIVER_LIST) }
                    ?.map { it.toKtLightSourceElement(treeStructure) }
                    .orEmpty()
        }
    }

    context(context: CheckerContext)
            /**
             * Simplified checking of subtype relation used in context receiver checkers.
             * It converts type parameters to star projections and top level type parameters to its supertypes. Then it checks the relation.
             */
    fun checkSubTypes(types: List<ConeKotlinType>): Boolean {
        fun replaceTypeParametersByStarProjections(type: ConeClassLikeType): ConeClassLikeType {
            return type.withArguments(type.typeArguments.map {
                when {
                    it.isStarProjection -> it
                    it.type!! is ConeTypeParameterType -> ConeStarProjection
                    it.type!! is ConeClassLikeType -> replaceTypeParametersByStarProjections(it.type as ConeClassLikeType)
                    else -> it
                }
            }.toTypedArray())
        }

        val replacedTypeParameters = types.flatMap { r ->
            when (r) {
                is ConeTypeParameterType -> r.lookupTag.typeParameterSymbol.resolvedBounds.map { it.coneType }
                is ConeClassLikeType -> listOf(replaceTypeParametersByStarProjections(r))
                else -> listOf(r)
            }
        }

        for (i in replacedTypeParameters.indices)
            for (j in i + 1..<replacedTypeParameters.size) {
                if (replacedTypeParameters[i].isSubtypeOf(replacedTypeParameters[j], context.session)
                    || replacedTypeParameters[j].isSubtypeOf(replacedTypeParameters[i], context.session)
                )
                    return true
            }

        return false
    }
}

