1// Copyright 2010 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package elliptic
6
7import (
8	"bytes"
9	"crypto/rand"
10	"encoding/hex"
11	"math/big"
12	"testing"
13)
14
15// genericParamsForCurve returns the dereferenced CurveParams for
16// the specified curve. This is used to avoid the logic for
17// upgrading a curve to its specific implementation, forcing
18// usage of the generic implementation.
19func genericParamsForCurve(c Curve) *CurveParams {
20	d := *(c.Params())
21	return &d
22}
23
24func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
25	tests := []struct {
26		name  string
27		curve Curve
28	}{
29		{"P256", P256()},
30		{"P256/Params", genericParamsForCurve(P256())},
31		{"P224", P224()},
32		{"P224/Params", genericParamsForCurve(P224())},
33		{"P384", P384()},
34		{"P384/Params", genericParamsForCurve(P384())},
35		{"P521", P521()},
36		{"P521/Params", genericParamsForCurve(P521())},
37	}
38	if testing.Short() {
39		tests = tests[:1]
40	}
41	for _, test := range tests {
42		curve := test.curve
43		t.Run(test.name, func(t *testing.T) {
44			t.Parallel()
45			f(t, curve)
46		})
47	}
48}
49
50func TestOnCurve(t *testing.T) {
51	t.Parallel()
52	testAllCurves(t, func(t *testing.T, curve Curve) {
53		if !curve.IsOnCurve(curve.Params().Gx, curve.Params().Gy) {
54			t.Error("basepoint is not on the curve")
55		}
56	})
57}
58
59func TestOffCurve(t *testing.T) {
60	t.Parallel()
61	testAllCurves(t, func(t *testing.T, curve Curve) {
62		x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1)
63		if curve.IsOnCurve(x, y) {
64			t.Errorf("point off curve is claimed to be on the curve")
65		}
66
67		byteLen := (curve.Params().BitSize + 7) / 8
68		b := make([]byte, 1+2*byteLen)
69		b[0] = 4 // uncompressed point
70		x.FillBytes(b[1 : 1+byteLen])
71		y.FillBytes(b[1+byteLen : 1+2*byteLen])
72
73		x1, y1 := Unmarshal(curve, b)
74		if x1 != nil || y1 != nil {
75			t.Errorf("unmarshaling a point not on the curve succeeded")
76		}
77	})
78}
79
80func TestInfinity(t *testing.T) {
81	t.Parallel()
82	testAllCurves(t, testInfinity)
83}
84
85func isInfinity(x, y *big.Int) bool {
86	return x.Sign() == 0 && y.Sign() == 0
87}
88
89func testInfinity(t *testing.T, curve Curve) {
90	x0, y0 := new(big.Int), new(big.Int)
91	xG, yG := curve.Params().Gx, curve.Params().Gy
92
93	if !isInfinity(curve.ScalarMult(xG, yG, curve.Params().N.Bytes())) {
94		t.Errorf("x^q != ∞")
95	}
96	if !isInfinity(curve.ScalarMult(xG, yG, []byte{0})) {
97		t.Errorf("x^0 != ∞")
98	}
99
100	if !isInfinity(curve.ScalarMult(x0, y0, []byte{1, 2, 3})) {
101		t.Errorf("∞^k != ∞")
102	}
103	if !isInfinity(curve.ScalarMult(x0, y0, []byte{0})) {
104		t.Errorf("∞^0 != ∞")
105	}
106
107	if !isInfinity(curve.ScalarBaseMult(curve.Params().N.Bytes())) {
108		t.Errorf("b^q != ∞")
109	}
110	if !isInfinity(curve.ScalarBaseMult([]byte{0})) {
111		t.Errorf("b^0 != ∞")
112	}
113
114	if !isInfinity(curve.Double(x0, y0)) {
115		t.Errorf("2∞ != ∞")
116	}
117	// There is no other point of order two on the NIST curves (as they have
118	// cofactor one), so Double can't otherwise return the point at infinity.
119
120	nMinusOne := new(big.Int).Sub(curve.Params().N, big.NewInt(1))
121	x, y := curve.ScalarMult(xG, yG, nMinusOne.Bytes())
122	x, y = curve.Add(x, y, xG, yG)
123	if !isInfinity(x, y) {
124		t.Errorf("x^(q-1) + x != ∞")
125	}
126	x, y = curve.Add(xG, yG, x0, y0)
127	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
128		t.Errorf("x+∞ != x")
129	}
130	x, y = curve.Add(x0, y0, xG, yG)
131	if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 {
132		t.Errorf("∞+x != x")
133	}
134
135	if curve.IsOnCurve(x0, y0) {
136		t.Errorf("IsOnCurve(∞) == true")
137	}
138
139	if xx, yy := Unmarshal(curve, Marshal(curve, x0, y0)); xx != nil || yy != nil {
140		t.Errorf("Unmarshal(Marshal(∞)) did not return an error")
141	}
142	// We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are
143	// two valid points with x = 0.
144	if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil {
145		t.Errorf("Unmarshal(∞) did not return an error")
146	}
147	byteLen := (curve.Params().BitSize + 7) / 8
148	buf := make([]byte, byteLen*2+1)
149	buf[0] = 4 // Uncompressed format.
150	if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil {
151		t.Errorf("Unmarshal((0,0)) did not return an error")
152	}
153}
154
155func TestMarshal(t *testing.T) {
156	t.Parallel()
157	testAllCurves(t, func(t *testing.T, curve Curve) {
158		_, x, y, err := GenerateKey(curve, rand.Reader)
159		if err != nil {
160			t.Fatal(err)
161		}
162		serialized := Marshal(curve, x, y)
163		xx, yy := Unmarshal(curve, serialized)
164		if xx == nil {
165			t.Fatal("failed to unmarshal")
166		}
167		if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
168			t.Fatal("unmarshal returned different values")
169		}
170	})
171}
172
173func TestUnmarshalToLargeCoordinates(t *testing.T) {
174	t.Parallel()
175	// See https://golang.org/issues/20482.
176	testAllCurves(t, testUnmarshalToLargeCoordinates)
177}
178
179func testUnmarshalToLargeCoordinates(t *testing.T, curve Curve) {
180	p := curve.Params().P
181	byteLen := (p.BitLen() + 7) / 8
182
183	// Set x to be greater than curve's parameter P – specifically, to P+5.
184	// Set y to mod_sqrt(x^3 - 3x + B)) so that (x mod P = 5 , y) is on the
185	// curve.
186	x := new(big.Int).Add(p, big.NewInt(5))
187	y := curve.Params().polynomial(x)
188	y.ModSqrt(y, p)
189
190	invalid := make([]byte, byteLen*2+1)
191	invalid[0] = 4 // uncompressed encoding
192	x.FillBytes(invalid[1 : 1+byteLen])
193	y.FillBytes(invalid[1+byteLen:])
194
195	if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
196		t.Errorf("Unmarshal accepts invalid X coordinate")
197	}
198
199	if curve == p256 {
200		// This is a point on the curve with a small y value, small enough that
201		// we can add p and still be within 32 bytes.
202		x, _ = new(big.Int).SetString("31931927535157963707678568152204072984517581467226068221761862915403492091210", 10)
203		y, _ = new(big.Int).SetString("5208467867388784005506817585327037698770365050895731383201516607147", 10)
204		y.Add(y, p)
205
206		if p.Cmp(y) > 0 || y.BitLen() != 256 {
207			t.Fatal("y not within expected range")
208		}
209
210		// marshal
211		x.FillBytes(invalid[1 : 1+byteLen])
212		y.FillBytes(invalid[1+byteLen:])
213
214		if X, Y := Unmarshal(curve, invalid); X != nil || Y != nil {
215			t.Errorf("Unmarshal accepts invalid Y coordinate")
216		}
217	}
218}
219
220// TestInvalidCoordinates tests big.Int values that are not valid field elements
221// (negative or bigger than P). They are expected to return false from
222// IsOnCurve, all other behavior is undefined.
223func TestInvalidCoordinates(t *testing.T) {
224	t.Parallel()
225	testAllCurves(t, testInvalidCoordinates)
226}
227
228func testInvalidCoordinates(t *testing.T, curve Curve) {
229	checkIsOnCurveFalse := func(name string, x, y *big.Int) {
230		if curve.IsOnCurve(x, y) {
231			t.Errorf("IsOnCurve(%s) unexpectedly returned true", name)
232		}
233	}
234
235	p := curve.Params().P
236	_, x, y, _ := GenerateKey(curve, rand.Reader)
237	xx, yy := new(big.Int), new(big.Int)
238
239	// Check if the sign is getting dropped.
240	xx.Neg(x)
241	checkIsOnCurveFalse("-x, y", xx, y)
242	yy.Neg(y)
243	checkIsOnCurveFalse("x, -y", x, yy)
244
245	// Check if negative values are reduced modulo P.
246	xx.Sub(x, p)
247	checkIsOnCurveFalse("x-P, y", xx, y)
248	yy.Sub(y, p)
249	checkIsOnCurveFalse("x, y-P", x, yy)
250
251	// Check if positive values are reduced modulo P.
252	xx.Add(x, p)
253	checkIsOnCurveFalse("x+P, y", xx, y)
254	yy.Add(y, p)
255	checkIsOnCurveFalse("x, y+P", x, yy)
256
257	// Check if the overflow is dropped.
258	xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535))
259	checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y)
260	yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535))
261	checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy)
262
263	// Check if P is treated like zero (if possible).
264	// y^2 = x^3 - 3x + B
265	// y = mod_sqrt(x^3 - 3x + B)
266	// y = mod_sqrt(B) if x = 0
267	// If there is no modsqrt, there is no point with x = 0, can't test x = P.
268	if yy := new(big.Int).ModSqrt(curve.Params().B, p); yy != nil {
269		if !curve.IsOnCurve(big.NewInt(0), yy) {
270			t.Fatal("(0, mod_sqrt(B)) is not on the curve?")
271		}
272		checkIsOnCurveFalse("P, y", p, yy)
273	}
274}
275
276func TestMarshalCompressed(t *testing.T) {
277	t.Parallel()
278	t.Run("P-256/03", func(t *testing.T) {
279		data, _ := hex.DecodeString("031e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
280		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
281		y, _ := new(big.Int).SetString("66200849279091436748794323380043701364391950689352563629885086590854940586447", 10)
282		testMarshalCompressed(t, P256(), x, y, data)
283	})
284	t.Run("P-256/02", func(t *testing.T) {
285		data, _ := hex.DecodeString("021e3987d9f9ea9d7dd7155a56a86b2009e1e0ab332f962d10d8beb6406ab1ad79")
286		x, _ := new(big.Int).SetString("13671033352574878777044637384712060483119675368076128232297328793087057702265", 10)
287		y, _ := new(big.Int).SetString("49591239931264812013903123569363872165694192725937750565648544718012157267504", 10)
288		testMarshalCompressed(t, P256(), x, y, data)
289	})
290
291	t.Run("Invalid", func(t *testing.T) {
292		data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535")
293		X, Y := UnmarshalCompressed(P256(), data)
294		if X != nil || Y != nil {
295			t.Error("expected an error for invalid encoding")
296		}
297	})
298
299	if testing.Short() {
300		t.Skip("skipping other curves on short test")
301	}
302
303	testAllCurves(t, func(t *testing.T, curve Curve) {
304		_, x, y, err := GenerateKey(curve, rand.Reader)
305		if err != nil {
306			t.Fatal(err)
307		}
308		testMarshalCompressed(t, curve, x, y, nil)
309	})
310
311}
312
313func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
314	if !curve.IsOnCurve(x, y) {
315		t.Fatal("invalid test point")
316	}
317	got := MarshalCompressed(curve, x, y)
318	if want != nil && !bytes.Equal(got, want) {
319		t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
320	}
321
322	X, Y := UnmarshalCompressed(curve, got)
323	if X == nil || Y == nil {
324		t.Fatalf("UnmarshalCompressed failed unexpectedly")
325	}
326
327	if !curve.IsOnCurve(X, Y) {
328		t.Error("UnmarshalCompressed returned a point not on the curve")
329	}
330	if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
331		t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
332	}
333}
334
335func TestLargeIsOnCurve(t *testing.T) {
336	t.Parallel()
337	testAllCurves(t, func(t *testing.T, curve Curve) {
338		large := big.NewInt(1)
339		large.Lsh(large, 1000)
340		if curve.IsOnCurve(large, large) {
341			t.Errorf("(2^1000, 2^1000) is reported on the curve")
342		}
343	})
344}
345
346func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
347	tests := []struct {
348		name  string
349		curve Curve
350	}{
351		{"P256", P256()},
352		{"P224", P224()},
353		{"P384", P384()},
354		{"P521", P521()},
355	}
356	for _, test := range tests {
357		curve := test.curve
358		b.Run(test.name, func(b *testing.B) {
359			f(b, curve)
360		})
361	}
362}
363
364func BenchmarkScalarBaseMult(b *testing.B) {
365	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
366		priv, _, _, _ := GenerateKey(curve, rand.Reader)
367		b.ReportAllocs()
368		b.ResetTimer()
369		for i := 0; i < b.N; i++ {
370			x, _ := curve.ScalarBaseMult(priv)
371			// Prevent the compiler from optimizing out the operation.
372			priv[0] ^= byte(x.Bits()[0])
373		}
374	})
375}
376
377func BenchmarkScalarMult(b *testing.B) {
378	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
379		_, x, y, _ := GenerateKey(curve, rand.Reader)
380		priv, _, _, _ := GenerateKey(curve, rand.Reader)
381		b.ReportAllocs()
382		b.ResetTimer()
383		for i := 0; i < b.N; i++ {
384			x, y = curve.ScalarMult(x, y, priv)
385		}
386	})
387}
388
389func BenchmarkMarshalUnmarshal(b *testing.B) {
390	benchmarkAllCurves(b, func(b *testing.B, curve Curve) {
391		_, x, y, _ := GenerateKey(curve, rand.Reader)
392		b.Run("Uncompressed", func(b *testing.B) {
393			b.ReportAllocs()
394			for i := 0; i < b.N; i++ {
395				buf := Marshal(curve, x, y)
396				xx, yy := Unmarshal(curve, buf)
397				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
398					b.Error("Unmarshal output different from Marshal input")
399				}
400			}
401		})
402		b.Run("Compressed", func(b *testing.B) {
403			b.ReportAllocs()
404			for i := 0; i < b.N; i++ {
405				buf := MarshalCompressed(curve, x, y)
406				xx, yy := UnmarshalCompressed(curve, buf)
407				if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
408					b.Error("Unmarshal output different from Marshal input")
409				}
410			}
411		})
412	})
413}
414