/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.jose.jws;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.Key;
import java.time.Duration;
import java.util.Comparator;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Stream;
import org.jboss.logging.Logger;
import org.keycloak.Token;
import org.keycloak.TokenCategory;
import org.keycloak.common.util.Time;
import org.keycloak.crypto.CekManagementProvider;
import org.keycloak.crypto.ClientSignatureVerifierProvider;
import org.keycloak.crypto.ContentEncryptionProvider;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.crypto.SignatureProvider;
import org.keycloak.crypto.SignatureSignerContext;
import org.keycloak.jose.JOSE;
import org.keycloak.jose.JOSEParser;
import org.keycloak.jose.jwe.JWE;
import org.keycloak.jose.jwe.JWEException;
import org.keycloak.jose.jwe.alg.JWEAlgorithmProvider;
import org.keycloak.jose.jwe.enc.JWEEncryptionProvider;
import org.keycloak.jose.jwk.JWK;
import org.keycloak.jose.jws.Algorithm;
import org.keycloak.jose.jws.JWSBuilder;
import org.keycloak.jose.jws.JWSInput;
import org.keycloak.keys.loader.PublicKeyStorageManager;
import org.keycloak.models.AuthenticatedClientSessionModel;
import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.TokenManager;
import org.keycloak.models.UserModel;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.protocol.oidc.OIDCAdvancedConfigWrapper;
import org.keycloak.representations.LogoutToken;
import org.keycloak.util.JsonSerialization;
import org.keycloak.util.TokenUtil;

public class DefaultTokenManager
implements TokenManager {
    private static final Logger logger = Logger.getLogger(DefaultTokenManager.class);
    private final KeycloakSession session;

    public DefaultTokenManager(KeycloakSession session) {
        this.session = session;
    }

    public String encode(Token token) {
        String signatureAlgorithm = this.signatureAlgorithm(token.getCategory());
        SignatureProvider signatureProvider = (SignatureProvider)this.session.getProvider(SignatureProvider.class, signatureAlgorithm);
        SignatureSignerContext signer = signatureProvider.signer();
        String encodedToken = new JWSBuilder().type("JWT").jsonContent((Object)token).sign(signer);
        return encodedToken;
    }

    public <T extends Token> T decode(String token, Class<T> clazz) {
        if (token == null) {
            return null;
        }
        try {
            boolean valid;
            JWSInput jws = new JWSInput(token);
            String signatureAlgorithm = jws.getHeader().getAlgorithm().name();
            SignatureProvider signatureProvider = (SignatureProvider)this.session.getProvider(SignatureProvider.class, signatureAlgorithm);
            if (signatureProvider == null) {
                return null;
            }
            String kid = jws.getHeader().getKeyId();
            if (kid == null) {
                logger.debugf("KID is null in token. Using the realm active key to verify token signature.", new Object[0]);
                kid = this.session.keys().getActiveKey(this.session.getContext().getRealm(), KeyUse.SIG, signatureAlgorithm).getKid();
            }
            return (T)((valid = signatureProvider.verifier(kid).verify(jws.getEncodedSignatureInput().getBytes("UTF-8"), jws.getSignature())) ? (Token)jws.readJsonContent(clazz) : null);
        }
        catch (Exception e) {
            logger.debug((Object)"Failed to decode token", (Throwable)e);
            return null;
        }
    }

    public <T> T decodeClientJWT(String jwt, ClientModel client, BiConsumer<JOSE, ClientModel> jwtValidator, Class<T> clazz) {
        if (jwt == null) {
            return null;
        }
        JOSE joseToken = JOSEParser.parse((String)jwt);
        jwtValidator.accept(joseToken, client);
        if (joseToken instanceof JWE) {
            try {
                String kid = joseToken.getHeader().getKeyId();
                Stream keys = this.session.keys().getKeysStream(this.session.getContext().getRealm());
                Optional<KeyWrapper> activeKey = kid == null ? keys.filter(k -> KeyUse.ENC.equals((Object)k.getUse()) && k.getPublicKey() != null).sorted(Comparator.comparingLong(KeyWrapper::getProviderPriority).reversed()).findFirst() : keys.filter(k -> KeyUse.ENC.equals((Object)k.getUse()) && k.getKid().equals(kid)).findAny();
                JWE jwe = (JWE)JWE.class.cast(joseToken);
                Key privateKey = activeKey.map(KeyWrapper::getPrivateKey).orElseThrow(() -> new RuntimeException("Could not find private key for decrypting token"));
                jwe.getKeyStorage().setDecryptionKey(privateKey);
                byte[] content = jwe.verifyAndDecodeJwe().getContent();
                try {
                    JOSE jws = JOSEParser.parse((String)new String(content));
                    if (jws instanceof JWSInput) {
                        jwtValidator.accept(jws, client);
                        return this.verifyJWS(client, clazz, (JWSInput)jws);
                    }
                }
                catch (Exception exception) {
                    // empty catch block
                }
                return (T)JsonSerialization.readValue((byte[])content, clazz);
            }
            catch (IOException cause) {
                throw new RuntimeException("Failed to deserialize JWT", cause);
            }
            catch (JWEException cause) {
                throw new RuntimeException("Failed to decrypt JWT", cause);
            }
        }
        return this.verifyJWS(client, clazz, (JWSInput)joseToken);
    }

    private <T> T verifyJWS(ClientModel client, Class<T> clazz, JWSInput jws) {
        try {
            String signatureAlgorithm = jws.getHeader().getAlgorithm().name();
            ClientSignatureVerifierProvider signatureProvider = (ClientSignatureVerifierProvider)this.session.getProvider(ClientSignatureVerifierProvider.class, signatureAlgorithm);
            if (signatureProvider == null) {
                if (jws.getHeader().getAlgorithm().equals((Object)Algorithm.none)) {
                    return (T)jws.readJsonContent(clazz);
                }
                return null;
            }
            boolean valid = signatureProvider.verifier(client, jws).verify(jws.getEncodedSignatureInput().getBytes("UTF-8"), jws.getSignature());
            return (T)(valid ? jws.readJsonContent(clazz) : null);
        }
        catch (Exception e) {
            logger.debug((Object)"Failed to decode token", (Throwable)e);
            return null;
        }
    }

    public String signatureAlgorithm(TokenCategory category) {
        switch (category) {
            case INTERNAL: {
                return "HS256";
            }
            case ADMIN: {
                return this.getSignatureAlgorithm(null);
            }
            case ACCESS: {
                return this.getSignatureAlgorithm("access.token.signed.response.alg");
            }
            case ID: 
            case LOGOUT: {
                return this.getSignatureAlgorithm("id.token.signed.response.alg");
            }
            case USERINFO: {
                return this.getSignatureAlgorithm("user.info.response.signature.alg");
            }
            case AUTHORIZATION_RESPONSE: {
                return this.getSignatureAlgorithm("authorization.signed.response.alg");
            }
        }
        throw new RuntimeException("Unknown token type");
    }

    private String getSignatureAlgorithm(String clientAttribute) {
        String algorithm;
        RealmModel realm = this.session.getContext().getRealm();
        ClientModel client = this.session.getContext().getClient();
        String string = algorithm = client != null && clientAttribute != null ? client.getAttribute(clientAttribute) : null;
        if (algorithm != null && !algorithm.equals("")) {
            return algorithm;
        }
        algorithm = realm.getDefaultSignatureAlgorithm();
        if (algorithm != null && !algorithm.equals("")) {
            return algorithm;
        }
        return "RS256";
    }

    public String encodeAndEncrypt(Token token) {
        String encodedToken = this.encode(token);
        if (this.isTokenEncryptRequired(token.getCategory())) {
            encodedToken = this.getEncryptedToken(token.getCategory(), encodedToken);
        }
        return encodedToken;
    }

    private boolean isTokenEncryptRequired(TokenCategory category) {
        if (this.cekManagementAlgorithm(category) == null) {
            return false;
        }
        return this.encryptAlgorithm(category) != null;
    }

    private String getEncryptedToken(TokenCategory category, String encodedToken) {
        String encryptedToken = null;
        String algAlgorithm = this.cekManagementAlgorithm(category);
        String encAlgorithm = this.encryptAlgorithm(category);
        CekManagementProvider cekManagementProvider = (CekManagementProvider)this.session.getProvider(CekManagementProvider.class, algAlgorithm);
        JWEAlgorithmProvider jweAlgorithmProvider = cekManagementProvider.jweAlgorithmProvider();
        ContentEncryptionProvider contentEncryptionProvider = (ContentEncryptionProvider)this.session.getProvider(ContentEncryptionProvider.class, encAlgorithm);
        JWEEncryptionProvider jweEncryptionProvider = contentEncryptionProvider.jweEncryptionProvider();
        ClientModel client = this.session.getContext().getClient();
        KeyWrapper keyWrapper = PublicKeyStorageManager.getClientPublicKeyWrapper(this.session, client, JWK.Use.ENCRYPTION, algAlgorithm);
        if (keyWrapper == null) {
            throw new RuntimeException("can not get encryption KEK");
        }
        Key encryptionKek = keyWrapper.getPublicKey();
        String encryptionKekId = keyWrapper.getKid();
        try {
            encryptedToken = TokenUtil.jweKeyEncryptionEncode((Key)encryptionKek, (byte[])encodedToken.getBytes("UTF-8"), (String)algAlgorithm, (String)encAlgorithm, (String)encryptionKekId, (JWEAlgorithmProvider)jweAlgorithmProvider, (JWEEncryptionProvider)jweEncryptionProvider);
        }
        catch (UnsupportedEncodingException | JWEException e) {
            throw new RuntimeException(e);
        }
        return encryptedToken;
    }

    public String cekManagementAlgorithm(TokenCategory category) {
        if (category == null) {
            return null;
        }
        switch (category) {
            case ID: 
            case LOGOUT: {
                return this.getCekManagementAlgorithm("id.token.encrypted.response.alg");
            }
            case AUTHORIZATION_RESPONSE: {
                return this.getCekManagementAlgorithm("authorization.encrypted.response.alg");
            }
            case USERINFO: {
                return this.getCekManagementAlgorithm("user.info.encrypted.response.alg");
            }
        }
        return null;
    }

    private String getCekManagementAlgorithm(String clientAttribute) {
        String algorithm;
        ClientModel client = this.session.getContext().getClient();
        String string = algorithm = client != null && clientAttribute != null ? client.getAttribute(clientAttribute) : null;
        if (algorithm != null && !algorithm.equals("")) {
            return algorithm;
        }
        return null;
    }

    public String encryptAlgorithm(TokenCategory category) {
        if (category == null) {
            return null;
        }
        switch (category) {
            case ID: {
                return this.getEncryptAlgorithm("id.token.encrypted.response.enc", "A128CBC-HS256");
            }
            case LOGOUT: {
                return this.getEncryptAlgorithm("id.token.encrypted.response.enc");
            }
            case AUTHORIZATION_RESPONSE: {
                return this.getEncryptAlgorithm("authorization.encrypted.response.enc");
            }
            case USERINFO: {
                return this.getEncryptAlgorithm("user.info.encrypted.response.enc", "A128CBC-HS256");
            }
        }
        return null;
    }

    private String getEncryptAlgorithm(String clientAttribute) {
        return this.getEncryptAlgorithm(clientAttribute, null);
    }

    private String getEncryptAlgorithm(String clientAttribute, String defaultValue) {
        String algorithm;
        ClientModel client = this.session.getContext().getClient();
        String string = algorithm = client != null && clientAttribute != null ? client.getAttribute(clientAttribute) : null;
        if (algorithm != null && !algorithm.equals("")) {
            return algorithm;
        }
        return defaultValue;
    }

    public LogoutToken initLogoutToken(ClientModel client, UserModel user, AuthenticatedClientSessionModel clientSession) {
        LogoutToken token = new LogoutToken();
        token.id(KeycloakModelUtils.generateId());
        token.issuedNow();
        token.exp(Long.valueOf((long)Time.currentTime() + Duration.ofMinutes(2L).getSeconds()));
        token.issuer(clientSession.getNote("iss"));
        token.putEvents("http://schemas.openid.net/event/backchannel-logout", (Object)JsonSerialization.createObjectNode());
        token.addAudience(client.getClientId());
        OIDCAdvancedConfigWrapper oidcAdvancedConfigWrapper = OIDCAdvancedConfigWrapper.fromClientModel(client);
        if (oidcAdvancedConfigWrapper.isBackchannelLogoutSessionRequired()) {
            token.setSid(clientSession.getUserSession().getId());
        }
        if (oidcAdvancedConfigWrapper.getBackchannelLogoutRevokeOfflineTokens()) {
            token.putEvents("revoke_offline_access", (Object)true);
        }
        token.setSubject(user.getId());
        return token;
    }
}

