1// Copyright 2022 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package ecdh
6
7import (
8	"crypto/internal/boring"
9	"crypto/internal/nistec"
10	"crypto/internal/randutil"
11	"errors"
12	"internal/byteorder"
13	"io"
14	"math/bits"
15)
16
17type nistCurve[Point nistPoint[Point]] struct {
18	name        string
19	newPoint    func() Point
20	scalarOrder []byte
21}
22
23// nistPoint is a generic constraint for the nistec Point types.
24type nistPoint[T any] interface {
25	Bytes() []byte
26	BytesX() ([]byte, error)
27	SetBytes([]byte) (T, error)
28	ScalarMult(T, []byte) (T, error)
29	ScalarBaseMult([]byte) (T, error)
30}
31
32func (c *nistCurve[Point]) String() string {
33	return c.name
34}
35
36var errInvalidPrivateKey = errors.New("crypto/ecdh: invalid private key")
37
38func (c *nistCurve[Point]) GenerateKey(rand io.Reader) (*PrivateKey, error) {
39	if boring.Enabled && rand == boring.RandReader {
40		key, bytes, err := boring.GenerateKeyECDH(c.name)
41		if err != nil {
42			return nil, err
43		}
44		return newBoringPrivateKey(c, key, bytes)
45	}
46
47	key := make([]byte, len(c.scalarOrder))
48	randutil.MaybeReadByte(rand)
49	for {
50		if _, err := io.ReadFull(rand, key); err != nil {
51			return nil, err
52		}
53
54		// Mask off any excess bits if the size of the underlying field is not a
55		// whole number of bytes, which is only the case for P-521. We use a
56		// pointer to the scalarOrder field because comparing generic and
57		// instantiated types is not supported.
58		if &c.scalarOrder[0] == &p521Order[0] {
59			key[0] &= 0b0000_0001
60		}
61
62		// In tests, rand will return all zeros and NewPrivateKey will reject
63		// the zero key as it generates the identity as a public key. This also
64		// makes this function consistent with crypto/elliptic.GenerateKey.
65		key[1] ^= 0x42
66
67		k, err := c.NewPrivateKey(key)
68		if err == errInvalidPrivateKey {
69			continue
70		}
71		return k, err
72	}
73}
74
75func (c *nistCurve[Point]) NewPrivateKey(key []byte) (*PrivateKey, error) {
76	if len(key) != len(c.scalarOrder) {
77		return nil, errors.New("crypto/ecdh: invalid private key size")
78	}
79	if isZero(key) || !isLess(key, c.scalarOrder) {
80		return nil, errInvalidPrivateKey
81	}
82	if boring.Enabled {
83		bk, err := boring.NewPrivateKeyECDH(c.name, key)
84		if err != nil {
85			return nil, err
86		}
87		return newBoringPrivateKey(c, bk, key)
88	}
89	k := &PrivateKey{
90		curve:      c,
91		privateKey: append([]byte{}, key...),
92	}
93	return k, nil
94}
95
96func newBoringPrivateKey(c Curve, bk *boring.PrivateKeyECDH, privateKey []byte) (*PrivateKey, error) {
97	k := &PrivateKey{
98		curve:      c,
99		boring:     bk,
100		privateKey: append([]byte(nil), privateKey...),
101	}
102	return k, nil
103}
104
105func (c *nistCurve[Point]) privateKeyToPublicKey(key *PrivateKey) *PublicKey {
106	boring.Unreachable()
107	if key.curve != c {
108		panic("crypto/ecdh: internal error: converting the wrong key type")
109	}
110	p, err := c.newPoint().ScalarBaseMult(key.privateKey)
111	if err != nil {
112		// This is unreachable because the only error condition of
113		// ScalarBaseMult is if the input is not the right size.
114		panic("crypto/ecdh: internal error: nistec ScalarBaseMult failed for a fixed-size input")
115	}
116	publicKey := p.Bytes()
117	if len(publicKey) == 1 {
118		// The encoding of the identity is a single 0x00 byte. This is
119		// unreachable because the only scalar that generates the identity is
120		// zero, which is rejected by NewPrivateKey.
121		panic("crypto/ecdh: internal error: nistec ScalarBaseMult returned the identity")
122	}
123	return &PublicKey{
124		curve:     key.curve,
125		publicKey: publicKey,
126	}
127}
128
129// isZero returns whether a is all zeroes in constant time.
130func isZero(a []byte) bool {
131	var acc byte
132	for _, b := range a {
133		acc |= b
134	}
135	return acc == 0
136}
137
138// isLess returns whether a < b, where a and b are big-endian buffers of the
139// same length and shorter than 72 bytes.
140func isLess(a, b []byte) bool {
141	if len(a) != len(b) {
142		panic("crypto/ecdh: internal error: mismatched isLess inputs")
143	}
144
145	// Copy the values into a fixed-size preallocated little-endian buffer.
146	// 72 bytes is enough for every scalar in this package, and having a fixed
147	// size lets us avoid heap allocations.
148	if len(a) > 72 {
149		panic("crypto/ecdh: internal error: isLess input too large")
150	}
151	bufA, bufB := make([]byte, 72), make([]byte, 72)
152	for i := range a {
153		bufA[i], bufB[i] = a[len(a)-i-1], b[len(b)-i-1]
154	}
155
156	// Perform a subtraction with borrow.
157	var borrow uint64
158	for i := 0; i < len(bufA); i += 8 {
159		limbA, limbB := byteorder.LeUint64(bufA[i:]), byteorder.LeUint64(bufB[i:])
160		_, borrow = bits.Sub64(limbA, limbB, borrow)
161	}
162
163	// If there is a borrow at the end of the operation, then a < b.
164	return borrow == 1
165}
166
167func (c *nistCurve[Point]) NewPublicKey(key []byte) (*PublicKey, error) {
168	// Reject the point at infinity and compressed encodings.
169	if len(key) == 0 || key[0] != 4 {
170		return nil, errors.New("crypto/ecdh: invalid public key")
171	}
172	k := &PublicKey{
173		curve:     c,
174		publicKey: append([]byte{}, key...),
175	}
176	if boring.Enabled {
177		bk, err := boring.NewPublicKeyECDH(c.name, k.publicKey)
178		if err != nil {
179			return nil, err
180		}
181		k.boring = bk
182	} else {
183		// SetBytes also checks that the point is on the curve.
184		if _, err := c.newPoint().SetBytes(key); err != nil {
185			return nil, err
186		}
187	}
188	return k, nil
189}
190
191func (c *nistCurve[Point]) ecdh(local *PrivateKey, remote *PublicKey) ([]byte, error) {
192	// Note that this function can't return an error, as NewPublicKey rejects
193	// invalid points and the point at infinity, and NewPrivateKey rejects
194	// invalid scalars and the zero value. BytesX returns an error for the point
195	// at infinity, but in a prime order group such as the NIST curves that can
196	// only be the result of a scalar multiplication if one of the inputs is the
197	// zero scalar or the point at infinity.
198
199	if boring.Enabled {
200		return boring.ECDH(local.boring, remote.boring)
201	}
202
203	boring.Unreachable()
204	p, err := c.newPoint().SetBytes(remote.publicKey)
205	if err != nil {
206		return nil, err
207	}
208	if _, err := p.ScalarMult(p, local.privateKey); err != nil {
209		return nil, err
210	}
211	return p.BytesX()
212}
213
214// P256 returns a [Curve] which implements NIST P-256 (FIPS 186-3, section D.2.3),
215// also known as secp256r1 or prime256v1.
216//
217// Multiple invocations of this function will return the same value, which can
218// be used for equality checks and switch statements.
219func P256() Curve { return p256 }
220
221var p256 = &nistCurve[*nistec.P256Point]{
222	name:        "P-256",
223	newPoint:    nistec.NewP256Point,
224	scalarOrder: p256Order,
225}
226
227var p256Order = []byte{
228	0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
229	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
230	0xbc, 0xe6, 0xfa, 0xad, 0xa7, 0x17, 0x9e, 0x84,
231	0xf3, 0xb9, 0xca, 0xc2, 0xfc, 0x63, 0x25, 0x51}
232
233// P384 returns a [Curve] which implements NIST P-384 (FIPS 186-3, section D.2.4),
234// also known as secp384r1.
235//
236// Multiple invocations of this function will return the same value, which can
237// be used for equality checks and switch statements.
238func P384() Curve { return p384 }
239
240var p384 = &nistCurve[*nistec.P384Point]{
241	name:        "P-384",
242	newPoint:    nistec.NewP384Point,
243	scalarOrder: p384Order,
244}
245
246var p384Order = []byte{
247	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
248	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
249	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
250	0xc7, 0x63, 0x4d, 0x81, 0xf4, 0x37, 0x2d, 0xdf,
251	0x58, 0x1a, 0x0d, 0xb2, 0x48, 0xb0, 0xa7, 0x7a,
252	0xec, 0xec, 0x19, 0x6a, 0xcc, 0xc5, 0x29, 0x73}
253
254// P521 returns a [Curve] which implements NIST P-521 (FIPS 186-3, section D.2.5),
255// also known as secp521r1.
256//
257// Multiple invocations of this function will return the same value, which can
258// be used for equality checks and switch statements.
259func P521() Curve { return p521 }
260
261var p521 = &nistCurve[*nistec.P521Point]{
262	name:        "P-521",
263	newPoint:    nistec.NewP521Point,
264	scalarOrder: p521Order,
265}
266
267var p521Order = []byte{0x01, 0xff,
268	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
269	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
270	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
271	0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfa,
272	0x51, 0x86, 0x87, 0x83, 0xbf, 0x2f, 0x96, 0x6b,
273	0x7f, 0xcc, 0x01, 0x48, 0xf7, 0x09, 0xa5, 0xd0,
274	0x3b, 0xb5, 0xc9, 0xb8, 0x89, 0x9c, 0x47, 0xae,
275	0xbb, 0x6f, 0xb7, 0x1e, 0x91, 0x38, 0x64, 0x09}
276