xref: /aosp_15_r20/external/tink/go/aead/aes_gcm_key_manager_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package aead_test
18
19import (
20	"bytes"
21	"fmt"
22	"testing"
23
24	"github.com/google/go-cmp/cmp"
25	"google.golang.org/protobuf/proto"
26	"github.com/google/tink/go/aead/subtle"
27	"github.com/google/tink/go/core/registry"
28	"github.com/google/tink/go/internal/internalregistry"
29	"github.com/google/tink/go/subtle/random"
30	"github.com/google/tink/go/testutil"
31	gcmpb "github.com/google/tink/go/proto/aes_gcm_go_proto"
32	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
33)
34
35var keySizes = []uint32{16, 32}
36
37func TestAESGCMGetPrimitiveBasic(t *testing.T) {
38	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
39	if err != nil {
40		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
41	}
42	for _, keySize := range keySizes {
43		key := testutil.NewAESGCMKey(testutil.AESGCMKeyVersion, keySize)
44		serializedKey, _ := proto.Marshal(key)
45		p, err := keyManager.Primitive(serializedKey)
46		if err != nil {
47			t.Errorf("unexpected error: %s", err)
48		}
49		if err := validateAESGCMPrimitive(p, key); err != nil {
50			t.Errorf("%s", err)
51		}
52	}
53}
54
55func TestAESGCMGetPrimitiveWithInvalidInput(t *testing.T) {
56	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
57	if err != nil {
58		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
59	}
60	// invalid AESGCMKey
61	testKeys := genInvalidAESGCMKeys()
62	for i := 0; i < len(testKeys); i++ {
63		serializedKey, _ := proto.Marshal(testKeys[i])
64		if _, err := keyManager.Primitive(serializedKey); err == nil {
65			t.Errorf("expect an error in test case %d", i)
66		}
67	}
68	// nil
69	if _, err := keyManager.Primitive(nil); err == nil {
70		t.Errorf("expect an error when input is nil")
71	}
72	// empty array
73	if _, err := keyManager.Primitive([]byte{}); err == nil {
74		t.Errorf("expect an error when input is empty")
75	}
76}
77
78func TestAESGCMNewKeyMultipleTimes(t *testing.T) {
79	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
80	if err != nil {
81		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
82	}
83	format := testutil.NewAESGCMKeyFormat(32)
84	serializedFormat, _ := proto.Marshal(format)
85	keys := make(map[string]bool)
86	nTest := 26
87	for i := 0; i < nTest; i++ {
88		key, _ := keyManager.NewKey(serializedFormat)
89		serializedKey, _ := proto.Marshal(key)
90		keys[string(serializedKey)] = true
91
92		keyData, _ := keyManager.NewKeyData(serializedFormat)
93		serializedKey = keyData.Value
94		keys[string(serializedKey)] = true
95	}
96	if len(keys) != nTest*2 {
97		t.Errorf("key is repeated")
98	}
99}
100
101func TestAESGCMNewKeyBasic(t *testing.T) {
102	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
103	if err != nil {
104		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
105	}
106	for _, keySize := range keySizes {
107		format := testutil.NewAESGCMKeyFormat(keySize)
108		serializedFormat, _ := proto.Marshal(format)
109		m, err := keyManager.NewKey(serializedFormat)
110		if err != nil {
111			t.Errorf("unexpected error: %s", err)
112		}
113		key := m.(*gcmpb.AesGcmKey)
114		if err := validateAESGCMKey(key, format); err != nil {
115			t.Errorf("%s", err)
116		}
117	}
118}
119
120func TestAESGCMNewKeyWithInvalidInput(t *testing.T) {
121	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
122	if err != nil {
123		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
124	}
125	// bad format
126	badFormats := genInvalidAESGCMKeyFormats()
127	for i := 0; i < len(badFormats); i++ {
128		serializedFormat, _ := proto.Marshal(badFormats[i])
129		if _, err := keyManager.NewKey(serializedFormat); err == nil {
130			t.Errorf("expect an error in test case %d", i)
131		}
132	}
133	// nil
134	if _, err := keyManager.NewKey(nil); err == nil {
135		t.Errorf("expect an error when input is nil")
136	}
137	// empty array
138	if _, err := keyManager.NewKey([]byte{}); err == nil {
139		t.Errorf("expect an error when input is empty")
140	}
141}
142
143func TestAESGCMNewKeyDataBasic(t *testing.T) {
144	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
145	if err != nil {
146		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
147	}
148	for _, keySize := range keySizes {
149		format := testutil.NewAESGCMKeyFormat(keySize)
150		serializedFormat, _ := proto.Marshal(format)
151		keyData, err := keyManager.NewKeyData(serializedFormat)
152		if err != nil {
153			t.Errorf("unexpected error: %s", err)
154		}
155		if keyData.TypeUrl != testutil.AESGCMTypeURL {
156			t.Errorf("incorrect type url")
157		}
158		if keyData.KeyMaterialType != tinkpb.KeyData_SYMMETRIC {
159			t.Errorf("incorrect key material type")
160		}
161		key := new(gcmpb.AesGcmKey)
162		if err := proto.Unmarshal(keyData.Value, key); err != nil {
163			t.Errorf("incorrect key value")
164		}
165		if err := validateAESGCMKey(key, format); err != nil {
166			t.Errorf("%s", err)
167		}
168	}
169}
170
171func TestAESGCMNewKeyDataWithInvalidInput(t *testing.T) {
172	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
173	if err != nil {
174		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
175	}
176	badFormats := genInvalidAESGCMKeyFormats()
177	for i := 0; i < len(badFormats); i++ {
178		serializedFormat, _ := proto.Marshal(badFormats[i])
179		if _, err := keyManager.NewKeyData(serializedFormat); err == nil {
180			t.Errorf("expect an error in test case %d", i)
181		}
182	}
183	// nil input
184	if _, err := keyManager.NewKeyData(nil); err == nil {
185		t.Errorf("expect an error when input is nil")
186	}
187	// empty input
188	if _, err := keyManager.NewKeyData([]byte{}); err == nil {
189		t.Errorf("expect an error when input is empty")
190	}
191}
192
193func TestAESGCMDoesSupport(t *testing.T) {
194	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
195	if err != nil {
196		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
197	}
198	if !keyManager.DoesSupport(testutil.AESGCMTypeURL) {
199		t.Errorf("AESGCMKeyManager must support %s", testutil.AESGCMTypeURL)
200	}
201	if keyManager.DoesSupport("some bad type") {
202		t.Errorf("AESGCMKeyManager must support only %s", testutil.AESGCMTypeURL)
203	}
204}
205
206func TestAESGCMTypeURL(t *testing.T) {
207	keyManager, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
208	if err != nil {
209		t.Errorf("cannot obtain AES-GCM key manager: %s", err)
210	}
211	if keyManager.TypeURL() != testutil.AESGCMTypeURL {
212		t.Errorf("incorrect key type")
213	}
214}
215
216func TestAESGCMKeyMaterialType(t *testing.T) {
217	km, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
218	if err != nil {
219		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.AESGCMTypeURL, err)
220	}
221	keyManager, ok := km.(internalregistry.DerivableKeyManager)
222	if !ok {
223		t.Fatalf("key manager is not DerivableKeyManager")
224	}
225	if got, want := keyManager.KeyMaterialType(), tinkpb.KeyData_SYMMETRIC; got != want {
226		t.Errorf("KeyMaterialType() = %v, want %v", got, want)
227	}
228}
229
230func TestAESGCMDeriveKey(t *testing.T) {
231	km, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
232	if err != nil {
233		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.AESGCMTypeURL, err)
234	}
235	keyManager, ok := km.(internalregistry.DerivableKeyManager)
236	if !ok {
237		t.Fatalf("key manager is not DerivableKeyManager")
238	}
239
240	for _, test := range []struct {
241		name    string
242		keySize uint32
243	}{
244		{
245			name:    "AES-128-GCM",
246			keySize: 16,
247		},
248		{
249			name:    "AES-256-GCM",
250			keySize: 32,
251		},
252	} {
253		t.Run(test.name, func(t *testing.T) {
254			keyFormat := testutil.NewAESGCMKeyFormat(test.keySize)
255			serializedKeyFormat, err := proto.Marshal(keyFormat)
256			if err != nil {
257				t.Fatalf("proto.Marshal(%v) err = %v, want nil", keyFormat, err)
258			}
259
260			rand := random.GetRandomBytes(test.keySize)
261			buf := &bytes.Buffer{}
262			buf.Write(rand) // never returns a non-nil error
263
264			k, err := keyManager.DeriveKey(serializedKeyFormat, buf)
265			if err != nil {
266				t.Fatalf("keyManager.DeriveKey() err = %v, want nil", err)
267			}
268			key := k.(*gcmpb.AesGcmKey)
269			if got, want := len(key.GetKeyValue()), int(test.keySize); got != want {
270				t.Errorf("key length = %d, want %d", got, want)
271			}
272			if diff := cmp.Diff(key.GetKeyValue(), rand); diff != "" {
273				t.Errorf("incorrect derived key: diff = %v", diff)
274			}
275		})
276	}
277}
278
279func TestAESGCMDeriveKeyFailsWithInvalidKeyFormats(t *testing.T) {
280	km, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
281	if err != nil {
282		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.AESGCMTypeURL, err)
283	}
284	keyManager, ok := km.(internalregistry.DerivableKeyManager)
285	if !ok {
286		t.Fatalf("key manager is not DerivableKeyManager")
287	}
288
289	for _, test := range []struct {
290		name      string
291		keyFormat *gcmpb.AesGcmKeyFormat
292		randLen   uint32
293	}{
294		{
295			name:      "invalid key size",
296			keyFormat: &gcmpb.AesGcmKeyFormat{KeySize: 50, Version: 0},
297			randLen:   50,
298		},
299		{
300			name:      "not enough randomness",
301			keyFormat: &gcmpb.AesGcmKeyFormat{KeySize: 32, Version: 0},
302			randLen:   10,
303		},
304		{
305			name:      "invalid version",
306			keyFormat: &gcmpb.AesGcmKeyFormat{KeySize: 32, Version: 100000},
307			randLen:   32,
308		},
309		{
310			name:      "empty key format",
311			keyFormat: &gcmpb.AesGcmKeyFormat{},
312			randLen:   16,
313		},
314		{
315			name:    "nil key format",
316			randLen: 16,
317		},
318	} {
319		t.Run(test.name, func(t *testing.T) {
320			serializedKeyFormat, err := proto.Marshal(test.keyFormat)
321			if err != nil {
322				t.Fatalf("proto.Marshal(%v) err = %v, want nil", test.keyFormat, err)
323			}
324			buf := bytes.NewBuffer(random.GetRandomBytes(test.randLen))
325			if _, err := keyManager.DeriveKey(serializedKeyFormat, buf); err == nil {
326				t.Error("keyManager.DeriveKey() err = nil, want non-nil")
327			}
328		})
329	}
330}
331
332func TestAESGCMDeriveKeyFailsWithMalformedSerializedKeyFormat(t *testing.T) {
333	km, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
334	if err != nil {
335		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.AESGCMTypeURL, err)
336	}
337	keyManager, ok := km.(internalregistry.DerivableKeyManager)
338	if !ok {
339		t.Fatalf("key manager is not DerivableKeyManager")
340	}
341	size := proto.Size(&gcmpb.AesGcmKeyFormat{KeySize: 16, Version: 0})
342	malformedSerializedKeyFormat := random.GetRandomBytes(uint32(size))
343	buf := bytes.NewBuffer(random.GetRandomBytes(32))
344	if _, err := keyManager.DeriveKey(malformedSerializedKeyFormat, buf); err == nil {
345		t.Error("keyManager.DeriveKey() err = nil, want non-nil")
346	}
347}
348
349func TestAESGCMDeriveKeyFailsWithInsufficientRandomness(t *testing.T) {
350	km, err := registry.GetKeyManager(testutil.AESGCMTypeURL)
351	if err != nil {
352		t.Fatalf("registry.GetKeyManager(%q) err = %v, want nil", testutil.AESGCMTypeURL, err)
353	}
354	keyManager, ok := km.(internalregistry.DerivableKeyManager)
355	if !ok {
356		t.Fatalf("key manager is not DerivableKeyManager")
357	}
358	var keySize uint32 = 16
359	keyFormat, err := proto.Marshal(testutil.NewAESGCMKeyFormat(keySize))
360	if err != nil {
361		t.Fatalf("proto.Marshal() err = %v, want nil", err)
362	}
363	{
364		buf := bytes.NewBuffer(random.GetRandomBytes(keySize))
365		if _, err := keyManager.DeriveKey(keyFormat, buf); err != nil {
366			t.Errorf("keyManager.DeriveKey() err = %v, want nil", err)
367		}
368	}
369	{
370		insufficientBuf := bytes.NewBuffer(random.GetRandomBytes(keySize - 1))
371		if _, err := keyManager.DeriveKey(keyFormat, insufficientBuf); err == nil {
372			t.Errorf("keyManager.DeriveKey() err = nil, want non-nil")
373		}
374	}
375}
376
377func genInvalidAESGCMKeys() []proto.Message {
378	return []proto.Message{
379		// not a AESGCMKey
380		testutil.NewAESGCMKeyFormat(32),
381		// bad key size
382		testutil.NewAESGCMKey(testutil.AESGCMKeyVersion, 17),
383		testutil.NewAESGCMKey(testutil.AESGCMKeyVersion, 25),
384		testutil.NewAESGCMKey(testutil.AESGCMKeyVersion, 33),
385		// bad version
386		testutil.NewAESGCMKey(testutil.AESGCMKeyVersion+1, 16),
387	}
388}
389
390func genInvalidAESGCMKeyFormats() []proto.Message {
391	return []proto.Message{
392		// not AESGCMKeyFormat
393		testutil.NewAESGCMKey(testutil.AESGCMKeyVersion, 16),
394		// invalid key size
395		testutil.NewAESGCMKeyFormat(uint32(15)),
396		testutil.NewAESGCMKeyFormat(uint32(23)),
397		testutil.NewAESGCMKeyFormat(uint32(31)),
398	}
399}
400
401func validateAESGCMKey(key *gcmpb.AesGcmKey, format *gcmpb.AesGcmKeyFormat) error {
402	if uint32(len(key.KeyValue)) != format.KeySize {
403		return fmt.Errorf("incorrect key size")
404	}
405	if key.Version != testutil.AESGCMKeyVersion {
406		return fmt.Errorf("incorrect key version")
407	}
408	// try to encrypt and decrypt
409	p, err := subtle.NewAESGCM(key.KeyValue)
410	if err != nil {
411		return fmt.Errorf("invalid key")
412	}
413	return validateAESGCMPrimitive(p, key)
414}
415
416func validateAESGCMPrimitive(p interface{}, key *gcmpb.AesGcmKey) error {
417	cipher := p.(*subtle.AESGCM)
418	if !bytes.Equal(cipher.Key(), key.KeyValue) {
419		return fmt.Errorf("key and primitive don't match")
420	}
421	// try to encrypt and decrypt
422	pt := random.GetRandomBytes(32)
423	aad := random.GetRandomBytes(32)
424	ct, err := cipher.Encrypt(pt, aad)
425	if err != nil {
426		return fmt.Errorf("encryption failed")
427	}
428	decrypted, err := cipher.Decrypt(ct, aad)
429	if err != nil {
430		return fmt.Errorf("decryption failed")
431	}
432	if !bytes.Equal(decrypted, pt) {
433		return fmt.Errorf("decryption failed")
434	}
435	return nil
436}
437