xref: /aosp_15_r20/external/tink/go/signature/rsassapkcs1_signer_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package signature_test
18
19import (
20	"encoding/hex"
21	"math/big"
22	"testing"
23
24	"github.com/google/go-cmp/cmp"
25	"google.golang.org/protobuf/proto"
26	"github.com/google/tink/go/core/registry"
27	internal "github.com/google/tink/go/internal/signature"
28	"github.com/google/tink/go/subtle/random"
29	"github.com/google/tink/go/tink"
30	cpb "github.com/google/tink/go/proto/common_go_proto"
31	rsassapkcs1pb "github.com/google/tink/go/proto/rsa_ssa_pkcs1_go_proto"
32	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
33)
34
35const (
36	rsaPKCS1PrivateKeyTypeURL = "type.googleapis.com/google.crypto.tink.RsaSsaPkcs1PrivateKey"
37)
38
39func TestRSASSAPKCS1SignerKeyManagerDoesSupport(t *testing.T) {
40	skm, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
41	if err != nil {
42		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
43	}
44	if !skm.DoesSupport(rsaPKCS1PrivateKeyTypeURL) {
45		t.Errorf("DoesSupport(%q) = false, want true", rsaPKCS1PrivateKeyTypeURL)
46	}
47	if skm.DoesSupport("not.valid.type") {
48		t.Errorf("DoesSupport(%q) = true, want false", "not.valid.type")
49	}
50}
51
52func TestRSASSAPKCS1SignerTypeURL(t *testing.T) {
53	skm, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
54	if err != nil {
55		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
56	}
57	if skm.TypeURL() != rsaPKCS1PrivateKeyTypeURL {
58		t.Errorf("TypeURL() = %q, want %q", skm.TypeURL(), rsaPKCS1PrivateKeyTypeURL)
59	}
60}
61
62func TestRSASSAPKCS1SignerKeyManagerPublicKeyData(t *testing.T) {
63	skm, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
64	if err != nil {
65		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
66	}
67	vkm, err := registry.GetKeyManager(rsaPKCS1PublicTypeURL)
68	if err != nil {
69		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PublicTypeURL, err)
70	}
71	privKey, err := makeValidRSAPKCS1Key()
72	if err != nil {
73		t.Fatalf("makeValidRSAPKCS1Key() err = %v, want nil", err)
74	}
75	serializedPrivate, err := proto.Marshal(privKey)
76	if err != nil {
77		t.Fatalf("proto.Marshal() err = %v, want nil", err)
78	}
79	got, err := skm.(registry.PrivateKeyManager).PublicKeyData(serializedPrivate)
80	if err != nil {
81		t.Fatalf("PublicKeyData() err = %v, want nil", err)
82	}
83	if got.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PUBLIC {
84		t.Errorf("GetKeyMaterialType() = %q, want %q", got.GetKeyMaterialType(), tinkpb.KeyData_ASYMMETRIC_PUBLIC)
85	}
86	if got.GetTypeUrl() != rsaPKCS1PublicTypeURL {
87		t.Errorf("GetTypeUrl() = %q, want %q", got.GetTypeUrl(), rsaPKCS1PublicTypeURL)
88	}
89	if _, err := vkm.Primitive(got.GetValue()); err != nil {
90		t.Errorf("Primitive() err = %v, want nil", err)
91	}
92}
93
94func TestRSASSAPKCS1SignerKeyManagerPrimitiveSignVerify(t *testing.T) {
95	skm, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
96	if err != nil {
97		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
98	}
99	privKey, err := makeValidRSAPKCS1Key()
100	if err != nil {
101		t.Fatalf("makeValidRSAPKCS1Key() err = %v, want nil", err)
102	}
103	serializedPrivate, err := proto.Marshal(privKey)
104	if err != nil {
105		t.Fatalf("proto.Marshal() err = %v, want nil", err)
106	}
107	p, err := skm.Primitive(serializedPrivate)
108	if err != nil {
109		t.Fatalf("Primitive() err = %v, want nil", err)
110	}
111	signer, ok := p.(*internal.RSA_SSA_PKCS1_Signer)
112	if !ok {
113		t.Fatalf("primitive is not of type RSA_SSA_PKCS1_Signer")
114	}
115	vkm, err := registry.GetKeyManager(rsaPKCS1PublicTypeURL)
116	if err != nil {
117		t.Fatalf("regitry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PublicTypeURL, err)
118	}
119	serializedPublic, err := proto.Marshal(privKey.PublicKey)
120	if err != nil {
121		t.Fatalf("Failed serializing public key proto: %v", err)
122	}
123	p, err = vkm.Primitive(serializedPublic)
124	if err != nil {
125		t.Fatalf("rsaSSAPKCS1VerifierKeyManager.Primitive() failed: %v", err)
126	}
127	v, ok := p.(*internal.RSA_SSA_PKCS1_Verifier)
128	if !ok {
129		t.Fatalf("primitve is not of type RSA_SSA_PKCS1_Verifier")
130	}
131	data := random.GetRandomBytes(1281)
132	signature, err := signer.Sign(data)
133	if err != nil {
134		t.Fatalf("Sign() err = %v, want nil", err)
135	}
136	if err := v.Verify(signature, data); err != nil {
137		t.Fatalf("Verify() err = %v, want nil", err)
138	}
139}
140
141func TestRSASSAPKCS1SignerKeyManagerPrimitiveWithInvalidInputFails(t *testing.T) {
142	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
143	if err != nil {
144		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
145	}
146	validPrivKey, err := makeValidRSAPKCS1Key()
147	if err != nil {
148		t.Fatalf("makeValidRSAPKCS1Key() err = %v, want nil", err)
149	}
150	serializedValidPrivate, err := proto.Marshal(validPrivKey)
151	if err != nil {
152		t.Fatalf("proto.Marshal() err = %v, want nil", err)
153	}
154	if _, err := km.Primitive(serializedValidPrivate); err != nil {
155		t.Fatalf("Primitive(serializedValidPrivate) err = %v, want nil", err)
156	}
157	type testCase struct {
158		name string
159		key  *rsassapkcs1pb.RsaSsaPkcs1PrivateKey
160	}
161	for _, tc := range []testCase{
162		{
163			name: "empty key",
164			key:  &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{},
165		},
166		{
167			name: "nil key",
168			key:  nil,
169		},
170		{
171			name: "invalid version",
172			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
173				Version:   validPrivKey.GetVersion() + 1,
174				PublicKey: validPrivKey.GetPublicKey(),
175				D:         validPrivKey.GetD(),
176				P:         validPrivKey.GetP(),
177				Q:         validPrivKey.GetQ(),
178				Dp:        validPrivKey.GetDp(),
179				Dq:        validPrivKey.GetDq(),
180				Crt:       validPrivKey.GetCrt(),
181			},
182		},
183		{
184			name: "invalid hash algorithm ",
185			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
186				Version: validPrivKey.GetVersion(),
187				PublicKey: &rsassapkcs1pb.RsaSsaPkcs1PublicKey{
188					Version: validPrivKey.GetPublicKey().GetVersion(),
189					E:       validPrivKey.GetPublicKey().GetE(),
190					N:       validPrivKey.GetPublicKey().GetN(),
191					Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
192						HashType: cpb.HashType_SHA224,
193					},
194				},
195				D:   validPrivKey.GetD(),
196				P:   validPrivKey.GetP(),
197				Q:   validPrivKey.GetQ(),
198				Dp:  validPrivKey.GetDp(),
199				Dq:  validPrivKey.GetDq(),
200				Crt: validPrivKey.GetCrt(),
201			},
202		},
203		{
204			name: "invalid modulus",
205			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
206				Version: validPrivKey.GetVersion(),
207				PublicKey: &rsassapkcs1pb.RsaSsaPkcs1PublicKey{
208					Version: validPrivKey.GetPublicKey().GetVersion(),
209					E:       validPrivKey.GetPublicKey().GetE(),
210					N:       []byte{3, 4, 5},
211					Params:  validPrivKey.GetPublicKey().GetParams(),
212				},
213				D:   validPrivKey.GetD(),
214				P:   validPrivKey.GetP(),
215				Q:   validPrivKey.GetQ(),
216				Dp:  validPrivKey.GetDp(),
217				Dq:  validPrivKey.GetDq(),
218				Crt: validPrivKey.GetCrt(),
219			},
220		},
221		{
222			name: "invalid public key exponent",
223			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
224				Version: validPrivKey.GetVersion(),
225				PublicKey: &rsassapkcs1pb.RsaSsaPkcs1PublicKey{
226					Version: validPrivKey.GetPublicKey().GetVersion(),
227					E:       []byte{0x06},
228					N:       validPrivKey.GetPublicKey().GetN(),
229					Params:  validPrivKey.GetPublicKey().GetParams(),
230				},
231				D:   validPrivKey.GetD(),
232				P:   validPrivKey.GetP(),
233				Q:   validPrivKey.GetQ(),
234				Dp:  validPrivKey.GetDp(),
235				Dq:  validPrivKey.GetDq(),
236				Crt: validPrivKey.GetCrt(),
237			},
238		},
239		{
240			name: "invalid private key D value",
241			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
242				Version:   validPrivKey.GetVersion(),
243				PublicKey: validPrivKey.GetPublicKey(),
244				D:         nil,
245				P:         validPrivKey.GetP(),
246				Q:         validPrivKey.GetQ(),
247				Dp:        validPrivKey.GetDp(),
248				Dq:        validPrivKey.GetDq(),
249				Crt:       validPrivKey.GetCrt(),
250			},
251		},
252
253		{
254			name: "invalid private key P value",
255			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
256				Version:   validPrivKey.GetVersion(),
257				PublicKey: validPrivKey.GetPublicKey(),
258				D:         validPrivKey.GetD(),
259				P:         nil,
260				Q:         validPrivKey.GetQ(),
261				Dp:        validPrivKey.GetDp(),
262				Dq:        validPrivKey.GetDq(),
263				Crt:       validPrivKey.GetCrt(),
264			},
265		},
266		{
267			name: "invalid private key Q value",
268			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
269				Version:   validPrivKey.GetVersion(),
270				PublicKey: validPrivKey.GetPublicKey(),
271				D:         validPrivKey.GetD(),
272				P:         validPrivKey.GetP(),
273				Q:         nil,
274				Dp:        validPrivKey.GetDp(),
275				Dq:        validPrivKey.GetDq(),
276				Crt:       validPrivKey.GetCrt(),
277			},
278		},
279		{
280			name: "invalid precomputed Dp values in private key",
281			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
282				Version:   validPrivKey.GetVersion(),
283				PublicKey: validPrivKey.GetPublicKey(),
284				D:         validPrivKey.GetD(),
285				P:         validPrivKey.GetP(),
286				Q:         validPrivKey.GetQ(),
287				Dp:        nil,
288				Dq:        validPrivKey.GetDq(),
289				Crt:       validPrivKey.GetCrt(),
290			},
291		},
292		{
293			name: "invalid precomputed Dq values in private key",
294			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
295				Version:   validPrivKey.GetVersion(),
296				PublicKey: validPrivKey.GetPublicKey(),
297				D:         validPrivKey.GetD(),
298				P:         validPrivKey.GetP(),
299				Q:         validPrivKey.GetQ(),
300				Dp:        validPrivKey.GetDp(),
301				Dq:        nil,
302				Crt:       validPrivKey.GetCrt(),
303			},
304		},
305		{
306			name: "invalid precomputed Crt values in private key",
307			key: &rsassapkcs1pb.RsaSsaPkcs1PrivateKey{
308				Version:   validPrivKey.GetVersion(),
309				PublicKey: validPrivKey.GetPublicKey(),
310				D:         validPrivKey.GetD(),
311				P:         validPrivKey.GetP(),
312				Q:         validPrivKey.GetQ(),
313				Dp:        validPrivKey.GetDp(),
314				Dq:        validPrivKey.GetDq(),
315				Crt:       nil,
316			},
317		},
318	} {
319		t.Run(tc.name, func(t *testing.T) {
320			serializedKey, err := proto.Marshal(tc.key)
321			if err != nil {
322				t.Fatalf("proto.Marshal() err = %v, want nil", err)
323			}
324			if _, err := km.Primitive(serializedKey); err == nil {
325				t.Errorf("Primitive() err = nil, want error")
326			}
327			if _, err := km.(registry.PrivateKeyManager).PublicKeyData(serializedKey); err == nil {
328				t.Errorf("PublicKeyData() err = nil, want error")
329			}
330		})
331	}
332}
333
334func TestRSASSAPKCS1SignerKeyManagerPrimitiveWithCorruptedKeyFails(t *testing.T) {
335	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
336	if err != nil {
337		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
338	}
339	corruptedPrivKey, err := makeValidRSAPKCS1Key()
340	if err != nil {
341		t.Fatalf("makeValidRSAPKCS1Key() err = %v, want nil", err)
342	}
343	corruptedPrivKey.P[5] = byte(uint8(corruptedPrivKey.P[5] + 1))
344	corruptedPrivKey.P[10] = byte(uint8(corruptedPrivKey.P[10] + 1))
345	serializedCorruptedPrivate, err := proto.Marshal(corruptedPrivKey)
346	if err != nil {
347		t.Fatalf("proto.Marshal() err = %v, want nil", err)
348	}
349	if _, err := km.Primitive(serializedCorruptedPrivate); err == nil {
350		t.Errorf("Primitive() err = nil, want error")
351	}
352}
353
354func TestRSASSAPKCS1SignerKeyManagerPrimitiveNewKey(t *testing.T) {
355	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
356	if err != nil {
357		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
358	}
359	validPrivKey, err := makeValidRSAPKCS1Key()
360	if err != nil {
361		t.Fatalf("makeValidRSAPKCS1Key() err = %v, want nil", err)
362	}
363	keyFormat := &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{
364		Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
365			HashType: cpb.HashType_SHA256,
366		},
367		ModulusSizeInBits: 3072,
368		PublicExponent:    []byte{0x01, 0x00, 0x01},
369	}
370	serializedFormat, err := proto.Marshal(keyFormat)
371	if err != nil {
372		t.Fatalf("proto.Marshal() err = %v, want nil", err)
373	}
374	m, err := km.NewKey(serializedFormat)
375	if err != nil {
376		t.Fatalf("NewKey() err = %v, want nil", err)
377	}
378	privKey, ok := m.(*rsassapkcs1pb.RsaSsaPkcs1PrivateKey)
379	if !ok {
380		t.Fatalf("privateKey is not a RsaSsaPkcs1PrivateKey")
381	}
382	if privKey.GetVersion() != validPrivKey.GetVersion() {
383		t.Errorf("GetVersion() = %d, want %d", privKey.GetVersion(), validPrivKey.GetVersion())
384	}
385	wantPubKey := validPrivKey.GetPublicKey()
386	gotPubKey := privKey.GetPublicKey()
387	if gotPubKey.GetParams().GetHashType() != wantPubKey.GetParams().GetHashType() {
388		t.Errorf("GetHashType() = %v, want %v", gotPubKey.GetParams().GetHashType(), wantPubKey.GetParams().GetHashType())
389	}
390	if !cmp.Equal(gotPubKey.GetE(), wantPubKey.GetE()) {
391		t.Errorf("GetE() = %v, want %v", gotPubKey.GetE(), wantPubKey.GetE())
392	}
393	gotModSize := new(big.Int).SetBytes(gotPubKey.GetN()).BitLen()
394	if gotModSize != 3072 {
395		t.Errorf("Modulus Size = %d, want %d", gotModSize, 3072)
396	}
397}
398
399func TestRSASSAPKCS1SignerKeyManagerPrimitiveNewKeyWithInvalidInputFails(t *testing.T) {
400	type testCase struct {
401		name   string
402		format *rsassapkcs1pb.RsaSsaPkcs1KeyFormat
403	}
404	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
405	if err != nil {
406		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
407	}
408	for _, tc := range []testCase{
409		{
410			name:   "empty format",
411			format: &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{},
412		},
413		{
414			name: "invalid hash",
415			format: &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{
416				ModulusSizeInBits: 2048,
417				PublicExponent:    []byte{0x01, 0x00, 0x01},
418				Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
419					HashType: cpb.HashType_SHA224,
420				},
421			},
422		},
423		{
424			name: "invalid public exponent",
425			format: &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{
426				ModulusSizeInBits: 2048,
427				PublicExponent:    []byte{0x01},
428				Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
429					HashType: cpb.HashType_SHA256,
430				},
431			},
432		},
433		{
434			name: "invalid modulus size",
435			format: &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{
436				ModulusSizeInBits: 1024,
437				PublicExponent:    []byte{0x01},
438				Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
439					HashType: cpb.HashType_SHA256,
440				},
441			},
442		},
443	} {
444		t.Run(tc.name, func(t *testing.T) {
445			serializedFormat, err := proto.Marshal(tc.format)
446			if err != nil {
447				t.Fatalf("proto.Marshal() err = %v, want nil", err)
448			}
449			if _, err := km.NewKey(serializedFormat); err == nil {
450				t.Fatalf("NewKey() err = nil, want error")
451			}
452		})
453	}
454}
455
456func TestRSASSAPKCS1SignerKeyManagerPrimitiveNewKeyData(t *testing.T) {
457	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
458	if err != nil {
459		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
460	}
461	keyFormat := &rsassapkcs1pb.RsaSsaPkcs1KeyFormat{
462		ModulusSizeInBits: 2048,
463		PublicExponent:    []byte{0x01, 0x00, 0x01},
464		Params: &rsassapkcs1pb.RsaSsaPkcs1Params{
465			HashType: cpb.HashType_SHA256,
466		},
467	}
468	serializedFormat, err := proto.Marshal(keyFormat)
469	if err != nil {
470		t.Fatalf("proto.Marshal() err = %v, want nil", err)
471	}
472	keyData, err := km.NewKeyData(serializedFormat)
473	if err != nil {
474		t.Fatalf("NewKeyData() err = %v, want nil", err)
475	}
476	if keyData.GetTypeUrl() != rsaPKCS1PrivateKeyTypeURL {
477		t.Errorf("GetTypeUrl() = %v, want %v", keyData.GetTypeUrl(), rsaPKCS1PrivateKeyTypeURL)
478	}
479	if keyData.GetKeyMaterialType() != tinkpb.KeyData_ASYMMETRIC_PRIVATE {
480		t.Errorf("GetKeyMaterialType() = %v, want %v", keyData.GetKeyMaterialType(), tinkpb.KeyData_ASYMMETRIC_PRIVATE)
481	}
482	if _, err := km.Primitive(keyData.GetValue()); err != nil {
483		t.Errorf("Primitive() err = %v, want nil", err)
484	}
485}
486
487func TestRSASSAPKCS1SignerKeyManagerPrimitiveNISTTestVectors(t *testing.T) {
488	km, err := registry.GetKeyManager(rsaPKCS1PrivateKeyTypeURL)
489	if err != nil {
490		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", rsaPKCS1PrivateKeyTypeURL, err)
491	}
492	for _, tc := range nistPKCS1TestVectors {
493		t.Run(tc.name, func(t *testing.T) {
494			key, err := tc.ToProtoKey()
495			if err != nil {
496				t.Fatalf("tc.ToProtoKey() err = %v, want nil", err)
497			}
498			serializedKey, err := proto.Marshal(key)
499			if err != nil {
500				t.Fatalf("proto.Marshal() err = %v, want nil", err)
501			}
502			p, err := km.Primitive(serializedKey)
503			if err != nil {
504				t.Fatalf("km.Primitive() err = %v, want nil", err)
505			}
506			msg, err := hex.DecodeString(tc.msg)
507			if err != nil {
508				t.Fatalf("hex.DecodeString(tc.msg) err = %v, want nil", err)
509			}
510			signer, ok := p.(tink.Signer)
511			if !ok {
512				t.Fatalf("primitive isn't a Tink.Signer")
513			}
514			sig, err := signer.Sign(msg)
515			if err != nil {
516				t.Fatalf("p.(tink.Signer).Sign(msg) err = %v, want nil", err)
517			}
518			gotSig := hex.EncodeToString(sig)
519			if !cmp.Equal(gotSig, tc.sig) {
520				t.Errorf("Sign() = %q, want %q", gotSig, tc.sig)
521			}
522		})
523	}
524}
525