/*
 * Copyright 2010-2022 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.allopen.fir

import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.descriptors.Modality
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.copyWithNewDefaults
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.declarations.utils.isLocal
import org.jetbrains.kotlin.fir.extensions.FirStatusTransformerExtension
import org.jetbrains.kotlin.fir.extensions.predicate.DeclarationPredicate
import org.jetbrains.kotlin.fir.extensions.utils.AbstractSimpleClassPredicateMatchingService
import org.jetbrains.kotlin.fir.resolve.getContainingClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirRegularClassSymbol
import org.jetbrains.kotlin.name.FqName

class FirAllOpenStatusTransformer(session: FirSession) : FirStatusTransformerExtension(session) {
    override fun needTransformStatus(declaration: FirDeclaration): Boolean {
        if (declaration.isJavaOrEnhancement) return false
        return when (declaration) {
            is FirRegularClass -> declaration.classKind == ClassKind.CLASS && session.allOpenPredicateMatcher.isAnnotated(declaration.symbol)
            is FirCallableDeclaration -> {
                val parentClassSymbol = declaration.symbol.getContainingClassSymbol() as? FirRegularClassSymbol ?: return false
                if (parentClassSymbol.isLocal) return false
                parentClassSymbol.classKind == ClassKind.CLASS && session.allOpenPredicateMatcher.isAnnotated(parentClassSymbol)
            }
            else -> false
        }
    }

    override fun transformStatus(status: FirDeclarationStatus, declaration: FirDeclaration): FirDeclarationStatus {
        return when (status.modality) {
            null -> status.copyWithNewDefaults(modality = Modality.OPEN, defaultModality = Modality.OPEN)
            else -> status.copyWithNewDefaults(defaultModality = Modality.OPEN)
        }
    }
}

class FirAllOpenPredicateMatcher(
    session: FirSession,
    allOpenAnnotationFqNames: List<String>
) : AbstractSimpleClassPredicateMatchingService(session) {
    companion object {
        fun getFactory(allOpenAnnotationFqNames: List<String>): Factory {
            return Factory { session -> FirAllOpenPredicateMatcher(session, allOpenAnnotationFqNames) }
        }
    }

    override val predicate = DeclarationPredicate.create {
        val annotationFqNames = allOpenAnnotationFqNames.map { FqName(it) }
        annotated(annotationFqNames) or metaAnnotated(annotationFqNames, includeItself = true)
    }
}

val FirSession.allOpenPredicateMatcher: FirAllOpenPredicateMatcher by FirSession.sessionComponentAccessor()
