1// Copyright 2023 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package x509
6
7import (
8	"encoding"
9	"encoding/asn1"
10	"math"
11	"testing"
12)
13
14var oidTests = []struct {
15	raw   []byte
16	valid bool
17	str   string
18	ints  []uint64
19}{
20	{[]byte{}, false, "", nil},
21	{[]byte{0x80, 0x01}, false, "", nil},
22	{[]byte{0x01, 0x80, 0x01}, false, "", nil},
23
24	{[]byte{1, 2, 3}, true, "0.1.2.3", []uint64{0, 1, 2, 3}},
25	{[]byte{41, 2, 3}, true, "1.1.2.3", []uint64{1, 1, 2, 3}},
26	{[]byte{86, 2, 3}, true, "2.6.2.3", []uint64{2, 6, 2, 3}},
27
28	{[]byte{41, 255, 255, 255, 127}, true, "1.1.268435455", []uint64{1, 1, 268435455}},
29	{[]byte{41, 0x87, 255, 255, 255, 127}, true, "1.1.2147483647", []uint64{1, 1, 2147483647}},
30	{[]byte{41, 255, 255, 255, 255, 127}, true, "1.1.34359738367", []uint64{1, 1, 34359738367}},
31	{[]byte{42, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.2.9223372036854775807", []uint64{1, 2, 9223372036854775807}},
32	{[]byte{43, 0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.3.18446744073709551615", []uint64{1, 3, 18446744073709551615}},
33	{[]byte{44, 0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "1.4.36893488147419103231", nil},
34	{[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.1180591620717411303423", nil},
35	{[]byte{85, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.5.19342813113834066795298815", nil},
36
37	{[]byte{255, 255, 255, 127}, true, "2.268435375", []uint64{2, 268435375}},
38	{[]byte{0x87, 255, 255, 255, 127}, true, "2.2147483567", []uint64{2, 2147483567}},
39	{[]byte{255, 127}, true, "2.16303", []uint64{2, 16303}},
40	{[]byte{255, 255, 255, 255, 127}, true, "2.34359738287", []uint64{2, 34359738287}},
41	{[]byte{255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.9223372036854775727", []uint64{2, 9223372036854775727}},
42	{[]byte{0x81, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.18446744073709551535", []uint64{2, 18446744073709551535}},
43	{[]byte{0x83, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.36893488147419103151", nil},
44	{[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.1180591620717411303343", nil},
45	{[]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 127}, true, "2.19342813113834066795298735", nil},
46
47	{[]byte{41, 0x80 | 66, 0x80 | 44, 0x80 | 11, 33}, true, "1.1.139134369", []uint64{1, 1, 139134369}},
48	{[]byte{0x80 | 66, 0x80 | 44, 0x80 | 11, 33}, true, "2.139134289", []uint64{2, 139134289}},
49}
50
51func TestOID(t *testing.T) {
52	for _, v := range oidTests {
53		oid, ok := newOIDFromDER(v.raw)
54		if ok != v.valid {
55			t.Errorf("newOIDFromDER(%v) = (%v, %v); want = (OID, %v)", v.raw, oid, ok, v.valid)
56			continue
57		}
58
59		if !ok {
60			continue
61		}
62
63		if str := oid.String(); str != v.str {
64			t.Errorf("(%#v).String() = %v, want; %v", oid, str, v.str)
65		}
66
67		var asn1OID asn1.ObjectIdentifier
68		for _, v := range v.ints {
69			if v > math.MaxInt32 {
70				asn1OID = nil
71				break
72			}
73			asn1OID = append(asn1OID, int(v))
74		}
75
76		o, ok := oid.toASN1OID()
77		if shouldOk := asn1OID != nil; shouldOk != ok {
78			t.Errorf("(%#v).toASN1OID() = (%v, %v); want = (%v, %v)", oid, o, ok, asn1OID, shouldOk)
79			continue
80		}
81
82		if asn1OID != nil && !o.Equal(asn1OID) {
83			t.Errorf("(%#v).toASN1OID() = (%v, true); want = (%v, true)", oid, o, asn1OID)
84		}
85
86		if v.ints != nil {
87			oid2, err := OIDFromInts(v.ints)
88			if err != nil {
89				t.Errorf("OIDFromInts(%v) = (%v, %v); want = (%v, nil)", v.ints, oid2, err, oid)
90			}
91			if !oid2.Equal(oid) {
92				t.Errorf("OIDFromInts(%v) = (%v, nil); want = (%v, nil)", v.ints, oid2, oid)
93			}
94		}
95	}
96}
97
98func TestInvalidOID(t *testing.T) {
99	cases := []struct {
100		str  string
101		ints []uint64
102	}{
103		{str: "", ints: []uint64{}},
104		{str: "1", ints: []uint64{1}},
105		{str: "3", ints: []uint64{3}},
106		{str: "3.100.200", ints: []uint64{3, 100, 200}},
107		{str: "1.81", ints: []uint64{1, 81}},
108		{str: "1.81.200", ints: []uint64{1, 81, 200}},
109	}
110
111	for _, tt := range cases {
112		oid, err := OIDFromInts(tt.ints)
113		if err == nil {
114			t.Errorf("OIDFromInts(%v) = (%v, %v); want = (OID{}, %v)", tt.ints, oid, err, errInvalidOID)
115		}
116
117		oid2, err := ParseOID(tt.str)
118		if err == nil {
119			t.Errorf("ParseOID(%v) = (%v, %v); want = (OID{}, %v)", tt.str, oid2, err, errInvalidOID)
120		}
121
122		var oid3 OID
123		err = oid3.UnmarshalText([]byte(tt.str))
124		if err == nil {
125			t.Errorf("(*OID).UnmarshalText(%v) = (%v, %v); want = (OID{}, %v)", tt.str, oid3, err, errInvalidOID)
126		}
127	}
128}
129
130func TestOIDEqual(t *testing.T) {
131	var cases = []struct {
132		oid  OID
133		oid2 OID
134		eq   bool
135	}{
136		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: mustNewOIDFromInts(t, []uint64{1, 2, 3}), eq: true},
137		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: mustNewOIDFromInts(t, []uint64{1, 2, 4}), eq: false},
138		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: mustNewOIDFromInts(t, []uint64{1, 2, 3, 4}), eq: false},
139		{oid: mustNewOIDFromInts(t, []uint64{2, 33, 22}), oid2: mustNewOIDFromInts(t, []uint64{2, 33, 23}), eq: false},
140		{oid: OID{}, oid2: OID{}, eq: true},
141		{oid: OID{}, oid2: mustNewOIDFromInts(t, []uint64{2, 33, 23}), eq: false},
142	}
143
144	for _, tt := range cases {
145		if eq := tt.oid.Equal(tt.oid2); eq != tt.eq {
146			t.Errorf("(%v).Equal(%v) = %v, want %v", tt.oid, tt.oid2, eq, tt.eq)
147		}
148	}
149}
150
151var (
152	_ encoding.BinaryMarshaler   = OID{}
153	_ encoding.BinaryUnmarshaler = new(OID)
154	_ encoding.TextMarshaler     = OID{}
155	_ encoding.TextUnmarshaler   = new(OID)
156)
157
158func TestOIDMarshal(t *testing.T) {
159	cases := []struct {
160		in  string
161		out OID
162		err error
163	}{
164		{in: "", err: errInvalidOID},
165		{in: "0", err: errInvalidOID},
166		{in: "1", err: errInvalidOID},
167		{in: ".1", err: errInvalidOID},
168		{in: ".1.", err: errInvalidOID},
169		{in: "1.", err: errInvalidOID},
170		{in: "1..", err: errInvalidOID},
171		{in: "1.2.", err: errInvalidOID},
172		{in: "1.2.333.", err: errInvalidOID},
173		{in: "1.2.333..", err: errInvalidOID},
174		{in: "1.2..", err: errInvalidOID},
175		{in: "+1.2", err: errInvalidOID},
176		{in: "-1.2", err: errInvalidOID},
177		{in: "1.-2", err: errInvalidOID},
178		{in: "1.2.+333", err: errInvalidOID},
179	}
180
181	for _, v := range oidTests {
182		oid, ok := newOIDFromDER(v.raw)
183		if !ok {
184			continue
185		}
186		cases = append(cases, struct {
187			in  string
188			out OID
189			err error
190		}{
191			in:  v.str,
192			out: oid,
193			err: nil,
194		})
195	}
196
197	for _, tt := range cases {
198		o, err := ParseOID(tt.in)
199		if err != tt.err {
200			t.Errorf("ParseOID(%q) = %v; want = %v", tt.in, err, tt.err)
201			continue
202		}
203
204		var o2 OID
205		err = o2.UnmarshalText([]byte(tt.in))
206		if err != tt.err {
207			t.Errorf("(*OID).UnmarshalText(%q) = %v; want = %v", tt.in, err, tt.err)
208			continue
209		}
210
211		if err != nil {
212			continue
213		}
214
215		if !o.Equal(tt.out) {
216			t.Errorf("(*OID).UnmarshalText(%q) = %v; want = %v", tt.in, o, tt.out)
217			continue
218		}
219
220		if !o2.Equal(tt.out) {
221			t.Errorf("ParseOID(%q) = %v; want = %v", tt.in, o2, tt.out)
222			continue
223		}
224
225		marshalled, err := o.MarshalText()
226		if string(marshalled) != tt.in || err != nil {
227			t.Errorf("(%#v).MarshalText() = (%v, %v); want = (%v, nil)", o, string(marshalled), err, tt.in)
228			continue
229		}
230
231		binary, err := o.MarshalBinary()
232		if err != nil {
233			t.Errorf("(%#v).MarshalBinary() = %v; want = nil", o, err)
234		}
235
236		var o3 OID
237		if err := o3.UnmarshalBinary(binary); err != nil {
238			t.Errorf("(*OID).UnmarshalBinary(%v) = %v; want = nil", binary, err)
239		}
240
241		if !o3.Equal(tt.out) {
242			t.Errorf("(*OID).UnmarshalBinary(%v) = %v; want = %v", binary, o3, tt.out)
243			continue
244		}
245	}
246}
247
248func TestOIDEqualASN1OID(t *testing.T) {
249	maxInt32PlusOne := int64(math.MaxInt32) + 1
250	var cases = []struct {
251		oid  OID
252		oid2 asn1.ObjectIdentifier
253		eq   bool
254	}{
255		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: asn1.ObjectIdentifier{1, 2, 3}, eq: true},
256		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: asn1.ObjectIdentifier{1, 2, 4}, eq: false},
257		{oid: mustNewOIDFromInts(t, []uint64{1, 2, 3}), oid2: asn1.ObjectIdentifier{1, 2, 3, 4}, eq: false},
258		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 22}), oid2: asn1.ObjectIdentifier{1, 33, 23}, eq: false},
259		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 23}), oid2: asn1.ObjectIdentifier{1, 33, 22}, eq: false},
260		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 127}), oid2: asn1.ObjectIdentifier{1, 33, 127}, eq: true},
261		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 128}), oid2: asn1.ObjectIdentifier{1, 33, 127}, eq: false},
262		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 128}), oid2: asn1.ObjectIdentifier{1, 33, 128}, eq: true},
263		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 129}), oid2: asn1.ObjectIdentifier{1, 33, 129}, eq: true},
264		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 128}), oid2: asn1.ObjectIdentifier{1, 33, 129}, eq: false},
265		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 129}), oid2: asn1.ObjectIdentifier{1, 33, 128}, eq: false},
266		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 255}), oid2: asn1.ObjectIdentifier{1, 33, 255}, eq: true},
267		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 256}), oid2: asn1.ObjectIdentifier{1, 33, 256}, eq: true},
268		{oid: mustNewOIDFromInts(t, []uint64{2, 33, 257}), oid2: asn1.ObjectIdentifier{2, 33, 256}, eq: false},
269		{oid: mustNewOIDFromInts(t, []uint64{2, 33, 256}), oid2: asn1.ObjectIdentifier{2, 33, 257}, eq: false},
270
271		{oid: mustNewOIDFromInts(t, []uint64{1, 33}), oid2: asn1.ObjectIdentifier{1, 33, math.MaxInt32}, eq: false},
272		{oid: mustNewOIDFromInts(t, []uint64{1, 33, math.MaxInt32}), oid2: asn1.ObjectIdentifier{1, 33}, eq: false},
273		{oid: mustNewOIDFromInts(t, []uint64{1, 33, math.MaxInt32}), oid2: asn1.ObjectIdentifier{1, 33, math.MaxInt32}, eq: true},
274		{
275			oid:  mustNewOIDFromInts(t, []uint64{1, 33, math.MaxInt32 + 1}),
276			oid2: asn1.ObjectIdentifier{1, 33 /*convert to int, so that it compiles on 32bit*/, int(maxInt32PlusOne)},
277			eq:   false,
278		},
279
280		{oid: mustNewOIDFromInts(t, []uint64{1, 33, 256}), oid2: asn1.ObjectIdentifier{}, eq: false},
281		{oid: OID{}, oid2: asn1.ObjectIdentifier{1, 33, 256}, eq: false},
282		{oid: OID{}, oid2: asn1.ObjectIdentifier{}, eq: false},
283	}
284
285	for _, tt := range cases {
286		if eq := tt.oid.EqualASN1OID(tt.oid2); eq != tt.eq {
287			t.Errorf("(%v).EqualASN1OID(%v) = %v, want %v", tt.oid, tt.oid2, eq, tt.eq)
288		}
289	}
290}
291
292func TestOIDUnmarshalBinary(t *testing.T) {
293	for _, tt := range oidTests {
294		var o OID
295		err := o.UnmarshalBinary(tt.raw)
296
297		expectErr := errInvalidOID
298		if tt.valid {
299			expectErr = nil
300		}
301
302		if err != expectErr {
303			t.Errorf("(o *OID).UnmarshalBinary(%v) = %v; want = %v; (o = %v)", tt.raw, err, expectErr, o)
304		}
305	}
306}
307
308func BenchmarkOIDMarshalUnmarshalText(b *testing.B) {
309	oid := mustNewOIDFromInts(b, []uint64{1, 2, 3, 9999, 1024})
310	for range b.N {
311		text, err := oid.MarshalText()
312		if err != nil {
313			b.Fatal(err)
314		}
315		var o OID
316		if err := o.UnmarshalText(text); err != nil {
317			b.Fatal(err)
318		}
319	}
320}
321
322func mustNewOIDFromInts(t testing.TB, ints []uint64) OID {
323	oid, err := OIDFromInts(ints)
324	if err != nil {
325		t.Fatalf("OIDFromInts(%v) unexpected error: %v", ints, err)
326	}
327	return oid
328}
329