/*
 * Decompiled with CFR 0.152.
 */
package org.apache.activemq.artemis.shaded.org.jgroups.protocols;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.spec.X509EncodedKeySpec;
import java.util.Iterator;
import java.util.function.Supplier;
import javax.crypto.Cipher;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.apache.activemq.artemis.shaded.org.jgroups.Address;
import org.apache.activemq.artemis.shaded.org.jgroups.EmptyMessage;
import org.apache.activemq.artemis.shaded.org.jgroups.Event;
import org.apache.activemq.artemis.shaded.org.jgroups.Header;
import org.apache.activemq.artemis.shaded.org.jgroups.Message;
import org.apache.activemq.artemis.shaded.org.jgroups.annotations.MBean;
import org.apache.activemq.artemis.shaded.org.jgroups.annotations.Property;
import org.apache.activemq.artemis.shaded.org.jgroups.conf.AttributeType;
import org.apache.activemq.artemis.shaded.org.jgroups.protocols.ASYM_ENCRYPT;
import org.apache.activemq.artemis.shaded.org.jgroups.protocols.KeyExchange;
import org.apache.activemq.artemis.shaded.org.jgroups.util.MessageBatch;
import org.apache.activemq.artemis.shaded.org.jgroups.util.Tuple;
import org.apache.activemq.artemis.shaded.org.jgroups.util.Util;

@MBean(description="Key exchange protocol to fetch a shared secret group key from the key server.That shared (symmetric) key is subsequently used to encrypt communication between cluster members")
public class DH_KEY_EXCHANGE
extends KeyExchange {
    @Property(description="The type of secret key to be sent up the stack (converted from DH). Should be the same as the algorithm part of ASYM_ENCRYPT.sym_algorithm if ASYM_ENCRYPT is used")
    protected String secret_key_algorithm = "AES";
    @Property(description="The length of the secret key (in bits) to be sent up the stack. AES requires 128 bits. Should be the same as ASYM_ENCRYPT.sym_keylength if ASYM_ENCRYPT is used.")
    protected int secret_key_length = 128;
    @Property(description="Max time (in ms) that a FETCH_SECRET_KEY down event will be ignored (if an existing request is in progress) until a new request for the secret key is sent to the keyserver", type=AttributeType.TIME)
    protected long timeout = 2000L;
    protected KeyAgreement key_agreement;
    protected PublicKey dh_key;
    protected long last_key_request;
    protected static final KeyPairGenerator key_pair_gen;
    protected static final KeyFactory dh_key_factory;

    @Override
    public void init() throws Exception {
        super.init();
        if (this.secret_key_length % 8 != 0) {
            throw new IllegalStateException(String.format("secret_key_length (%d) must be a multiple of 8", this.secret_key_length));
        }
        ASYM_ENCRYPT asym_encrypt = (ASYM_ENCRYPT)this.findProtocolAbove(ASYM_ENCRYPT.class);
        if (asym_encrypt != null) {
            String sym_alg = asym_encrypt.symKeyAlgorithm();
            int sym_keylen = asym_encrypt.symKeylength();
            if (!Util.match(sym_alg, this.secret_key_algorithm)) {
                this.log.warn("overriding %s=%s to %s from %s", "secret_key_algorithm", this.secret_key_algorithm, sym_alg, ASYM_ENCRYPT.class.getSimpleName());
                this.secret_key_algorithm = sym_alg;
            }
            if (sym_keylen != this.secret_key_length) {
                this.log.warn("overriding %s=%d to %d from %s", "secret_key_length", this.secret_key_length, sym_keylen, ASYM_ENCRYPT.class.getSimpleName());
                this.secret_key_length = sym_keylen;
            }
        }
        this.key_agreement = KeyAgreement.getInstance("DH");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void fetchSecretKeyFrom(Address target) throws NoSuchAlgorithmException, InvalidKeyException {
        byte[] encoded_dh_key = null;
        DH_KEY_EXCHANGE dH_KEY_EXCHANGE = this;
        synchronized (dH_KEY_EXCHANGE) {
            if (this.dh_key != null) {
                long curr_time = System.currentTimeMillis();
                if (curr_time - this.last_key_request >= this.timeout) {
                    this.last_key_request = curr_time;
                    encoded_dh_key = this.dh_key.getEncoded();
                }
            } else {
                KeyPair kp = key_pair_gen.generateKeyPair();
                PrivateKey private_key = kp.getPrivate();
                this.dh_key = kp.getPublic();
                encoded_dh_key = this.dh_key.getEncoded();
                this.key_agreement.init(private_key);
                this.log.debug("%s: sending public key %s.. to %s", this.local_addr, DH_KEY_EXCHANGE.print16(this.dh_key), target);
            }
        }
        if (encoded_dh_key != null) {
            Message msg = new EmptyMessage(target).putHeader(this.id, DhHeader.createSecretKeyRequest(encoded_dh_key));
            this.down_prot.down(msg);
        }
    }

    @Override
    public Address getServerLocation() {
        return null;
    }

    @Override
    public Object up(Message msg) {
        DhHeader hdr = (DhHeader)msg.getHeader(this.id);
        if (hdr != null) {
            this.handle(hdr, msg.getSrc());
            return null;
        }
        return this.up_prot.up(msg);
    }

    @Override
    public void up(MessageBatch batch) {
        Iterator<Message> it = batch.iterator();
        while (it.hasNext()) {
            Message msg = it.next();
            DhHeader hdr = (DhHeader)msg.getHeader(this.id);
            if (hdr == null) continue;
            it.remove();
            this.handle(hdr, msg.getSrc());
        }
        if (!batch.isEmpty()) {
            this.up_prot.up(batch);
        }
    }

    protected void handle(DhHeader hdr, Address sender) {
        try {
            PublicKey pub_key = dh_key_factory.generatePublic(new X509EncodedKeySpec(hdr.dh_key));
            switch (hdr.type.ordinal()) {
                case 0: {
                    this.handleSecretKeyRequest(pub_key, sender);
                    break;
                }
                case 1: {
                    this.handleSecretKeyResponse(pub_key, hdr.encrypted_secret_key, hdr.secret_key_version, sender);
                    break;
                }
                default: {
                    this.log.warn("unknown header type %d", new Object[]{hdr.type});
                    break;
                }
            }
        }
        catch (Throwable t) {
            this.log.error(String.format("failed handling request %s", hdr), t);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void handleSecretKeyRequest(PublicKey dh_public_key, Address sender) throws Exception {
        byte[] encrypted_secret_key;
        byte[] version;
        KeyPair kp = key_pair_gen.generateKeyPair();
        PrivateKey private_key = kp.getPrivate();
        PublicKey public_key_rsp = kp.getPublic();
        this.log.debug("%s: received public key %s.. from %s", this.local_addr, DH_KEY_EXCHANGE.print16(dh_public_key), sender);
        DH_KEY_EXCHANGE dH_KEY_EXCHANGE = this;
        synchronized (dH_KEY_EXCHANGE) {
            this.key_agreement.init(private_key);
            this.key_agreement.doPhase(dh_public_key, true);
            byte[] secret_session_key = this.key_agreement.generateSecret();
            SecretKey hashed_session_key = this.hash(secret_session_key);
            Cipher encrypter = Cipher.getInstance(this.secret_key_algorithm);
            encrypter.init(1, hashed_session_key);
            Tuple tuple = (Tuple)this.up_prot.up(new Event(111));
            SecretKey secret_key = (SecretKey)tuple.getVal1();
            version = (byte[])tuple.getVal2();
            encrypted_secret_key = encrypter.doFinal(secret_key.getEncoded());
        }
        this.log.debug("%s: sending public key rsp %s.. to %s", this.local_addr, DH_KEY_EXCHANGE.print16(public_key_rsp), sender);
        Message rsp = new EmptyMessage(sender).putHeader(this.id, DhHeader.createSecretKeyResponse(public_key_rsp.getEncoded(), encrypted_secret_key, version));
        this.down_prot.down(rsp);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void handleSecretKeyResponse(PublicKey dh_public_key, byte[] encrypted_secret_key, byte[] version, Address sender) throws Exception {
        Tuple<SecretKeySpec, byte[]> tuple = null;
        this.log.debug("%s: received public key rsp %s.. from %s", this.local_addr, DH_KEY_EXCHANGE.print16(dh_public_key), sender);
        DH_KEY_EXCHANGE dH_KEY_EXCHANGE = this;
        synchronized (dH_KEY_EXCHANGE) {
            this.key_agreement.doPhase(dh_public_key, true);
            byte[] secret_session_key = this.key_agreement.generateSecret();
            SecretKey hashed_session_key = this.hash(secret_session_key);
            Cipher encrypter = Cipher.getInstance(this.secret_key_algorithm);
            encrypter.init(2, hashed_session_key);
            byte[] secret_key = encrypter.doFinal(encrypted_secret_key);
            SecretKeySpec sk = new SecretKeySpec(secret_key, this.secret_key_algorithm);
            tuple = new Tuple<SecretKeySpec, byte[]>(sk, version);
            this.dh_key = null;
        }
        this.log.debug("%s: sending up secret key (version: %s)", this.local_addr, Util.byteArrayToHexString(version));
        this.up_prot.up(new Event(112, tuple));
    }

    protected SecretKey hash(byte[] key) throws Exception {
        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        digest.update(key);
        byte[] hashed_key = digest.digest();
        return new SecretKeySpec(hashed_key, 0, this.secret_key_length / 8, this.secret_key_algorithm);
    }

    protected static String print16(PublicKey pub_key) {
        MessageDigest digest = null;
        try {
            digest = MessageDigest.getInstance("SHA-256");
            digest.update(pub_key.getEncoded());
            return Util.byteArrayToHexString(digest.digest(), 0, 16);
        }
        catch (NoSuchAlgorithmException e) {
            return e.toString();
        }
    }

    static {
        try {
            key_pair_gen = KeyPairGenerator.getInstance("DH");
            dh_key_factory = KeyFactory.getInstance("DH");
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    public static class DhHeader
    extends Header {
        protected Type type;
        protected byte[] dh_key;
        protected byte[] encrypted_secret_key;
        protected byte[] secret_key_version;

        public static DhHeader createSecretKeyRequest(byte[] dh_key) {
            DhHeader hdr = new DhHeader();
            hdr.type = Type.SECRET_KEY_REQ;
            hdr.dh_key = dh_key;
            return hdr;
        }

        public static DhHeader createSecretKeyResponse(byte[] dh_pub_key, byte[] encrypted_secret_key, byte[] version) {
            DhHeader hdr = new DhHeader();
            hdr.type = Type.SECRET_KEY_RSP;
            hdr.dh_key = dh_pub_key;
            hdr.encrypted_secret_key = encrypted_secret_key;
            hdr.secret_key_version = version;
            return hdr;
        }

        @Override
        public Supplier<? extends Header> create() {
            return DhHeader::new;
        }

        @Override
        public short getMagicId() {
            return 92;
        }

        public byte[] dhKey() {
            return this.dh_key;
        }

        public byte[] encryptedSecret() {
            return this.encrypted_secret_key;
        }

        public byte[] version() {
            return this.secret_key_version;
        }

        @Override
        public int serializedSize() {
            switch (this.type.ordinal()) {
                case 0: {
                    return 5 + (this.dh_key != null ? this.dh_key.length : 0);
                }
                case 1: {
                    return 13 + (this.dh_key != null ? this.dh_key.length : 0) + (this.encrypted_secret_key != null ? this.encrypted_secret_key.length : 0) + (this.secret_key_version != null ? this.secret_key_version.length : 0);
                }
            }
            return 0;
        }

        @Override
        public void writeTo(DataOutput out) throws IOException {
            out.writeByte(this.type.ordinal());
            switch (this.type.ordinal()) {
                case 0: {
                    int size = this.dh_key != null ? this.dh_key.length : 0;
                    out.writeInt(size);
                    if (this.dh_key == null) break;
                    out.write(this.dh_key);
                    break;
                }
                case 1: {
                    int size = this.dh_key != null ? this.dh_key.length : 0;
                    out.writeInt(size);
                    if (size > 0) {
                        out.write(this.dh_key);
                    }
                    size = this.encrypted_secret_key != null ? this.encrypted_secret_key.length : 0;
                    out.writeInt(size);
                    if (this.encrypted_secret_key != null) {
                        out.write(this.encrypted_secret_key);
                    }
                    size = this.secret_key_version != null ? this.secret_key_version.length : 0;
                    out.writeInt(size);
                    if (this.secret_key_version == null) break;
                    out.write(this.secret_key_version);
                }
            }
        }

        @Override
        public void readFrom(DataInput in) throws IOException {
            byte ordinal = in.readByte();
            this.type = Type.values()[ordinal];
            int size = in.readInt();
            if (size > 0) {
                this.dh_key = new byte[size];
                in.readFully(this.dh_key);
            }
            switch (this.type.ordinal()) {
                case 0: {
                    break;
                }
                case 1: {
                    size = in.readInt();
                    if (size > 0) {
                        this.encrypted_secret_key = new byte[size];
                        in.readFully(this.encrypted_secret_key);
                    }
                    if ((size = in.readInt()) <= 0) break;
                    this.secret_key_version = new byte[size];
                    in.readFully(this.secret_key_version);
                }
            }
        }

        @Override
        public String toString() {
            if (this.type == null) {
                return "n/a";
            }
            switch (this.type.ordinal()) {
                case 0: {
                    return String.format("%s dh-key %d bytes", new Object[]{this.type, this.dh_key.length});
                }
                case 1: {
                    return String.format("%s dh-key %d bytes, encrypted secret %d bytes, version: %s", new Object[]{this.type, this.dh_key.length, this.encrypted_secret_key.length, Util.byteArrayToHexString(this.secret_key_version)});
                }
            }
            return "n/a";
        }
    }

    protected static enum Type {
        SECRET_KEY_REQ,
        SECRET_KEY_RSP;

    }
}

