001/*
002 * nimbus-jose-jwt
003 *
004 * Copyright 2012-2021, Connect2id Ltd and contributors.
005 *
006 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use
007 * this file except in compliance with the License. You may obtain a copy of the
008 * License at
009 *
010 *    http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software distributed
013 * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
014 * CONDITIONS OF ANY KIND, either express or implied. See the License for the
015 * specific language governing permissions and limitations under the License.
016 */
017
018package com.nimbusds.jose.crypto.impl;
019
020
021import com.nimbusds.jose.*;
022import com.nimbusds.jose.crypto.utils.ECChecks;
023import com.nimbusds.jose.jwk.Curve;
024import com.nimbusds.jose.jwk.OctetKeyPair;
025import com.nimbusds.jose.util.Base64URL;
026import com.nimbusds.jose.util.ByteUtils;
027
028import javax.crypto.SecretKey;
029import javax.crypto.spec.SecretKeySpec;
030import java.nio.charset.StandardCharsets;
031import java.security.interfaces.ECPrivateKey;
032import java.security.interfaces.ECPublicKey;
033import java.util.Objects;
034
035
036/**
037 * Elliptic Curve Diffie-Hellman One-Pass Unified Model (ECDH-1PU) key
038 * agreement functions and utilities.
039 *
040 * @see <a href="https://datatracker.ietf.org/doc/html/draft-madden-jose-ecdh-1pu-04">Public
041 * Key Authenticated Encryption for JOSE: ECDH-1PU</a>
042 *
043 * @author Alexander Martynov
044 * @version 2021-08-03
045 */
046public class ECDH1PU {
047
048    /**
049     * Resolves the ECDH algorithm mode.
050     *
051     * @param alg The JWE algorithm. Must be supported and not {@code null}.
052     *
053     * @return The algorithm mode.
054     *
055     * @throws JOSEException If the JWE algorithm is not supported.
056     */
057    public static ECDH.AlgorithmMode resolveAlgorithmMode(final JWEAlgorithm alg)
058        throws JOSEException {
059
060        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
061
062        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
063
064            return ECDH.AlgorithmMode.DIRECT;
065        }
066
067        if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW) ||
068                alg.equals(JWEAlgorithm.ECDH_1PU_A192KW) ||
069                alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)
070        ) {
071
072            return ECDH.AlgorithmMode.KW;
073        }
074
075        throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(
076                alg,
077                ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
078    }
079
080
081    /**
082     * Returns the bit length of the shared key (derived via concat KDF)
083     * for the specified JWE ECDH algorithm.
084     *
085     * @param alg The JWE ECDH algorithm. Must be supported and not
086     *            {@code null}.
087     * @param enc The encryption method. Must be supported and not
088     *            {@code null}.
089     *
090     * @return The bit length of the shared key.
091     *
092     * @throws JOSEException If the JWE algorithm or encryption method is
093     *                       not supported.
094     */
095    public static int sharedKeyLength(final JWEAlgorithm alg, final EncryptionMethod enc)
096        throws JOSEException {
097
098        Objects.requireNonNull(alg, "The parameter \"alg\" must not be null");
099        Objects.requireNonNull(enc, "The parameter \"enc\" must not be null");
100
101        if (alg.equals(JWEAlgorithm.ECDH_1PU)) {
102
103            int length = enc.cekBitLength();
104
105            if (length == 0) {
106                throw new JOSEException("Unsupported JWE encryption method " + enc);
107            }
108
109            return length;
110        }
111
112        if (alg.equals(JWEAlgorithm.ECDH_1PU_A128KW)) {
113            return 128;
114        }
115
116        if (alg.equals(JWEAlgorithm.ECDH_1PU_A192KW)) {
117            return  192;
118        }
119
120        if (alg.equals(JWEAlgorithm.ECDH_1PU_A256KW)) {
121            return  256;
122        }
123
124        throw new JOSEException(AlgorithmSupportMessage.unsupportedJWEAlgorithm(
125                alg, ECDHCryptoProvider.SUPPORTED_ALGORITHMS));
126    }
127
128    /**
129     * Derives a shared key (via concat KDF).
130     *
131     * The method should only be called in the
132     * {@link ECDH.AlgorithmMode#DIRECT} mode.
133     *
134     * The method derives the Content Encryption Key (CEK) for the "enc"
135     * algorithm, in the {@link ECDH.AlgorithmMode#DIRECT} mode.
136     *
137     * The method does not take the auth tag because the auth tag will be
138     * generated using a CEK derived as an output of this method.
139     *
140     * @param header    The JWE header. Its algorithm and encryption method
141     *                  must be supported. Must not be {@code null}.
142     * @param Z         The derived shared secret ('Z'). Must not be
143     *                  {@code null}.
144     * @param concatKDF The concat KDF. Must be initialised and not
145     *                  {@code null}.
146     *
147     * @return The derived shared key.
148     *
149     * @throws JOSEException If derivation of the shared key failed.
150     */
151    public static SecretKey deriveSharedKey(final JWEHeader header,
152                                            final SecretKey Z,
153                                            final ConcatKDF concatKDF)
154            throws JOSEException {
155
156        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
157        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
158        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
159
160        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
161
162        // Set the alg ID for the concat KDF
163        ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
164
165        final String algID;
166
167        if (algMode == ECDH.AlgorithmMode.DIRECT) {
168            // algID = enc
169            algID = header.getEncryptionMethod().getName();
170        } else if (algMode == ECDH.AlgorithmMode.KW) {
171            // algID = alg
172            algID = header.getAlgorithm().getName();
173        } else {
174            throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
175        }
176
177        return concatKDF.deriveKey(
178                Z,
179                sharedKeyLength,
180                ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)),
181                ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()),
182                ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()),
183                ConcatKDF.encodeIntData(sharedKeyLength),
184                ConcatKDF.encodeNoData()
185        );
186    }
187
188    /**
189     * Derives a shared key (via concat KDF).
190     *
191     * The method should only be called in {@link ECDH.AlgorithmMode#KW}.
192     *
193     * In Key Agreement with {@link ECDH.AlgorithmMode#KW} mode, the JWE
194     * Authentication Tag is included in the input to the KDF. This ensures
195     * that the content of the JWE was produced by the original sender and not
196     * by another recipient.
197     *
198     *
199     * @param header    The JWE header. Its algorithm and encryption method
200     *                  must be supported. Must not be {@code null}.
201     * @param Z         The derived shared secret ('Z'). Must not be
202     *                  {@code null}.
203     * @param tag       In Direct Key Agreement mode this is set to an empty
204     *                  octet string. In Key Agreement with Key Wrapping mode,
205     *                  this is set to a value of the form Data, where Data is
206     *                  the raw octets of the JWE Authentication Tag.
207     * @param concatKDF The concat KDF. Must be initialised and not
208     *                  {@code null}.
209     *
210     * @return The derived shared key.
211     *
212     * @throws JOSEException If derivation of the shared key failed.
213     */
214    public static SecretKey deriveSharedKey(final JWEHeader header,
215                        final SecretKey Z,
216                        final Base64URL tag,
217                        final ConcatKDF concatKDF)
218        throws JOSEException {
219
220        Objects.requireNonNull(header, "The parameter \"header\" must not be null");
221        Objects.requireNonNull(Z, "The parameter \"Z\" must not be null");
222        Objects.requireNonNull(tag, "The parameter \"tag\" must not be null");
223        Objects.requireNonNull(concatKDF, "The parameter \"concatKDF\" must not be null");
224
225        final int sharedKeyLength = sharedKeyLength(header.getAlgorithm(), header.getEncryptionMethod());
226
227        // Set the alg ID for the concat KDF
228        ECDH.AlgorithmMode algMode = resolveAlgorithmMode(header.getAlgorithm());
229
230        final String algID;
231
232        if (algMode == ECDH.AlgorithmMode.DIRECT) {
233            // algID = enc
234            algID = header.getEncryptionMethod().getName();
235        } else if (algMode == ECDH.AlgorithmMode.KW) {
236            // algID = alg
237            algID = header.getAlgorithm().getName();
238        } else {
239            throw new JOSEException("Unsupported JWE ECDH algorithm mode: " + algMode);
240        }
241
242        return concatKDF.deriveKey(
243            Z,
244            sharedKeyLength,
245            ConcatKDF.encodeDataWithLength(algID.getBytes(StandardCharsets.US_ASCII)),
246            ConcatKDF.encodeDataWithLength(header.getAgreementPartyUInfo()),
247            ConcatKDF.encodeDataWithLength(header.getAgreementPartyVInfo()),
248            ConcatKDF.encodeIntData(sharedKeyLength),
249            ConcatKDF.encodeNoData(),
250            ConcatKDF.encodeDataWithLength(tag)
251        );
252    }
253
254    /**
255     * Derives a shared secret (also called 'Z') where Z is the concatenation
256     * of Ze and Zs.
257     *
258     * @param Ze The shared secret derived from applying the ECDH primitive to
259     *           the sender's ephemeral private key and the recipient's static
260     *           public key (when sending) or the recipient's static private
261     *           key and the sender's ephemeral public key (when receiving).
262     *           Must not be {@code null}.
263     * @param Zs The shared secret derived from applying the ECDH primitive to
264     *           the sender's static private key and the recipient's static
265     *           public key (when sending) or the recipient's static private
266     *           key and the sender's static public key (when receiving). Must
267     *           not be {@code null}.
268     *
269     * @return The derived shared key.
270     */
271    public static SecretKey deriveZ(final SecretKey Ze, final SecretKey Zs) {
272        Objects.requireNonNull(Ze, "The parameter \"Ze\" must not be null");
273        Objects.requireNonNull(Zs, "The parameter \"Zs\" must not be null");
274
275        byte[] encodedKey = ByteUtils.concat(Ze.getEncoded(), Zs.getEncoded());
276        return new SecretKeySpec(encodedKey, 0, encodedKey.length, "AES");
277    }
278
279
280    /**
281     * Ensures the private key and public key are from the same curve.
282     *
283     * @param privateKey EC private key. Must not be {@code null}.
284     * @param publicKey  EC public key. Must not be {@code null}.
285     *
286     * @throws JOSEException If the key curves don't match.
287     */
288    public static void validateSameCurve(final ECPrivateKey privateKey, final ECPublicKey publicKey)
289            throws JOSEException{
290        
291        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
292        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
293
294        if (!privateKey.getParams().getCurve().equals(publicKey.getParams().getCurve())) {
295            throw new JOSEException("Curve of public key does not match curve of private key");
296        }
297
298        if (!ECChecks.isPointOnCurve(publicKey, privateKey)) {
299            throw new JOSEException("Invalid public EC key: Point(s) not on the expected curve");
300        }
301    }
302
303    /**
304     * Ensures the private key and public key are from the same curve.
305     *
306     * @param privateKey OKP private key. Must not be {@code null}.
307     * @param publicKey  OKP public key. Must not be {@code null}.
308     *
309     * @throws JOSEException If the curves don't match.
310     */
311    public static void validateSameCurve(final OctetKeyPair privateKey, final OctetKeyPair publicKey)
312            throws JOSEException {
313        
314        Objects.requireNonNull(privateKey, "The parameter \"privateKey\" must not be null");
315        Objects.requireNonNull(publicKey, "The parameter \"publicKey\" must not be null");
316
317        if (!privateKey.isPrivate()) {
318            throw new JOSEException("OKP private key should be a private key");
319        }
320
321        if (publicKey.isPrivate()) {
322            throw new JOSEException("OKP public key should not be a private key");
323        }
324
325        if (!publicKey.getCurve().equals(Curve.X25519)) {
326            throw new JOSEException("Only supports OctetKeyPairs with crv=X25519");
327        }
328
329        if (!privateKey.getCurve().equals(publicKey.getCurve())) {
330            throw new JOSEException("Curve of public key does not match curve of private key");
331        }
332    }
333
334    /**
335     * Prevents public instantiation.
336     */
337    private ECDH1PU() {
338
339    }
340}