/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.keys;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertPath;
import java.security.cert.CertPathValidator;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.PKIXParameters;
import java.security.cert.TrustAnchor;
import java.security.cert.X509Certificate;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.EdECPrivateKey;
import java.security.interfaces.RSAPrivateCrtKey;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.crypto.SecretKey;
import org.keycloak.common.util.KeyUtils;
import org.keycloak.common.util.KeystoreUtil;
import org.keycloak.component.ComponentModel;
import org.keycloak.crypto.JavaAlgorithm;
import org.keycloak.crypto.KeyStatus;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.keys.JavaKeystoreKeyProviderFactory;
import org.keycloak.keys.KeyProvider;
import org.keycloak.models.RealmModel;
import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.vault.VaultTranscriber;

public class JavaKeystoreKeyProvider
implements KeyProvider {
    private final KeyStatus status;
    private final ComponentModel model;
    private final VaultTranscriber vault;
    private final KeyWrapper key;
    private final String algorithm;

    public JavaKeystoreKeyProvider(RealmModel realm, ComponentModel model, VaultTranscriber vault) {
        this.model = model;
        this.vault = vault;
        this.status = KeyStatus.from((boolean)model.get("active", true), (boolean)model.get("enabled", true));
        String defaultAlgorithmKey = KeyUse.ENC.name().equalsIgnoreCase(model.get("keyUse")) ? "RSA-OAEP" : "RS256";
        this.algorithm = model.get("algorithm", defaultAlgorithmKey);
        if (model.hasNote(KeyWrapper.class.getName())) {
            this.key = (KeyWrapper)model.getNote(KeyWrapper.class.getName());
        } else {
            this.key = this.loadKey(realm, model);
            model.setNote(KeyWrapper.class.getName(), (Object)this.key);
        }
    }

    protected KeyWrapper loadKey(RealmModel realm, ComponentModel model) {
        String string;
        String keystorePath = model.get(JavaKeystoreKeyProviderFactory.KEYSTORE_KEY);
        FileInputStream is = new FileInputStream(keystorePath);
        try {
            KeyStore keyStore = this.loadKeyStore(is, keystorePath);
            String keyAlias = model.get(JavaKeystoreKeyProviderFactory.KEY_ALIAS_KEY);
            string = switch (this.algorithm) {
                case "PS256", "PS384", "PS512", "RS256", "RS384", "RS512" -> this.loadRSAKey(keyStore, keyAlias, KeyUse.SIG);
                case "RSA-OAEP", "RSA1_5", "RSA-OAEP-256" -> this.loadRSAKey(keyStore, keyAlias, KeyUse.ENC);
                case "ES256", "ES384", "ES512" -> this.loadECKey(keyStore, keyAlias, KeyUse.SIG);
                case "ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW" -> this.loadECKey(keyStore, keyAlias, KeyUse.ENC);
                case "EdDSA" -> this.loadEdDSAKey(keyStore, keyAlias, KeyUse.SIG);
                case "AES" -> this.loadOctKey(keyStore, keyAlias, JavaAlgorithm.getJavaAlgorithm((String)this.algorithm), KeyUse.ENC);
                case "HS256", "HS384", "HS512" -> this.loadOctKey(keyStore, keyAlias, JavaAlgorithm.getJavaAlgorithm((String)this.algorithm), KeyUse.SIG);
                default -> throw new RuntimeException(String.format("Keys for algorithm %s are not supported.", this.algorithm));
            };
        }
        catch (Throwable throwable) {
            try {
                try {
                    is.close();
                }
                catch (Throwable throwable2) {
                    throwable.addSuppressed(throwable2);
                }
                throw throwable;
            }
            catch (KeyStoreException kse) {
                throw new RuntimeException("KeyStore error on server. " + kse.getMessage(), kse);
            }
            catch (FileNotFoundException fnfe) {
                throw new RuntimeException("File not found on server. " + fnfe.getMessage(), fnfe);
            }
            catch (IOException ioe) {
                throw new RuntimeException("IO error on server. " + ioe.getMessage(), ioe);
            }
            catch (NoSuchAlgorithmException nsae) {
                throw new RuntimeException("Algorithm not available on server. " + nsae.getMessage(), nsae);
            }
            catch (CertificateException ce) {
                throw new RuntimeException("Certificate error on server. " + ce.getMessage(), ce);
            }
            catch (UnrecoverableKeyException uke) {
                throw new RuntimeException("Key in the keystore cannot be recovered. " + uke.getMessage(), uke);
            }
            catch (GeneralSecurityException gse) {
                throw new RuntimeException("Invalid certificate chain. Check the order of certificates.", gse);
            }
        }
        is.close();
        return string;
    }

    private KeyStore loadKeyStore(FileInputStream inputStream, String keystorePath) throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException {
        String keystoreType = KeystoreUtil.getKeystoreType((String)this.model.get(JavaKeystoreKeyProviderFactory.KEYSTORE_TYPE_KEY), (String)keystorePath, (String)"JKS");
        KeyStore keyStore = KeyStore.getInstance(keystoreType);
        String keystorePwd = this.model.get(JavaKeystoreKeyProviderFactory.KEYSTORE_PASSWORD_KEY);
        keystorePwd = this.vault.getStringSecret(keystorePwd).get().orElse(keystorePwd);
        keyStore.load(inputStream, keystorePwd.toCharArray());
        return keyStore;
    }

    private void checkUsage(KeyUse keyUse) throws GeneralSecurityException {
        String use = this.model.get("keyUse");
        if (use != null && !keyUse.name().equalsIgnoreCase(use)) {
            throw new UnrecoverableKeyException(String.format("Invalid use %s for algorithm %s.", use, this.algorithm));
        }
    }

    private X509Certificate checkCertificate(Certificate cert) throws GeneralSecurityException {
        if (cert instanceof X509Certificate) {
            X509Certificate x509Cert = (X509Certificate)cert;
            return x509Cert;
        }
        throw new UnrecoverableKeyException(String.format("Invalid %s certificate in the entry.", cert != null ? cert.getType() : null));
    }

    private <K extends KeyStore.Entry> K checkKeyEntry(KeyStore keyStore, String keyAlias, Class<K> clazz, KeyUse use) throws GeneralSecurityException {
        this.checkUsage(use);
        String keyPwd = this.model.get(JavaKeystoreKeyProviderFactory.KEY_PASSWORD_KEY);
        keyPwd = this.vault.getStringSecret(keyPwd).get().orElse(keyPwd);
        KeyStore.Entry keyEntry = keyStore.getEntry(keyAlias, new KeyStore.PasswordProtection(keyPwd.toCharArray()));
        if (keyEntry == null) {
            throw new UnrecoverableKeyException(String.format("Alias %s does not exists in the keystore.", keyAlias));
        }
        if (!clazz.isInstance(keyEntry)) {
            throw new UnrecoverableKeyException(String.format("Invalid %s key for alias %s. Key is not %s.", this.algorithm, keyAlias, clazz.getSimpleName()));
        }
        return (K)((KeyStore.Entry)clazz.cast(keyEntry));
    }

    private <K extends Key> K checkKey(Key key, String keyAlias, Class<K> clazz, String javaAlgorithm) throws GeneralSecurityException {
        if (!clazz.isInstance(key) || javaAlgorithm != null && !javaAlgorithm.equalsIgnoreCase(key.getAlgorithm())) {
            throw new NoSuchAlgorithmException(String.format("Invalid %s key for alias %s. Algorithm is %s.", this.algorithm, keyAlias, key.getAlgorithm()));
        }
        return (K)((Key)clazz.cast(key));
    }

    private KeyWrapper loadOctKey(KeyStore keyStore, String keyAlias, String javaAlgorithm, KeyUse keyUse) throws GeneralSecurityException {
        KeyStore.SecretKeyEntry secretKeyEntry = this.checkKeyEntry(keyStore, keyAlias, KeyStore.SecretKeyEntry.class, keyUse);
        SecretKey secretKey = this.checkKey(secretKeyEntry.getSecretKey(), keyAlias, SecretKey.class, javaAlgorithm);
        return this.createKeyWrapper(secretKey, keyUse);
    }

    private KeyWrapper loadEdDSAKey(KeyStore keyStore, String keyAlias, KeyUse keyUse) throws GeneralSecurityException {
        KeyStore.PrivateKeyEntry privateKeyEntry = this.checkKeyEntry(keyStore, keyAlias, KeyStore.PrivateKeyEntry.class, keyUse);
        EdECPrivateKey privateKey = this.checkKey(privateKeyEntry.getPrivateKey(), keyAlias, EdECPrivateKey.class, null);
        X509Certificate x509Cert = this.checkCertificate(privateKeyEntry.getCertificate());
        try {
            JavaAlgorithm.getJavaAlgorithmForHash((String)"EdDSA", (String)privateKey.getParams().getName());
        }
        catch (RuntimeException e) {
            throw new UnrecoverableKeyException(String.format("Invalid EdDSA curve for alias %s. Curve algorithm is %s.", keyAlias, privateKey.getParams().getName()));
        }
        PublicKey publicKey = x509Cert.getPublicKey();
        KeyPair keyPair = new KeyPair(publicKey, privateKey);
        return this.createKeyWrapper(keyPair, x509Cert, this.loadCertificateChain(privateKeyEntry), "OKP", keyUse, privateKey.getParams().getName());
    }

    private KeyWrapper loadECKey(KeyStore keyStore, String keyAlias, KeyUse keyUse) throws GeneralSecurityException {
        KeyStore.PrivateKeyEntry privateKeyEntry = this.checkKeyEntry(keyStore, keyAlias, KeyStore.PrivateKeyEntry.class, keyUse);
        ECPrivateKey privateKey = this.checkKey(privateKeyEntry.getPrivateKey(), keyAlias, ECPrivateKey.class, null);
        X509Certificate x509Cert = this.checkCertificate(privateKeyEntry.getCertificate());
        PublicKey publicKey = x509Cert.getPublicKey();
        KeyPair keyPair = new KeyPair(publicKey, privateKey);
        return this.createKeyWrapper(keyPair, x509Cert, this.loadCertificateChain(privateKeyEntry), "EC", keyUse, null);
    }

    private KeyWrapper loadRSAKey(KeyStore keyStore, String keyAlias, KeyUse keyUse) throws GeneralSecurityException {
        KeyStore.PrivateKeyEntry privateKeyEntry = this.checkKeyEntry(keyStore, keyAlias, KeyStore.PrivateKeyEntry.class, keyUse);
        RSAPrivateCrtKey privateKey = this.checkKey(privateKeyEntry.getPrivateKey(), keyAlias, RSAPrivateCrtKey.class, null);
        X509Certificate x509Cert = this.checkCertificate(privateKeyEntry.getCertificate());
        PublicKey publicKey = x509Cert.getPublicKey();
        KeyPair keyPair = new KeyPair(publicKey, privateKey);
        return this.createKeyWrapper(keyPair, x509Cert, this.loadCertificateChain(privateKeyEntry), "RSA", keyUse, null);
    }

    private List<X509Certificate> loadCertificateChain(KeyStore.PrivateKeyEntry privateKeyEntry) throws GeneralSecurityException {
        List chain = Optional.ofNullable(privateKeyEntry.getCertificateChain()).map(certificates -> Arrays.stream(certificates).map(X509Certificate.class::cast).collect(Collectors.toList())).orElseGet(Collections::emptyList);
        this.validateCertificateChain(chain);
        return chain;
    }

    private KeyWrapper createKeyWrapper(KeyPair keyPair, X509Certificate certificate, List<X509Certificate> certificateChain, String type, KeyUse keyUse, String curve) {
        KeyWrapper key = new KeyWrapper();
        key.setProviderId(this.model.getId());
        key.setProviderPriority(this.model.get("priority", 0L));
        key.setKid(this.model.get("kid") != null ? this.model.get("kid") : KeyUtils.createKeyId((Key)keyPair.getPublic()));
        key.setUse(keyUse);
        key.setType(type);
        key.setAlgorithm(this.algorithm);
        key.setCurve(curve);
        key.setStatus(this.status);
        key.setPrivateKey((Key)keyPair.getPrivate());
        key.setPublicKey((Key)keyPair.getPublic());
        key.setCertificate(certificate);
        if (!certificateChain.isEmpty()) {
            if (certificate != null && !certificate.equals(certificateChain.get(0))) {
                certificateChain.add(0, certificate);
            }
            key.setCertificateChain(certificateChain);
        }
        return key;
    }

    private KeyWrapper createKeyWrapper(SecretKey secretKey, KeyUse use) {
        KeyWrapper keyWrapper = new KeyWrapper();
        keyWrapper.setProviderId(this.model.getId());
        keyWrapper.setProviderPriority(this.model.get("priority", 0L));
        keyWrapper.setKid(this.model.get("kid", KeycloakModelUtils.generateId()));
        keyWrapper.setUse(use);
        keyWrapper.setType("OCT");
        keyWrapper.setAlgorithm(this.algorithm);
        keyWrapper.setStatus(this.status);
        keyWrapper.setSecretKey(secretKey);
        return keyWrapper;
    }

    private void validateCertificateChain(List<X509Certificate> certificates) throws GeneralSecurityException {
        if (certificates == null || certificates.isEmpty()) {
            return;
        }
        HashSet<TrustAnchor> anchors = new HashSet<TrustAnchor>();
        anchors.add(new TrustAnchor(certificates.get(certificates.size() - 1), null));
        PKIXParameters params = new PKIXParameters(anchors);
        params.setRevocationEnabled(false);
        CertPath certPath = CertificateFactory.getInstance("X.509").generateCertPath(certificates);
        CertPathValidator validator = CertPathValidator.getInstance(CertPathValidator.getDefaultType());
        validator.validate(certPath, params);
    }

    public Stream<KeyWrapper> getKeysStream() {
        return Stream.of(this.key);
    }
}

