1// Copyright 2021 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package elliptic
6
7import (
8	"math/big"
9	"testing"
10)
11
12type scalarMultTest struct {
13	k          string
14	xIn, yIn   string
15	xOut, yOut string
16}
17
18var p256MultTests = []scalarMultTest{
19	{
20		"2a265f8bcbdcaf94d58519141e578124cb40d64a501fba9c11847b28965bc737",
21		"023819813ac969847059028ea88a1f30dfbcde03fc791d3a252c6b41211882ea",
22		"f93e4ae433cc12cf2a43fc0ef26400c0e125508224cdb649380f25479148a4ad",
23		"4d4de80f1534850d261075997e3049321a0864082d24a917863366c0724f5ae3",
24		"a22d2b7f7818a3563e0f7a76c9bf0921ac55e06e2e4d11795b233824b1db8cc0",
25	},
26	{
27		"313f72ff9fe811bf573176231b286a3bdb6f1b14e05c40146590727a71c3bccd",
28		"cc11887b2d66cbae8f4d306627192522932146b42f01d3c6f92bd5c8ba739b06",
29		"a2f08a029cd06b46183085bae9248b0ed15b70280c7ef13a457f5af382426031",
30		"831c3f6b5f762d2f461901577af41354ac5f228c2591f84f8a6e51e2e3f17991",
31		"93f90934cd0ef2c698cc471c60a93524e87ab31ca2412252337f364513e43684",
32	},
33}
34
35func TestP256BaseMult(t *testing.T) {
36	p256 := P256()
37	p256Generic := genericParamsForCurve(p256)
38
39	scalars := make([]*big.Int, 0, len(p224BaseMultTests)+1)
40	for _, e := range p224BaseMultTests {
41		k, _ := new(big.Int).SetString(e.k, 10)
42		scalars = append(scalars, k)
43	}
44	k := new(big.Int).SetInt64(1)
45	k.Lsh(k, 500)
46	scalars = append(scalars, k)
47
48	for i, k := range scalars {
49		x, y := p256.ScalarBaseMult(k.Bytes())
50		x2, y2 := p256Generic.ScalarBaseMult(k.Bytes())
51		if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 {
52			t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, x, y, x2, y2)
53		}
54
55		if testing.Short() && i > 5 {
56			break
57		}
58	}
59}
60
61func TestP256Mult(t *testing.T) {
62	p256 := P256()
63	for i, e := range p256MultTests {
64		x, _ := new(big.Int).SetString(e.xIn, 16)
65		y, _ := new(big.Int).SetString(e.yIn, 16)
66		k, _ := new(big.Int).SetString(e.k, 16)
67		expectedX, _ := new(big.Int).SetString(e.xOut, 16)
68		expectedY, _ := new(big.Int).SetString(e.yOut, 16)
69
70		xx, yy := p256.ScalarMult(x, y, k.Bytes())
71		if xx.Cmp(expectedX) != 0 || yy.Cmp(expectedY) != 0 {
72			t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, xx, yy, expectedX, expectedY)
73		}
74	}
75}
76
77type synthCombinedMult struct {
78	Curve
79}
80
81func (s synthCombinedMult) CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int) {
82	x1, y1 := s.ScalarBaseMult(baseScalar)
83	x2, y2 := s.ScalarMult(bigX, bigY, scalar)
84	return s.Add(x1, y1, x2, y2)
85}
86
87func TestP256CombinedMult(t *testing.T) {
88	type combinedMult interface {
89		Curve
90		CombinedMult(bigX, bigY *big.Int, baseScalar, scalar []byte) (x, y *big.Int)
91	}
92
93	p256, ok := P256().(combinedMult)
94	if !ok {
95		p256 = &synthCombinedMult{P256()}
96	}
97
98	gx := p256.Params().Gx
99	gy := p256.Params().Gy
100
101	zero := make([]byte, 32)
102	one := make([]byte, 32)
103	one[31] = 1
104	two := make([]byte, 32)
105	two[31] = 2
106
107	// 0×G + 0×G = ∞
108	x, y := p256.CombinedMult(gx, gy, zero, zero)
109	if x.Sign() != 0 || y.Sign() != 0 {
110		t.Errorf("0×G + 0×G = (%d, %d), should be ∞", x, y)
111	}
112
113	// 1×G + 0×G = G
114	x, y = p256.CombinedMult(gx, gy, one, zero)
115	if x.Cmp(gx) != 0 || y.Cmp(gy) != 0 {
116		t.Errorf("1×G + 0×G = (%d, %d), should be (%d, %d)", x, y, gx, gy)
117	}
118
119	// 0×G + 1×G = G
120	x, y = p256.CombinedMult(gx, gy, zero, one)
121	if x.Cmp(gx) != 0 || y.Cmp(gy) != 0 {
122		t.Errorf("0×G + 1×G = (%d, %d), should be (%d, %d)", x, y, gx, gy)
123	}
124
125	// 1×G + 1×G = 2×G
126	x, y = p256.CombinedMult(gx, gy, one, one)
127	ggx, ggy := p256.ScalarBaseMult(two)
128	if x.Cmp(ggx) != 0 || y.Cmp(ggy) != 0 {
129		t.Errorf("1×G + 1×G = (%d, %d), should be (%d, %d)", x, y, ggx, ggy)
130	}
131
132	minusOne := new(big.Int).Sub(p256.Params().N, big.NewInt(1))
133	// 1×G + (-1)×G = ∞
134	x, y = p256.CombinedMult(gx, gy, one, minusOne.Bytes())
135	if x.Sign() != 0 || y.Sign() != 0 {
136		t.Errorf("1×G + (-1)×G = (%d, %d), should be ∞", x, y)
137	}
138}
139
140func TestIssue52075(t *testing.T) {
141	Gx, Gy := P256().Params().Gx, P256().Params().Gy
142	scalar := make([]byte, 33)
143	scalar[32] = 1
144	x, y := P256().ScalarBaseMult(scalar)
145	if x.Cmp(Gx) != 0 || y.Cmp(Gy) != 0 {
146		t.Errorf("unexpected output (%v,%v)", x, y)
147	}
148	x, y = P256().ScalarMult(Gx, Gy, scalar)
149	if x.Cmp(Gx) != 0 || y.Cmp(Gy) != 0 {
150		t.Errorf("unexpected output (%v,%v)", x, y)
151	}
152}
153