xref: /aosp_15_r20/external/boringssl/src/ssl/test/runner/key_agreement.go (revision 8fb009dc861624b67b6cdb62ea21f0f22d0c584b)
1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package runner
6
7import (
8	"crypto"
9	"crypto/ecdsa"
10	"crypto/ed25519"
11	"crypto/elliptic"
12	"crypto/rsa"
13	"crypto/subtle"
14	"crypto/x509"
15	"errors"
16	"fmt"
17	"io"
18	"math/big"
19
20	"boringssl.googlesource.com/boringssl/ssl/test/runner/kyber"
21	"golang.org/x/crypto/curve25519"
22)
23
24type keyType int
25
26const (
27	keyTypeRSA keyType = iota + 1
28	keyTypeECDSA
29)
30
31var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
32var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
33
34// rsaKeyAgreement implements the standard TLS key agreement where the client
35// encrypts the pre-master secret to the server's public key.
36type rsaKeyAgreement struct {
37	version       uint16
38	clientVersion uint16
39	exportKey     *rsa.PrivateKey
40}
41
42func (ka *rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
43	// Save the client version for comparison later.
44	ka.clientVersion = clientHello.vers
45
46	if !config.Bugs.RSAEphemeralKey {
47		return nil, nil
48	}
49
50	// Generate an ephemeral RSA key to use instead of the real
51	// one, as in RSA_EXPORT.
52	key, err := rsa.GenerateKey(config.rand(), 512)
53	if err != nil {
54		return nil, err
55	}
56	ka.exportKey = key
57
58	modulus := key.N.Bytes()
59	exponent := big.NewInt(int64(key.E)).Bytes()
60	serverRSAParams := make([]byte, 0, 2+len(modulus)+2+len(exponent))
61	serverRSAParams = append(serverRSAParams, byte(len(modulus)>>8), byte(len(modulus)))
62	serverRSAParams = append(serverRSAParams, modulus...)
63	serverRSAParams = append(serverRSAParams, byte(len(exponent)>>8), byte(len(exponent)))
64	serverRSAParams = append(serverRSAParams, exponent...)
65
66	var sigAlg signatureAlgorithm
67	if ka.version >= VersionTLS12 {
68		sigAlg, err = selectSignatureAlgorithm(false /* server */, ka.version, cert, config, clientHello.signatureAlgorithms)
69		if err != nil {
70			return nil, err
71		}
72	}
73
74	sig, err := signMessage(false /* server */, ka.version, cert.PrivateKey, config, sigAlg, serverRSAParams)
75	if err != nil {
76		return nil, errors.New("failed to sign RSA parameters: " + err.Error())
77	}
78
79	skx := new(serverKeyExchangeMsg)
80	sigAlgsLen := 0
81	if ka.version >= VersionTLS12 {
82		sigAlgsLen = 2
83	}
84	skx.key = make([]byte, len(serverRSAParams)+sigAlgsLen+2+len(sig))
85	copy(skx.key, serverRSAParams)
86	k := skx.key[len(serverRSAParams):]
87	if ka.version >= VersionTLS12 {
88		k[0] = byte(sigAlg >> 8)
89		k[1] = byte(sigAlg)
90		k = k[2:]
91	}
92	k[0] = byte(len(sig) >> 8)
93	k[1] = byte(len(sig))
94	copy(k[2:], sig)
95
96	return skx, nil
97}
98
99func (ka *rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Credential, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
100	preMasterSecret := make([]byte, 48)
101	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
102	if err != nil {
103		return nil, err
104	}
105
106	if len(ckx.ciphertext) < 2 {
107		return nil, errClientKeyExchange
108	}
109
110	ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
111	if ciphertextLen != len(ckx.ciphertext)-2 {
112		return nil, errClientKeyExchange
113	}
114	ciphertext := ckx.ciphertext[2:]
115
116	key := cert.PrivateKey.(*rsa.PrivateKey)
117	if ka.exportKey != nil {
118		key = ka.exportKey
119	}
120	err = rsa.DecryptPKCS1v15SessionKey(config.rand(), key, ciphertext, preMasterSecret)
121	if err != nil {
122		return nil, err
123	}
124	// This check should be done in constant-time, but this is a testing
125	// implementation. See the discussion at the end of section 7.4.7.1 of
126	// RFC 4346.
127	vers := uint16(preMasterSecret[0])<<8 | uint16(preMasterSecret[1])
128	if ka.clientVersion != vers {
129		return nil, fmt.Errorf("tls: invalid version in RSA premaster (got %04x, wanted %04x)", vers, ka.clientVersion)
130	}
131	return preMasterSecret, nil
132}
133
134func (ka *rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
135	return errors.New("tls: unexpected ServerKeyExchange")
136}
137
138func rsaSize(pub *rsa.PublicKey) int {
139	return (pub.N.BitLen() + 7) / 8
140}
141
142func rsaRawEncrypt(pub *rsa.PublicKey, msg []byte) ([]byte, error) {
143	k := rsaSize(pub)
144	if len(msg) != k {
145		return nil, errors.New("tls: bad padded RSA input")
146	}
147	m := new(big.Int).SetBytes(msg)
148	e := big.NewInt(int64(pub.E))
149	m.Exp(m, e, pub.N)
150	unpadded := m.Bytes()
151	ret := make([]byte, k)
152	copy(ret[len(ret)-len(unpadded):], unpadded)
153	return ret, nil
154}
155
156// nonZeroRandomBytes fills the given slice with non-zero random octets.
157func nonZeroRandomBytes(s []byte, rand io.Reader) {
158	if _, err := io.ReadFull(rand, s); err != nil {
159		panic(err)
160	}
161
162	for i := range s {
163		for s[i] == 0 {
164			if _, err := io.ReadFull(rand, s[i:i+1]); err != nil {
165				panic(err)
166			}
167		}
168	}
169}
170
171func (ka *rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
172	bad := config.Bugs.BadRSAClientKeyExchange
173	preMasterSecret := make([]byte, 48)
174	vers := clientHello.vers
175	if bad == RSABadValueWrongVersion1 {
176		vers ^= 1
177	} else if bad == RSABadValueWrongVersion2 {
178		vers ^= 0x100
179	}
180	preMasterSecret[0] = byte(vers >> 8)
181	preMasterSecret[1] = byte(vers)
182	_, err := io.ReadFull(config.rand(), preMasterSecret[2:])
183	if err != nil {
184		return nil, nil, err
185	}
186
187	sentPreMasterSecret := preMasterSecret
188	if bad == RSABadValueTooLong {
189		sentPreMasterSecret = make([]byte, 1, len(sentPreMasterSecret)+1)
190		sentPreMasterSecret = append(sentPreMasterSecret, preMasterSecret...)
191	} else if bad == RSABadValueTooShort {
192		sentPreMasterSecret = sentPreMasterSecret[:len(sentPreMasterSecret)-1]
193	}
194
195	// Pad for PKCS#1 v1.5.
196	padded := make([]byte, rsaSize(cert.PublicKey.(*rsa.PublicKey)))
197	padded[1] = 2
198	nonZeroRandomBytes(padded[2:len(padded)-len(sentPreMasterSecret)-1], config.rand())
199	copy(padded[len(padded)-len(sentPreMasterSecret):], sentPreMasterSecret)
200
201	if bad == RSABadValueWrongBlockType {
202		padded[1] = 3
203	} else if bad == RSABadValueWrongLeadingByte {
204		padded[0] = 1
205	} else if bad == RSABadValueNoZero {
206		for i := 2; i < len(padded); i++ {
207			if padded[i] == 0 {
208				padded[i]++
209			}
210		}
211	}
212
213	encrypted, err := rsaRawEncrypt(cert.PublicKey.(*rsa.PublicKey), padded)
214	if err != nil {
215		return nil, nil, err
216	}
217	if bad == RSABadValueCorrupt {
218		encrypted[len(encrypted)-1] ^= 1
219		// Clear the high byte to ensure |encrypted| is still below the RSA modulus.
220		encrypted[0] = 0
221	}
222	ckx := new(clientKeyExchangeMsg)
223	ckx.ciphertext = make([]byte, len(encrypted)+2)
224	ckx.ciphertext[0] = byte(len(encrypted) >> 8)
225	ckx.ciphertext[1] = byte(len(encrypted))
226	copy(ckx.ciphertext[2:], encrypted)
227	return preMasterSecret, ckx, nil
228}
229
230func (ka *rsaKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
231	return 0
232}
233
234// A kemImplementation is an instance of KEM-style construction for TLS.
235type kemImplementation interface {
236	// generate generates a keypair using rand. It returns the encoded public key.
237	generate(rand io.Reader) (publicKey []byte, err error)
238
239	// encap generates a symmetric, shared secret, encapsulates it with |peerKey|.
240	// It returns the encapsulated shared secret and the secret itself.
241	encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error)
242
243	// decap decapsulates |ciphertext| and returns the resulting shared secret.
244	decap(ciphertext []byte) (secret []byte, err error)
245}
246
247// ecdhKEM implements kemImplementation with an elliptic.Curve.
248//
249// TODO(davidben): Move this to Go's crypto/ecdh.
250type ecdhKEM struct {
251	curve          elliptic.Curve
252	privateKey     []byte
253	sendCompressed bool
254}
255
256func (e *ecdhKEM) generate(rand io.Reader) (publicKey []byte, err error) {
257	var x, y *big.Int
258	e.privateKey, x, y, err = elliptic.GenerateKey(e.curve, rand)
259	if err != nil {
260		return nil, err
261	}
262	ret := elliptic.Marshal(e.curve, x, y)
263	if e.sendCompressed {
264		l := (len(ret) - 1) / 2
265		tmp := make([]byte, 1+l)
266		tmp[0] = byte(2 | y.Bit(0))
267		copy(tmp[1:], ret[1:1+l])
268		ret = tmp
269	}
270	return ret, nil
271}
272
273func (e *ecdhKEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) {
274	ciphertext, err = e.generate(rand)
275	if err != nil {
276		return nil, nil, err
277	}
278	secret, err = e.decap(peerKey)
279	if err != nil {
280		return nil, nil, err
281	}
282	return
283}
284
285func (e *ecdhKEM) decap(ciphertext []byte) (secret []byte, err error) {
286	x, y := elliptic.Unmarshal(e.curve, ciphertext)
287	if x == nil {
288		return nil, errors.New("tls: invalid peer key")
289	}
290	x, _ = e.curve.ScalarMult(x, y, e.privateKey)
291	secret = make([]byte, (e.curve.Params().BitSize+7)>>3)
292	xBytes := x.Bytes()
293	copy(secret[len(secret)-len(xBytes):], xBytes)
294	return secret, nil
295}
296
297// x25519KEM implements kemImplementation with X25519.
298type x25519KEM struct {
299	privateKey [32]byte
300	setHighBit bool
301}
302
303func (e *x25519KEM) generate(rand io.Reader) (publicKey []byte, err error) {
304	_, err = io.ReadFull(rand, e.privateKey[:])
305	if err != nil {
306		return
307	}
308	var out [32]byte
309	curve25519.ScalarBaseMult(&out, &e.privateKey)
310	if e.setHighBit {
311		out[31] |= 0x80
312	}
313	return out[:], nil
314}
315
316func (e *x25519KEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) {
317	ciphertext, err = e.generate(rand)
318	if err != nil {
319		return nil, nil, err
320	}
321	secret, err = e.decap(peerKey)
322	if err != nil {
323		return nil, nil, err
324	}
325	return
326}
327
328func (e *x25519KEM) decap(ciphertext []byte) (secret []byte, err error) {
329	if len(ciphertext) != 32 {
330		return nil, errors.New("tls: invalid peer key")
331	}
332	var out [32]byte
333	curve25519.ScalarMult(&out, &e.privateKey, (*[32]byte)(ciphertext))
334
335	// Per RFC 7748, reject the all-zero value in constant time.
336	var zeros [32]byte
337	if subtle.ConstantTimeCompare(zeros[:], out[:]) == 1 {
338		return nil, errors.New("tls: X25519 value with wrong order")
339	}
340
341	return out[:], nil
342}
343
344// kyberKEM implements Kyber combined with X25519.
345type kyberKEM struct {
346	x25519PrivateKey [32]byte
347	kyberPrivateKey  *kyber.PrivateKey
348}
349
350func (e *kyberKEM) generate(rand io.Reader) (publicKey []byte, err error) {
351	if _, err := io.ReadFull(rand, e.x25519PrivateKey[:]); err != nil {
352		return nil, err
353	}
354	var x25519Public [32]byte
355	curve25519.ScalarBaseMult(&x25519Public, &e.x25519PrivateKey)
356
357	var kyberEntropy [64]byte
358	if _, err := io.ReadFull(rand, kyberEntropy[:]); err != nil {
359		return nil, err
360	}
361	var kyberPublic *[kyber.PublicKeySize]byte
362	e.kyberPrivateKey, kyberPublic = kyber.NewPrivateKey(&kyberEntropy)
363
364	var ret []byte
365	ret = append(ret, x25519Public[:]...)
366	ret = append(ret, kyberPublic[:]...)
367	return ret, nil
368}
369
370func (e *kyberKEM) encap(rand io.Reader, peerKey []byte) (ciphertext []byte, secret []byte, err error) {
371	if len(peerKey) != 32+kyber.PublicKeySize {
372		return nil, nil, errors.New("tls: bad length Kyber offer")
373	}
374
375	if _, err := io.ReadFull(rand, e.x25519PrivateKey[:]); err != nil {
376		return nil, nil, err
377	}
378
379	var x25519Shared, x25519PeerKey, x25519Public [32]byte
380	copy(x25519PeerKey[:], peerKey)
381	curve25519.ScalarBaseMult(&x25519Public, &e.x25519PrivateKey)
382	curve25519.ScalarMult(&x25519Shared, &e.x25519PrivateKey, &x25519PeerKey)
383
384	// Per RFC 7748, reject the all-zero value in constant time.
385	var zeros [32]byte
386	if subtle.ConstantTimeCompare(zeros[:], x25519Shared[:]) == 1 {
387		return nil, nil, errors.New("tls: X25519 value with wrong order")
388	}
389
390	kyberPublicKey, ok := kyber.UnmarshalPublicKey((*[kyber.PublicKeySize]byte)(peerKey[32:]))
391	if !ok {
392		return nil, nil, errors.New("tls: bad Kyber offer")
393	}
394
395	var kyberShared, kyberEntropy [32]byte
396	if _, err := io.ReadFull(rand, kyberEntropy[:]); err != nil {
397		return nil, nil, err
398	}
399	kyberCiphertext := kyberPublicKey.Encap(kyberShared[:], &kyberEntropy)
400
401	ciphertext = append(ciphertext, x25519Public[:]...)
402	ciphertext = append(ciphertext, kyberCiphertext[:]...)
403	secret = append(secret, x25519Shared[:]...)
404	secret = append(secret, kyberShared[:]...)
405
406	return ciphertext, secret, nil
407}
408
409func (e *kyberKEM) decap(ciphertext []byte) (secret []byte, err error) {
410	if len(ciphertext) != 32+kyber.CiphertextSize {
411		return nil, errors.New("tls: bad length Kyber reply")
412	}
413
414	var x25519Shared, x25519PeerKey [32]byte
415	copy(x25519PeerKey[:], ciphertext)
416	curve25519.ScalarMult(&x25519Shared, &e.x25519PrivateKey, &x25519PeerKey)
417
418	// Per RFC 7748, reject the all-zero value in constant time.
419	var zeros [32]byte
420	if subtle.ConstantTimeCompare(zeros[:], x25519Shared[:]) == 1 {
421		return nil, errors.New("tls: X25519 value with wrong order")
422	}
423
424	var kyberShared [32]byte
425	e.kyberPrivateKey.Decap(kyberShared[:], (*[kyber.CiphertextSize]byte)(ciphertext[32:]))
426
427	secret = append(secret, x25519Shared[:]...)
428	secret = append(secret, kyberShared[:]...)
429
430	return secret, nil
431}
432
433func kemForCurveID(id CurveID, config *Config) (kemImplementation, bool) {
434	switch id {
435	case CurveP224:
436		return &ecdhKEM{curve: elliptic.P224(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
437	case CurveP256:
438		return &ecdhKEM{curve: elliptic.P256(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
439	case CurveP384:
440		return &ecdhKEM{curve: elliptic.P384(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
441	case CurveP521:
442		return &ecdhKEM{curve: elliptic.P521(), sendCompressed: config.Bugs.SendCompressedCoordinates}, true
443	case CurveX25519:
444		return &x25519KEM{setHighBit: config.Bugs.SetX25519HighBit}, true
445	case CurveX25519Kyber768:
446		return &kyberKEM{}, true
447	default:
448		return nil, false
449	}
450
451}
452
453// keyAgreementAuthentication is a helper interface that specifies how
454// to authenticate the ServerKeyExchange parameters.
455type keyAgreementAuthentication interface {
456	signParameters(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error)
457	verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, params []byte, sig []byte) error
458}
459
460// nilKeyAgreementAuthentication does not authenticate the key
461// agreement parameters.
462type nilKeyAgreementAuthentication struct{}
463
464func (ka *nilKeyAgreementAuthentication) signParameters(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) {
465	skx := new(serverKeyExchangeMsg)
466	skx.key = params
467	return skx, nil
468}
469
470func (ka *nilKeyAgreementAuthentication) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, params []byte, sig []byte) error {
471	return nil
472}
473
474// signedKeyAgreement signs the ServerKeyExchange parameters with the
475// server's private key.
476type signedKeyAgreement struct {
477	keyType                keyType
478	version                uint16
479	peerSignatureAlgorithm signatureAlgorithm
480}
481
482func (ka *signedKeyAgreement) signParameters(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) {
483	// The message to be signed is prepended by the randoms.
484	var msg []byte
485	msg = append(msg, clientHello.random...)
486	msg = append(msg, hello.random...)
487	msg = append(msg, params...)
488
489	var sigAlg signatureAlgorithm
490	var err error
491	if ka.version >= VersionTLS12 {
492		sigAlg, err = selectSignatureAlgorithm(false /* server */, ka.version, cert, config, clientHello.signatureAlgorithms)
493		if err != nil {
494			return nil, err
495		}
496	}
497
498	sig, err := signMessage(false /* server */, ka.version, cert.PrivateKey, config, sigAlg, msg)
499	if err != nil {
500		return nil, err
501	}
502	if config.Bugs.SendSignatureAlgorithm != 0 {
503		sigAlg = config.Bugs.SendSignatureAlgorithm
504	}
505
506	skx := new(serverKeyExchangeMsg)
507	if config.Bugs.UnauthenticatedECDH {
508		skx.key = params
509	} else {
510		sigAlgsLen := 0
511		if ka.version >= VersionTLS12 {
512			sigAlgsLen = 2
513		}
514		skx.key = make([]byte, len(params)+sigAlgsLen+2+len(sig))
515		copy(skx.key, params)
516		k := skx.key[len(params):]
517		if ka.version >= VersionTLS12 {
518			k[0] = byte(sigAlg >> 8)
519			k[1] = byte(sigAlg)
520			k = k[2:]
521		}
522		k[0] = byte(len(sig) >> 8)
523		k[1] = byte(len(sig))
524		copy(k[2:], sig)
525	}
526
527	return skx, nil
528}
529
530func (ka *signedKeyAgreement) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, publicKey crypto.PublicKey, params []byte, sig []byte) error {
531	// The peer's key must match the cipher type.
532	switch ka.keyType {
533	case keyTypeECDSA:
534		_, edsaOk := publicKey.(*ecdsa.PublicKey)
535		_, ed25519Ok := publicKey.(ed25519.PublicKey)
536		if !edsaOk && !ed25519Ok {
537			return errors.New("tls: ECDHE ECDSA requires a ECDSA or Ed25519 server public key")
538		}
539	case keyTypeRSA:
540		_, ok := publicKey.(*rsa.PublicKey)
541		if !ok {
542			return errors.New("tls: ECDHE RSA requires a RSA server public key")
543		}
544	default:
545		return errors.New("tls: unknown key type")
546	}
547
548	// The message to be signed is prepended by the randoms.
549	var msg []byte
550	msg = append(msg, clientHello.random...)
551	msg = append(msg, serverHello.random...)
552	msg = append(msg, params...)
553
554	var sigAlg signatureAlgorithm
555	if ka.version >= VersionTLS12 {
556		if len(sig) < 2 {
557			return errServerKeyExchange
558		}
559		sigAlg = signatureAlgorithm(sig[0])<<8 | signatureAlgorithm(sig[1])
560		sig = sig[2:]
561		// Stash the signature algorithm to be extracted by the handshake.
562		ka.peerSignatureAlgorithm = sigAlg
563	}
564
565	if len(sig) < 2 {
566		return errServerKeyExchange
567	}
568	sigLen := int(sig[0])<<8 | int(sig[1])
569	if sigLen+2 != len(sig) {
570		return errServerKeyExchange
571	}
572	sig = sig[2:]
573
574	return verifyMessage(true /* client */, ka.version, publicKey, config, sigAlg, msg, sig)
575}
576
577// ecdheKeyAgreement implements a TLS key agreement where the server
578// generates a ephemeral EC public/private key pair and signs it. The
579// pre-master secret is then calculated using ECDH. The signature may
580// either be ECDSA or RSA.
581type ecdheKeyAgreement struct {
582	auth    keyAgreementAuthentication
583	kem     kemImplementation
584	curveID CurveID
585	peerKey []byte
586}
587
588func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
589	var curveid CurveID
590	preferredCurves := config.curvePreferences()
591
592NextCandidate:
593	for _, candidate := range preferredCurves {
594		if isPqGroup(candidate) && version < VersionTLS13 {
595			// Post-quantum "groups" require TLS 1.3.
596			continue
597		}
598
599		for _, c := range clientHello.supportedCurves {
600			if candidate == c {
601				curveid = c
602				break NextCandidate
603			}
604		}
605	}
606
607	if curveid == 0 {
608		return nil, errors.New("tls: no supported elliptic curves offered")
609	}
610
611	var ok bool
612	if ka.kem, ok = kemForCurveID(curveid, config); !ok {
613		return nil, errors.New("tls: preferredCurves includes unsupported curve")
614	}
615	ka.curveID = curveid
616
617	publicKey, err := ka.kem.generate(config.rand())
618	if err != nil {
619		return nil, err
620	}
621
622	// http://tools.ietf.org/html/rfc4492#section-5.4
623	serverECDHParams := make([]byte, 1+2+1+len(publicKey))
624	serverECDHParams[0] = 3 // named curve
625	if config.Bugs.SendCurve != 0 {
626		curveid = config.Bugs.SendCurve
627	}
628	serverECDHParams[1] = byte(curveid >> 8)
629	serverECDHParams[2] = byte(curveid)
630	serverECDHParams[3] = byte(len(publicKey))
631	copy(serverECDHParams[4:], publicKey)
632	if config.Bugs.InvalidECDHPoint {
633		serverECDHParams[4] ^= 0xff
634	}
635
636	return ka.auth.signParameters(config, cert, clientHello, hello, serverECDHParams)
637}
638
639func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Credential, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
640	if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
641		return nil, errClientKeyExchange
642	}
643	return ka.kem.decap(ckx.ciphertext[1:])
644}
645
646func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
647	if len(skx.key) < 4 {
648		return errServerKeyExchange
649	}
650	if skx.key[0] != 3 { // named curve
651		return errors.New("tls: server selected unsupported curve")
652	}
653	curveID := CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
654	ka.curveID = curveID
655
656	var ok bool
657	if ka.kem, ok = kemForCurveID(curveID, config); !ok {
658		return errors.New("tls: server selected unsupported curve")
659	}
660
661	publicLen := int(skx.key[3])
662	if publicLen+4 > len(skx.key) {
663		return errServerKeyExchange
664	}
665	// Save the peer key for later.
666	ka.peerKey = skx.key[4 : 4+publicLen]
667
668	// Check the signature.
669	serverECDHParams := skx.key[:4+publicLen]
670	sig := skx.key[4+publicLen:]
671	return ka.auth.verifyParameters(config, clientHello, serverHello, key, serverECDHParams, sig)
672}
673
674func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
675	if ka.kem == nil {
676		return nil, nil, errors.New("missing ServerKeyExchange message")
677	}
678
679	ciphertext, secret, err := ka.kem.encap(config.rand(), ka.peerKey)
680	if err != nil {
681		return nil, nil, err
682	}
683
684	ckx := new(clientKeyExchangeMsg)
685	ckx.ciphertext = make([]byte, 1+len(ciphertext))
686	ckx.ciphertext[0] = byte(len(ciphertext))
687	copy(ckx.ciphertext[1:], ciphertext)
688	if config.Bugs.InvalidECDHPoint {
689		ckx.ciphertext[1] ^= 0xff
690	}
691
692	return secret, ckx, nil
693}
694
695func (ka *ecdheKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
696	if auth, ok := ka.auth.(*signedKeyAgreement); ok {
697		return auth.peerSignatureAlgorithm
698	}
699	return 0
700}
701
702// nilKeyAgreement is a fake key agreement used to implement the plain PSK key
703// exchange.
704type nilKeyAgreement struct{}
705
706func (ka *nilKeyAgreement) generateServerKeyExchange(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
707	return nil, nil
708}
709
710func (ka *nilKeyAgreement) processClientKeyExchange(config *Config, cert *Credential, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
711	if len(ckx.ciphertext) != 0 {
712		return nil, errClientKeyExchange
713	}
714
715	// Although in plain PSK, otherSecret is all zeros, the base key
716	// agreement does not access to the length of the pre-shared
717	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
718	// of the appropriate length.
719	return nil, nil
720}
721
722func (ka *nilKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
723	if len(skx.key) != 0 {
724		return errServerKeyExchange
725	}
726	return nil
727}
728
729func (ka *nilKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
730	// Although in plain PSK, otherSecret is all zeros, the base key
731	// agreement does not access to the length of the pre-shared
732	// key. pskKeyAgreement instead interprets nil to mean to use all zeros
733	// of the appropriate length.
734	return nil, &clientKeyExchangeMsg{}, nil
735}
736
737func (ka *nilKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
738	return 0
739}
740
741// makePSKPremaster formats a PSK pre-master secret based on otherSecret from
742// the base key exchange and psk.
743func makePSKPremaster(otherSecret, psk []byte) []byte {
744	out := make([]byte, 0, 2+len(otherSecret)+2+len(psk))
745	out = append(out, byte(len(otherSecret)>>8), byte(len(otherSecret)))
746	out = append(out, otherSecret...)
747	out = append(out, byte(len(psk)>>8), byte(len(psk)))
748	out = append(out, psk...)
749	return out
750}
751
752// pskKeyAgreement implements the PSK key agreement.
753type pskKeyAgreement struct {
754	base         keyAgreement
755	identityHint string
756}
757
758func (ka *pskKeyAgreement) generateServerKeyExchange(config *Config, cert *Credential, clientHello *clientHelloMsg, hello *serverHelloMsg, version uint16) (*serverKeyExchangeMsg, error) {
759	// Assemble the identity hint.
760	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
761	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
762	bytes[1] = byte(len(config.PreSharedKeyIdentity))
763	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
764
765	// If there is one, append the base key agreement's
766	// ServerKeyExchange.
767	baseSkx, err := ka.base.generateServerKeyExchange(config, cert, clientHello, hello, version)
768	if err != nil {
769		return nil, err
770	}
771
772	if baseSkx != nil {
773		bytes = append(bytes, baseSkx.key...)
774	} else if config.PreSharedKeyIdentity == "" && !config.Bugs.AlwaysSendPreSharedKeyIdentityHint {
775		// ServerKeyExchange is optional if the identity hint is empty
776		// and there would otherwise be no ServerKeyExchange.
777		return nil, nil
778	}
779
780	skx := new(serverKeyExchangeMsg)
781	skx.key = bytes
782	return skx, nil
783}
784
785func (ka *pskKeyAgreement) processClientKeyExchange(config *Config, cert *Credential, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
786	// First, process the PSK identity.
787	if len(ckx.ciphertext) < 2 {
788		return nil, errClientKeyExchange
789	}
790	identityLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1])
791	if 2+identityLen > len(ckx.ciphertext) {
792		return nil, errClientKeyExchange
793	}
794	identity := string(ckx.ciphertext[2 : 2+identityLen])
795
796	if identity != config.PreSharedKeyIdentity {
797		return nil, errors.New("tls: unexpected identity")
798	}
799
800	if config.PreSharedKey == nil {
801		return nil, errors.New("tls: pre-shared key not configured")
802	}
803
804	// Process the remainder of the ClientKeyExchange to compute the base
805	// pre-master secret.
806	newCkx := new(clientKeyExchangeMsg)
807	newCkx.ciphertext = ckx.ciphertext[2+identityLen:]
808	otherSecret, err := ka.base.processClientKeyExchange(config, cert, newCkx, version)
809	if err != nil {
810		return nil, err
811	}
812
813	if otherSecret == nil {
814		// Special-case for the plain PSK key exchanges.
815		otherSecret = make([]byte, len(config.PreSharedKey))
816	}
817	return makePSKPremaster(otherSecret, config.PreSharedKey), nil
818}
819
820func (ka *pskKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, key crypto.PublicKey, skx *serverKeyExchangeMsg) error {
821	if len(skx.key) < 2 {
822		return errServerKeyExchange
823	}
824	identityLen := (int(skx.key[0]) << 8) | int(skx.key[1])
825	if 2+identityLen > len(skx.key) {
826		return errServerKeyExchange
827	}
828	ka.identityHint = string(skx.key[2 : 2+identityLen])
829
830	// Process the remainder of the ServerKeyExchange.
831	newSkx := new(serverKeyExchangeMsg)
832	newSkx.key = skx.key[2+identityLen:]
833	return ka.base.processServerKeyExchange(config, clientHello, serverHello, key, newSkx)
834}
835
836func (ka *pskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
837	// The server only sends an identity hint but, for purposes of
838	// test code, the server always sends the hint and it is
839	// required to match.
840	if ka.identityHint != config.PreSharedKeyIdentity {
841		return nil, nil, errors.New("tls: unexpected identity")
842	}
843
844	// Serialize the identity.
845	bytes := make([]byte, 2+len(config.PreSharedKeyIdentity))
846	bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8)
847	bytes[1] = byte(len(config.PreSharedKeyIdentity))
848	copy(bytes[2:], []byte(config.PreSharedKeyIdentity))
849
850	// Append the base key exchange's ClientKeyExchange.
851	otherSecret, baseCkx, err := ka.base.generateClientKeyExchange(config, clientHello, cert)
852	if err != nil {
853		return nil, nil, err
854	}
855	ckx := new(clientKeyExchangeMsg)
856	ckx.ciphertext = append(bytes, baseCkx.ciphertext...)
857
858	if config.PreSharedKey == nil {
859		return nil, nil, errors.New("tls: pre-shared key not configured")
860	}
861	if otherSecret == nil {
862		otherSecret = make([]byte, len(config.PreSharedKey))
863	}
864	return makePSKPremaster(otherSecret, config.PreSharedKey), ckx, nil
865}
866
867func (ka *pskKeyAgreement) peerSignatureAlgorithm() signatureAlgorithm {
868	return 0
869}
870