1use crate::error::*;
23use core::marker::PhantomData;
24
25pub use ec::{Curve, EcKey};
26use nss::{ec, ecdh};
27
28pub type EphemeralKeyPair = KeyPair<Ephemeral>;
29
30#[derive(PartialEq, Eq)]
32pub struct Algorithm {
33 pub(crate) curve_id: ec::Curve,
34}
35
36pub static ECDH_P256: Algorithm = Algorithm {
37 curve_id: ec::Curve::P256,
38};
39
40pub static ECDH_P384: Algorithm = Algorithm {
41 curve_id: ec::Curve::P384,
42};
43
44pub trait Lifetime {}
46
47pub struct Ephemeral {}
49impl Lifetime for Ephemeral {}
50
51pub struct Static {}
53impl Lifetime for Static {}
54
55pub struct KeyPair<U: Lifetime> {
57 private_key: PrivateKey<U>,
58 public_key: PublicKey,
59}
60
61impl<U: Lifetime> KeyPair<U> {
62 pub fn generate(alg: &'static Algorithm) -> Result<Self> {
64 let (prv_key, pub_key) = ec::generate_keypair(alg.curve_id)?;
65 Ok(Self {
66 private_key: PrivateKey {
67 alg,
68 wrapped: prv_key,
69 usage: PhantomData,
70 },
71 public_key: PublicKey {
72 alg,
73 wrapped: pub_key,
74 },
75 })
76 }
77
78 pub fn from_private_key(private_key: PrivateKey<U>) -> Result<Self> {
79 let public_key = private_key
80 .compute_public_key()
81 .map_err(|_| ErrorKind::InternalError)?;
82 Ok(Self {
83 private_key,
84 public_key,
85 })
86 }
87
88 pub fn private_key(&self) -> &PrivateKey<U> {
90 &self.private_key
91 }
92
93 pub fn public_key(&self) -> &PublicKey {
95 &self.public_key
96 }
97
98 pub fn split(self) -> (PrivateKey<U>, PublicKey) {
100 (self.private_key, self.public_key)
101 }
102}
103
104impl KeyPair<Static> {
105 pub fn from(private_key: PrivateKey<Static>) -> Result<Self> {
106 Self::from_private_key(private_key)
107 }
108}
109
110pub struct PublicKey {
112 wrapped: ec::PublicKey,
113 alg: &'static Algorithm,
114}
115
116impl PublicKey {
117 #[inline]
118 pub fn to_bytes(&self) -> Result<Vec<u8>> {
119 Ok(self.wrapped.to_bytes()?)
120 }
121
122 #[inline]
123 pub fn algorithm(&self) -> &'static Algorithm {
124 self.alg
125 }
126}
127
128pub struct UnparsedPublicKey<'a> {
130 alg: &'static Algorithm,
131 bytes: &'a [u8],
132}
133
134impl<'a> UnparsedPublicKey<'a> {
135 pub fn new(algorithm: &'static Algorithm, bytes: &'a [u8]) -> Self {
136 Self {
137 alg: algorithm,
138 bytes,
139 }
140 }
141
142 pub fn algorithm(&self) -> &'static Algorithm {
143 self.alg
144 }
145
146 pub fn bytes(&self) -> &'a [u8] {
147 self.bytes
148 }
149}
150
151pub struct PrivateKey<U: Lifetime> {
153 wrapped: ec::PrivateKey,
154 alg: &'static Algorithm,
155 usage: PhantomData<U>,
156}
157
158impl<U: Lifetime> PrivateKey<U> {
159 #[inline]
160 pub fn algorithm(&self) -> &'static Algorithm {
161 self.alg
162 }
163
164 pub fn compute_public_key(&self) -> Result<PublicKey> {
165 let pub_key = self.wrapped.convert_to_public_key()?;
166 Ok(PublicKey {
167 wrapped: pub_key,
168 alg: self.alg,
169 })
170 }
171
172 pub fn agree(self, peer_public_key: &UnparsedPublicKey<'_>) -> Result<InputKeyMaterial> {
176 agree_(&self.wrapped, self.alg, peer_public_key)
177 }
178}
179
180impl PrivateKey<Static> {
181 pub fn agree_static(
185 &self,
186 peer_public_key: &UnparsedPublicKey<'_>,
187 ) -> Result<InputKeyMaterial> {
188 agree_(&self.wrapped, self.alg, peer_public_key)
189 }
190
191 pub fn import(ec_key: &EcKey) -> Result<Self> {
192 let alg = match ec_key.curve() {
194 Curve::P256 => &ECDH_P256,
195 Curve::P384 => &ECDH_P384,
196 };
197 let private_key = ec::PrivateKey::import(ec_key)?;
198 Ok(Self {
199 wrapped: private_key,
200 alg,
201 usage: PhantomData,
202 })
203 }
204
205 pub fn export(&self) -> Result<EcKey> {
206 Ok(self.wrapped.export()?)
207 }
208
209 pub fn _tests_only_dangerously_convert_to_ephemeral(self) -> PrivateKey<Ephemeral> {
213 PrivateKey::<Ephemeral> {
214 wrapped: self.wrapped,
215 alg: self.alg,
216 usage: PhantomData,
217 }
218 }
219}
220
221fn agree_(
222 my_private_key: &ec::PrivateKey,
223 my_alg: &Algorithm,
224 peer_public_key: &UnparsedPublicKey<'_>,
225) -> Result<InputKeyMaterial> {
226 let alg = &my_alg;
227 if peer_public_key.algorithm() != *alg {
228 return Err(ErrorKind::InternalError.into());
229 }
230 let pub_key = ec::PublicKey::from_bytes(my_private_key.curve(), peer_public_key.bytes())?;
231 let value = ecdh::ecdh_agreement(my_private_key, &pub_key)?;
232 Ok(InputKeyMaterial { value })
233}
234
235#[must_use]
237pub struct InputKeyMaterial {
238 value: Vec<u8>,
239}
240
241impl InputKeyMaterial {
242 pub fn derive<F, R>(self, kdf: F) -> R
246 where
247 F: FnOnce(&[u8]) -> R,
248 {
249 kdf(&self.value)
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
257 use nss::ensure_initialized;
258
259 const PUB_KEY_1_B64: &str =
263 "BLunVoWkR67xRdAohVblFBWn1Oosb3kH_baxw1yfIYFfthSm4LIY35vDD-5LE454eB7TShn919DVVGZ_7tWdjTE";
264 const PRIV_KEY_1_JWK_D: &str = "CQ8uF_-zB1NftLO6ytwKM3Cnuol64PQw5qOuCzQJeFU";
265 const PRIV_KEY_1_JWK_X: &str = "u6dWhaRHrvFF0CiFVuUUFafU6ixveQf9trHDXJ8hgV8";
266 const PRIV_KEY_1_JWK_Y: &str = "thSm4LIY35vDD-5LE454eB7TShn919DVVGZ_7tWdjTE";
267
268 const PRIV_KEY_2_JWK_D: &str = "uN2YSQvxuxhQQ9Y1XXjYi1vr2ZTdzuoDX18PYu4LU-0";
269 const PRIV_KEY_2_JWK_X: &str = "S2S3tjygMB0DkM-N9jYUgGLt_9_H6km5P9V6V_KS4_4";
270 const PRIV_KEY_2_JWK_Y: &str = "03j8Tyqgrc4R4FAUV2C7-im96yMmfmO_5Om6Kr8YP3o";
271
272 const SHARED_SECRET_HEX: &str =
273 "163FAA3FC4815D47345C8E959F707B2F1D3537E7B2EA1DAEC23CA8D0A242CFF3";
274
275 fn load_priv_key_1() -> PrivateKey<Static> {
276 let private_key = URL_SAFE_NO_PAD.decode(PRIV_KEY_1_JWK_D).unwrap();
277 let x = URL_SAFE_NO_PAD.decode(PRIV_KEY_1_JWK_X).unwrap();
278 let y = URL_SAFE_NO_PAD.decode(PRIV_KEY_1_JWK_Y).unwrap();
279 PrivateKey::<Static>::import(
280 &EcKey::from_coordinates(Curve::P256, &private_key, &x, &y).unwrap(),
281 )
282 .unwrap()
283 }
284
285 fn load_priv_key_2() -> PrivateKey<Static> {
286 let private_key = URL_SAFE_NO_PAD.decode(PRIV_KEY_2_JWK_D).unwrap();
287 let x = URL_SAFE_NO_PAD.decode(PRIV_KEY_2_JWK_X).unwrap();
288 let y = URL_SAFE_NO_PAD.decode(PRIV_KEY_2_JWK_Y).unwrap();
289 PrivateKey::<Static>::import(
290 &EcKey::from_coordinates(Curve::P256, &private_key, &x, &y).unwrap(),
291 )
292 .unwrap()
293 }
294
295 #[test]
296 fn test_static_agreement() {
297 ensure_initialized();
298 let pub_key_raw = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap();
299 let peer_pub_key = UnparsedPublicKey::new(&ECDH_P256, &pub_key_raw);
300 let prv_key = load_priv_key_2();
301 let ikm = prv_key.agree_static(&peer_pub_key).unwrap();
302 let secret = ikm
303 .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
304 .unwrap();
305 let secret_b64 = hex::encode_upper(secret);
306 assert_eq!(secret_b64, *SHARED_SECRET_HEX);
307 }
308
309 #[test]
310 fn test_ephemeral_agreement_roundtrip() {
311 ensure_initialized();
312 let (our_prv_key, our_pub_key) =
313 KeyPair::<Ephemeral>::generate(&ECDH_P256).unwrap().split();
314 let (their_prv_key, their_pub_key) =
315 KeyPair::<Ephemeral>::generate(&ECDH_P256).unwrap().split();
316 let their_pub_key_raw = their_pub_key.to_bytes().unwrap();
317 let peer_public_key_1 = UnparsedPublicKey::new(&ECDH_P256, &their_pub_key_raw);
318 let ikm_1 = our_prv_key.agree(&peer_public_key_1).unwrap();
319 let secret_1 = ikm_1
320 .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
321 .unwrap();
322 let our_pub_key_raw = our_pub_key.to_bytes().unwrap();
323 let peer_public_key_2 = UnparsedPublicKey::new(&ECDH_P256, &our_pub_key_raw);
324 let ikm_2 = their_prv_key.agree(&peer_public_key_2).unwrap();
325 let secret_2 = ikm_2
326 .derive(|z| -> Result<Vec<u8>> { Ok(z.to_vec()) })
327 .unwrap();
328 assert_eq!(secret_1, secret_2);
329 }
330
331 #[test]
332 fn test_compute_public_key() {
333 ensure_initialized();
334 let (prv_key, pub_key) = KeyPair::<Static>::generate(&ECDH_P256).unwrap().split();
335 let computed_pub_key = prv_key.compute_public_key().unwrap();
336 assert_eq!(
337 computed_pub_key.to_bytes().unwrap(),
338 pub_key.to_bytes().unwrap()
339 );
340 }
341
342 #[test]
343 fn test_compute_public_key_known_values() {
344 ensure_initialized();
345 let prv_key = load_priv_key_1();
346 let pub_key = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap();
347 let computed_pub_key = prv_key.compute_public_key().unwrap();
348 assert_eq!(computed_pub_key.to_bytes().unwrap(), pub_key.as_slice());
349
350 let prv_key = load_priv_key_2();
351 let computed_pub_key = prv_key.compute_public_key().unwrap();
352 assert_ne!(computed_pub_key.to_bytes().unwrap(), pub_key.as_slice());
353 }
354
355 #[test]
356 fn test_keys_byte_representations_roundtrip() {
357 ensure_initialized();
358 let key_pair = KeyPair::<Static>::generate(&ECDH_P256).unwrap();
359 let prv_key = key_pair.private_key;
360 let extracted_pub_key = prv_key.compute_public_key().unwrap();
361 let ec_key = prv_key.export().unwrap();
362 let prv_key_reconstructed = PrivateKey::<Static>::import(&ec_key).unwrap();
363 let extracted_pub_key_reconstructed = prv_key.compute_public_key().unwrap();
364 let ec_key_reconstructed = prv_key_reconstructed.export().unwrap();
365 assert_eq!(ec_key.curve(), ec_key_reconstructed.curve());
366 assert_eq!(ec_key.public_key(), ec_key_reconstructed.public_key());
367 assert_eq!(ec_key.private_key(), ec_key_reconstructed.private_key());
368 assert_eq!(
369 extracted_pub_key.to_bytes().unwrap(),
370 extracted_pub_key_reconstructed.to_bytes().unwrap()
371 );
372 }
373
374 #[test]
375 fn test_agreement_rejects_invalid_pubkeys() {
376 ensure_initialized();
377 let prv_key = load_priv_key_2();
378
379 let mut invalid_pub_key = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap();
380 invalid_pub_key[0] = invalid_pub_key[0].wrapping_add(1);
381 assert!(prv_key
382 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
383 .is_err());
384
385 let mut invalid_pub_key = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap();
386 invalid_pub_key[0] = 0x02;
387 assert!(prv_key
388 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
389 .is_err());
390
391 let mut invalid_pub_key = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap();
392 invalid_pub_key[64] = invalid_pub_key[0].wrapping_add(1);
393 assert!(prv_key
394 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
395 .is_err());
396
397 let mut invalid_pub_key = [0u8; 65];
398 assert!(prv_key
399 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
400 .is_err());
401 invalid_pub_key[0] = 0x04;
402
403 let mut invalid_pub_key = URL_SAFE_NO_PAD.decode(PUB_KEY_1_B64).unwrap().to_vec();
404 invalid_pub_key = invalid_pub_key[0..64].to_vec();
405 assert!(prv_key
406 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
407 .is_err());
408
409 let invalid_pub_key_b64 = "BEogZ-rnm44oJkKsOE6Tc7NwFMgmntf7Btm_Rc4atxcqq99Xq1RWNTFpk99pdQOSjUvwELss51PkmAGCXhLfMV0";
412 let invalid_pub_key = URL_SAFE_NO_PAD.decode(invalid_pub_key_b64).unwrap();
413 assert!(prv_key
414 .agree_static(&UnparsedPublicKey::new(&ECDH_P256, &invalid_pub_key))
415 .is_err());
416 }
417}