001package com.nimbusds.openid.connect.sdk.id;
002
003
004import javax.crypto.Cipher;
005import javax.crypto.SecretKey;
006import javax.crypto.spec.IvParameterSpec;
007
008import com.nimbusds.jose.util.Base64URL;
009import com.nimbusds.oauth2.sdk.id.Subject;
010import net.jcip.annotations.ThreadSafe;
011import org.apache.commons.lang3.tuple.ImmutablePair;
012import org.apache.commons.lang3.tuple.Pair;
013
014
015/**
016 * AES/CBC/PKCS5Padding based encoder / decoder of pairwise subject
017 * identifiers. The salt is used as the IV. Reversal is supported.
018 *
019 * <p>The plain text is formatted as follows ('\' as delimiter):
020 *
021 * <pre>
022 * sector_id|local_sub
023 * </pre>
024 *
025 * <p>Related specifications:
026 *
027 * <ul>
028 *     <li>OpenID Connect Core 1.0, section 8.1.
029 * </ul>
030 */
031@ThreadSafe
032public class AESBasedPairwiseSubjectCodec extends PairwiseSubjectCodec {
033
034
035        /**
036         * The AES key.
037         */
038        private final SecretKey aesKey;
039
040
041        /**
042         * Creates a new AES-based codec for pairwise subject identifiers.
043         *
044         * @param aesKey The AES key. Must not be {@code null}.
045         * @param salt   The salt. Must not be {@code null}.
046         */
047        public AESBasedPairwiseSubjectCodec(final SecretKey aesKey, final byte[] salt) {
048                super(salt);
049                if (salt == null) {
050                        throw new IllegalArgumentException("The salt must not be null");
051                }
052                if (aesKey == null) {
053                        throw new IllegalArgumentException("The AES key must not be null");
054                }
055                this.aesKey = aesKey;
056        }
057
058
059        /**
060         * Returns the AES key.
061         *
062         * @return The key.
063         */
064        public SecretKey getAESKey() {
065                return aesKey;
066        }
067
068
069        /**
070         * Creates a new AES/CBC/PKCS5Padding cipher using the configured
071         * JCE provider and salt.
072         *
073         * @param mode The cipher mode.
074         *
075         * @return The cipher.
076         */
077        private Cipher createCipher(final int mode) {
078
079                Cipher aesCipher;
080
081                try {
082                        if (getProvider() != null) {
083                                aesCipher = Cipher.getInstance("AES/CBC/PKCS5Padding", getProvider());
084                        } else {
085                                aesCipher = Cipher.getInstance("AES/CBC/PKCS5Padding");
086                        }
087
088                        aesCipher.init(mode, aesKey, new IvParameterSpec(getSalt()));
089                } catch (Exception e) {
090                        throw new RuntimeException(e);
091                }
092
093                return aesCipher;
094        }
095
096
097        @Override
098        public Subject encode(final SectorID sectorID, final Subject localSub) {
099
100                // Join parameters, delimited by '\'
101                byte[] plainText = (sectorID.getValue().replace("|", "\\|") + '|' + localSub.getValue().replace("|", "\\|")).getBytes(CHARSET);
102                byte[] cipherText;
103                try {
104                        cipherText = createCipher(Cipher.ENCRYPT_MODE).doFinal(plainText);
105                } catch (Exception e) {
106                        throw new RuntimeException(e);
107                }
108
109                return new Subject(Base64URL.encode(cipherText).toString());
110        }
111
112
113        @Override
114        public Pair<SectorID, Subject> decode(final Subject pairwiseSubject)
115                throws InvalidPairwiseSubjectException {
116
117                byte[] cipherText = new Base64URL(pairwiseSubject.getValue()).decode();
118
119                Cipher aesCipher = createCipher(Cipher.DECRYPT_MODE);
120
121                byte[] plainText;
122                try {
123                        plainText = aesCipher.doFinal(cipherText);
124                } catch (Exception e) {
125                        throw new InvalidPairwiseSubjectException("Decryption failed: " + e.getMessage(), e);
126                }
127
128                String parts[] = new String(plainText, CHARSET).split("(?<!\\\\)\\|");
129
130                // Unescape delimiter
131                for (int i=0; i<parts.length; i++) {
132                        parts[i] = parts[i].replace("\\|", "|");
133                }
134
135                // Check format
136                if (parts.length != 2) {
137                        throw new InvalidPairwiseSubjectException("Invalid format: Unexpected number of tokens: " + parts.length);
138                }
139
140                return new ImmutablePair<>(new SectorID(parts[0]), new Subject(parts[1]));
141        }
142}