xref: /aosp_15_r20/external/tink/go/keyderivation/keyset_deriver_factory_x_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package keyderivation_test
18
19import (
20	"bytes"
21	"testing"
22
23	"google.golang.org/protobuf/proto"
24	"github.com/google/tink/go/aead"
25	"github.com/google/tink/go/keyderivation"
26	"github.com/google/tink/go/keyset"
27	"github.com/google/tink/go/prf"
28	"github.com/google/tink/go/subtle/random"
29	prfderpb "github.com/google/tink/go/proto/prf_based_deriver_go_proto"
30	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
31)
32
33func TestWrappedKeysetDeriver(t *testing.T) {
34	// Construct deriving keyset handle containing one key.
35	aes128GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
36		PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
37		Params: &prfderpb.PrfBasedDeriverParams{
38			DerivedKeyTemplate: aead.AES128GCMKeyTemplate(),
39		},
40	})
41	if err != nil {
42		t.Fatalf("proto.Marshal(aes128GCMKeyFormat) err = %v, want nil", err)
43	}
44	singleKeyHandle, err := keyset.NewHandle(&tinkpb.KeyTemplate{
45		TypeUrl:          prfBasedDeriverTypeURL,
46		OutputPrefixType: tinkpb.OutputPrefixType_RAW,
47		Value:            aes128GCMKeyFormat,
48	})
49	if err != nil {
50		t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
51	}
52
53	// Construct deriving keyset handle containing three keys.
54	xChaChaKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
55		PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
56		Params: &prfderpb.PrfBasedDeriverParams{
57			DerivedKeyTemplate: aead.XChaCha20Poly1305KeyTemplate(),
58		},
59	})
60	if err != nil {
61		t.Fatalf("proto.Marshal(xChaChaKeyFormat) err = %v, want nil", err)
62	}
63	aes256GCMKeyFormat, err := proto.Marshal(&prfderpb.PrfBasedDeriverKeyFormat{
64		PrfKeyTemplate: prf.HKDFSHA256PRFKeyTemplate(),
65		Params: &prfderpb.PrfBasedDeriverParams{
66			DerivedKeyTemplate: aead.AES256GCMKeyTemplate(),
67		},
68	})
69	if err != nil {
70		t.Fatalf("proto.Marshal(aes256GCMKeyFormat) err = %v, want nil", err)
71	}
72	manager := keyset.NewManager()
73	aes128GCMKeyID, err := manager.Add(&tinkpb.KeyTemplate{
74		TypeUrl:          prfBasedDeriverTypeURL,
75		OutputPrefixType: tinkpb.OutputPrefixType_RAW,
76		Value:            aes128GCMKeyFormat,
77	})
78	if err != nil {
79		t.Fatalf("manager.Add(aes128GCMTemplate) err = %v, want nil", err)
80	}
81	if err := manager.SetPrimary(aes128GCMKeyID); err != nil {
82		t.Fatalf("manager.SetPrimary() err = %v, want nil", err)
83	}
84	if _, err := manager.Add(&tinkpb.KeyTemplate{
85		TypeUrl:          prfBasedDeriverTypeURL,
86		OutputPrefixType: tinkpb.OutputPrefixType_TINK,
87		Value:            xChaChaKeyFormat,
88	}); err != nil {
89		t.Fatalf("manager.Add(xChaChaTemplate) err = %v, want nil", err)
90	}
91	if _, err := manager.Add(&tinkpb.KeyTemplate{
92		TypeUrl:          prfBasedDeriverTypeURL,
93		OutputPrefixType: tinkpb.OutputPrefixType_CRUNCHY,
94		Value:            aes256GCMKeyFormat,
95	}); err != nil {
96		t.Fatalf("manager.Add(aes256GCMTemplate) err = %v, want nil", err)
97	}
98	multipleKeysHandle, err := manager.Handle()
99	if err != nil {
100		t.Fatalf("manager.Handle() err = %v, want nil", err)
101	}
102	if got, want := len(multipleKeysHandle.KeysetInfo().GetKeyInfo()), 3; got != want {
103		t.Fatalf("len(multipleKeysHandle) = %d, want %d", got, want)
104	}
105
106	for _, test := range []struct {
107		name         string
108		handle       *keyset.Handle
109		wantTypeURLs []string
110	}{
111		{
112			name:   "single key",
113			handle: singleKeyHandle,
114			wantTypeURLs: []string{
115				"type.googleapis.com/google.crypto.tink.AesGcmKey",
116			},
117		},
118		{
119			name:   "multiple keys",
120			handle: multipleKeysHandle,
121			wantTypeURLs: []string{
122				"type.googleapis.com/google.crypto.tink.AesGcmKey",
123				"type.googleapis.com/google.crypto.tink.XChaCha20Poly1305Key",
124				"type.googleapis.com/google.crypto.tink.AesGcmKey",
125			},
126		},
127	} {
128		t.Run(test.name, func(t *testing.T) {
129			// Derive keyset handle.
130			kd, err := keyderivation.New(test.handle)
131			if err != nil {
132				t.Fatalf("keyderivation.New() err = %v, want nil", err)
133			}
134			derivedHandle, err := kd.DeriveKeyset([]byte("salt"))
135			if err != nil {
136				t.Fatalf("DeriveKeyset() err = %v, want nil", err)
137			}
138
139			// Verify number of derived keys = number of deriving keys.
140			derivedKeyInfo := derivedHandle.KeysetInfo().GetKeyInfo()
141			keyInfo := test.handle.KeysetInfo().GetKeyInfo()
142			if len(derivedKeyInfo) != len(keyInfo) {
143				t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo))
144			}
145			if len(derivedKeyInfo) != len(test.wantTypeURLs) {
146				t.Errorf("number of derived keys = %d, want %d", len(derivedKeyInfo), len(keyInfo))
147			}
148
149			// Verify derived keys.
150			hasPrimaryKey := false
151			for i, derivedKey := range derivedKeyInfo {
152				derivingKey := keyInfo[i]
153				if got, want := derivedKey.GetOutputPrefixType(), derivingKey.GetOutputPrefixType(); got != want {
154					t.Errorf("GetOutputPrefixType() = %s, want %s", got, want)
155				}
156				if got, want := derivedKey.GetKeyId(), derivingKey.GetKeyId(); got != want {
157					t.Errorf("GetKeyId() = %d, want %d", got, want)
158				}
159				if got, want := derivedKey.GetTypeUrl(), test.wantTypeURLs[i]; got != want {
160					t.Errorf("GetTypeUrl() = %q, want %q", got, want)
161				}
162				if got, want := derivedKey.GetStatus(), derivingKey.GetStatus(); got != want {
163					t.Errorf("GetStatus() = %s, want %s", got, want)
164				}
165				if derivedKey.GetKeyId() == derivedHandle.KeysetInfo().GetPrimaryKeyId() {
166					hasPrimaryKey = true
167				}
168			}
169			if !hasPrimaryKey {
170				t.Errorf("derived keyset has no primary key")
171			}
172
173			// Verify derived keyset handle works for AEAD.
174			pt := random.GetRandomBytes(16)
175			ad := random.GetRandomBytes(4)
176			a, err := aead.New(derivedHandle)
177			if err != nil {
178				t.Fatalf("aead.New() err = %v, want nil", err)
179			}
180			ct, err := a.Encrypt(pt, ad)
181			if err != nil {
182				t.Fatalf("Encrypt() err = %v, want nil", err)
183			}
184			gotPT, err := a.Decrypt(ct, ad)
185			if err != nil {
186				t.Fatalf("Decrypt() err = %v, want nil", err)
187			}
188			if !bytes.Equal(gotPT, pt) {
189				t.Errorf("Decrypt() = %v, want %v", gotPT, pt)
190			}
191		})
192	}
193}
194
195func TestNewRejectsNilKeysetHandle(t *testing.T) {
196	if _, err := keyderivation.New(nil); err == nil {
197		t.Error("keyderivation.New() err = nil, want non-nil")
198	}
199}
200
201func TestNewRejectsIncorrectKey(t *testing.T) {
202	kh, err := keyset.NewHandle(aead.AES128GCMKeyTemplate())
203	if err != nil {
204		t.Fatalf("keyset.NewHandle() err = %v, want nil", err)
205	}
206	if _, err := keyderivation.New(kh); err == nil {
207		t.Error("keyderivation.New() err = nil, want non-nil")
208	}
209}
210