xref: /aosp_15_r20/external/tink/go/integration/awskms/aws_kms_client_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package awskms
18
19import (
20	"bytes"
21	"encoding/hex"
22	"os"
23	"path/filepath"
24	"strings"
25	"testing"
26
27	"github.com/google/tink/go/integration/awskms/internal/fakeawskms"
28	"github.com/google/tink/go/core/registry"
29	"github.com/aws/aws-sdk-go/aws"
30	"github.com/aws/aws-sdk-go/service/kms"
31)
32
33func TestNewClientWithOptions_URIPrefix(t *testing.T) {
34	srcDir, ok := os.LookupEnv("TEST_SRCDIR")
35	if !ok {
36		t.Skip("TEST_SRCDIR not set")
37	}
38
39	// Necessary for testing deprecated factory functions.
40	credFile := filepath.Join(srcDir, "tink_go/testdata/aws/credentials.csv")
41	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
42	fakekms, err := fakeawskms.New([]string{keyARN})
43	if err != nil {
44		t.Fatalf("fakeawskms.New() failed: %v", err)
45	}
46
47	tests := []struct {
48		name      string
49		uriPrefix string
50		valid     bool
51	}{
52		{
53			name:      "AWS partition",
54			uriPrefix: "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f",
55			valid:     true,
56		},
57		{
58			name:      "AWS US government partition",
59			uriPrefix: "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f",
60			valid:     true,
61		},
62		{
63			name:      "AWS CN partition",
64			uriPrefix: "aws-kms://arn:aws-cn:kms:cn-north-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f",
65			valid:     true,
66		},
67		{
68			name:      "invalid",
69			uriPrefix: "bad-prefix://arn:aws-cn:kms:cn-north-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f",
70			valid:     false,
71		},
72	}
73
74	for _, test := range tests {
75		t.Run(test.name, func(t *testing.T) {
76			_, err := NewClientWithOptions(test.uriPrefix)
77			if test.valid && err != nil {
78				t.Errorf("NewClientWithOptions(%q) err = %v, want nil", test.uriPrefix, err)
79			}
80			if !test.valid && err == nil {
81				t.Errorf("NewClientWithOptions(%q) err = nil, want error", test.uriPrefix)
82			}
83
84			// Test deprecated factory functions.
85			_, err = NewClient(test.uriPrefix)
86			if test.valid && err != nil {
87				t.Errorf("NewClient(%q) err = %v, want nil", test.uriPrefix, err)
88			}
89			if !test.valid && err == nil {
90				t.Errorf("NewClient(%q) err = nil, want error", test.uriPrefix)
91			}
92
93			_, err = NewClientWithCredentials(test.uriPrefix, credFile)
94			if test.valid && err != nil {
95				t.Errorf("NewClientWithCredentialPath(%q, _) err = %v, want nil", test.uriPrefix, err)
96			}
97			if !test.valid && err == nil {
98				t.Errorf("NewClientWithCredentialPath(%q, _) err = nil, want error", test.uriPrefix)
99			}
100
101			_, err = NewClientWithKMS(test.uriPrefix, fakekms)
102			if test.valid && err != nil {
103				t.Errorf("NewClientWithKMS(%q, _) err = %v, want nil", test.uriPrefix, err)
104			}
105			if !test.valid && err == nil {
106				t.Errorf("NewClientWithKMS(%q, _) err = nil, want error", test.uriPrefix)
107			}
108		})
109	}
110}
111
112func TestNewClientWithOptions_WithCredentialPath(t *testing.T) {
113	srcDir, ok := os.LookupEnv("TEST_SRCDIR")
114	if !ok {
115		t.Skip("TEST_SRCDIR not set")
116	}
117
118	uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/"
119
120	tests := []struct {
121		name     string
122		credFile string
123		valid    bool
124	}{
125		{
126			name:     "valid CSV credentials file",
127			credFile: filepath.Join(srcDir, "tink_go/testdata/aws/credentials.csv"),
128			valid:    true,
129		},
130		{
131			name:     "valid INI credentials file",
132			credFile: filepath.Join(srcDir, "tink_go/testdata/aws/credentials.cred"),
133			valid:    true,
134		},
135		{
136			name:     "invalid credentials file",
137			credFile: filepath.Join(srcDir, "tink_go/testdata/aws/access_keys_bad.csv"),
138			valid:    false,
139		},
140	}
141
142	for _, test := range tests {
143		t.Run(test.name, func(t *testing.T) {
144			_, err := NewClientWithOptions(uriPrefix, WithCredentialPath(test.credFile))
145			if test.valid && err != nil {
146				t.Errorf("NewClientWithOptions(uriPrefix, WithCredentialPath(%q)) err = %v, want nil", test.credFile, err)
147			}
148			if !test.valid && err == nil {
149				t.Errorf("NewClientWithOptions(uriPrefix, WithCredentialPath(%q)) err = nil, want error", test.credFile)
150			}
151
152			// Test deprecated factory function.
153			_, err = NewClientWithCredentials(uriPrefix, test.credFile)
154			if test.valid && err != nil {
155				t.Errorf("NewClientWithCredentials(uriPrefix, %q) err = %v, want nil", test.credFile, err)
156			}
157			if !test.valid && err == nil {
158				t.Errorf("NewClientWithCredentials(uriPrefix, %q) err = nil, want error", test.credFile)
159			}
160
161		})
162	}
163}
164
165func TestNewClientWithOptions_RepeatedWithKMSFails(t *testing.T) {
166	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
167	fakekms, err := fakeawskms.New([]string{keyARN})
168	if err != nil {
169		t.Fatalf("fakekms.New() failed: %v", err)
170	}
171
172	_, err = NewClientWithOptions("aws-kms://", WithKMS(fakekms), WithKMS(fakekms))
173	if err == nil {
174		t.Fatalf("NewClientWithOptions(_, WithKMS(_), WithKMS(_)) err = nil, want error")
175	}
176}
177
178func TestNewClientWithOptions_RepeatedWithEncryptionContextNameFails(t *testing.T) {
179	_, err := NewClientWithOptions("aws-kms://", WithEncryptionContextName(LegacyAdditionalData), WithEncryptionContextName(AssociatedData))
180	if err == nil {
181		t.Fatalf("NewClientWithOptions(_, WithEncryptionContextName(_), WithEncryptionContextName(_)) err = nil, want error")
182	}
183}
184
185func TestSupported(t *testing.T) {
186	uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/"
187	supportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
188	nonSupportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-DOES-NOT-EXIST:key/"
189
190	client, err := NewClientWithOptions(uriPrefix)
191	if err != nil {
192		t.Fatalf("NewClientWithOptions() failed: %v", err)
193	}
194
195	if !client.Supported(supportedKeyURI) {
196		t.Errorf("client with URI prefix %q should support key URI %q", uriPrefix, supportedKeyURI)
197	}
198
199	if client.Supported(nonSupportedKeyURI) {
200		t.Errorf("client with URI prefix %q should NOT support key URI %q", uriPrefix, nonSupportedKeyURI)
201	}
202}
203
204func TestGetAEADSupportedURI(t *testing.T) {
205	uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/"
206	supportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
207
208	client, err := NewClientWithOptions(uriPrefix)
209	if err != nil {
210		t.Fatalf("NewClientWithOptions() failed: %v", err)
211	}
212
213	_, err = client.GetAEAD(supportedKeyURI)
214	if err != nil {
215		t.Errorf("client with URI prefix %q should support key URI %q", uriPrefix, supportedKeyURI)
216	}
217}
218
219func TestGetAEADEncryptDecrypt(t *testing.T) {
220	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
221	keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
222	fakekms, err := fakeawskms.New([]string{keyARN})
223	if err != nil {
224		t.Fatalf("fakekms.New() failed: %v", err)
225	}
226
227	client, err := NewClientWithOptions("aws-kms://", WithKMS(fakekms))
228	if err != nil {
229		t.Fatalf("NewClientWithOptions() failed: %v", err)
230	}
231
232	a, err := client.GetAEAD(keyURI)
233	if err != nil {
234		t.Fatalf("client.GetAEAD(keyURI) err = %v, want nil", err)
235	}
236
237	plaintext := []byte("plaintext")
238	associatedData := []byte("associatedData")
239	ciphertext, err := a.Encrypt(plaintext, associatedData)
240	if err != nil {
241		t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err)
242	}
243	decrypted, err := a.Decrypt(ciphertext, associatedData)
244	if err != nil {
245		t.Fatalf("a.Decrypt(ciphertext, associatedData) err = %v, want nil", err)
246	}
247	if !bytes.Equal(decrypted, plaintext) {
248		t.Errorf("decrypted = %q, want %q", decrypted, plaintext)
249	}
250
251	_, err = a.Decrypt(ciphertext, []byte("invalidAssociatedData"))
252	if err == nil {
253		t.Error("a.Decrypt(ciphertext, []byte(\"invalidAssociatedData\")) err = nil, want error")
254	}
255
256	_, err = a.Decrypt([]byte("invalidCiphertext"), associatedData)
257	if err == nil {
258		t.Error("a.Decrypt([]byte(\"invalidCiphertext\"), associatedData) err = nil, want error")
259	}
260}
261
262func TestUsesAdditionalDataAsContextName(t *testing.T) {
263	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
264	keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
265	fakekms, err := fakeawskms.New([]string{keyARN})
266	if err != nil {
267		t.Fatalf("fakeawskms.New() failed: %v", err)
268	}
269
270	client, err := NewClientWithKMS("aws-kms://", fakekms)
271	if err != nil {
272		t.Fatalf("NewClientWithKMS() failed: %v", err)
273	}
274
275	a, err := client.GetAEAD(keyURI)
276	if err != nil {
277		t.Fatalf("client.GetAEAD(keyURI) failed: %s", err)
278	}
279
280	plaintext := []byte("plaintext")
281	associatedData := []byte("associatedData")
282	ciphertext, err := a.Encrypt(plaintext, associatedData)
283	if err != nil {
284		t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err)
285	}
286
287	hexAD := hex.EncodeToString(associatedData)
288	context := map[string]*string{"additionalData": &hexAD}
289	decRequest := &kms.DecryptInput{
290		KeyId:             aws.String(keyARN),
291		CiphertextBlob:    ciphertext,
292		EncryptionContext: context,
293	}
294	decResponse, err := fakekms.Decrypt(decRequest)
295	if err != nil {
296		t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
297	}
298	if !bytes.Equal(decResponse.Plaintext, plaintext) {
299		t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
300	}
301	if strings.Compare(*decResponse.KeyId, keyARN) != 0 {
302		t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN)
303	}
304}
305
306func TestEncryptionContextName(t *testing.T) {
307	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
308	keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
309	fakekms, err := fakeawskms.New([]string{keyARN})
310	if err != nil {
311		t.Fatalf("fakeawskms.New() failed: %v", err)
312	}
313
314	tests := []struct {
315		contextName     EncryptionContextName
316		wantContextName string
317	}{
318		{
319			contextName:     LegacyAdditionalData,
320			wantContextName: "additionalData",
321		},
322		{
323			contextName:     AssociatedData,
324			wantContextName: "associatedData",
325		},
326	}
327
328	for _, test := range tests {
329		t.Run(test.wantContextName, func(t *testing.T) {
330			client, err := NewClientWithOptions("aws-kms://", WithKMS(fakekms), WithEncryptionContextName(test.contextName))
331			if err != nil {
332				t.Fatalf("NewClientWithOptions() failed: %v", err)
333			}
334
335			a, err := client.GetAEAD(keyURI)
336			if err != nil {
337				t.Fatalf("client.GetAEAD(keyURI) failed: %s", err)
338			}
339
340			plaintext := []byte("plaintext")
341			associatedData := []byte("associatedData")
342			ciphertext, err := a.Encrypt(plaintext, associatedData)
343			if err != nil {
344				t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err)
345			}
346
347			hexAD := hex.EncodeToString(associatedData)
348			context := map[string]*string{test.wantContextName: &hexAD}
349			decRequest := &kms.DecryptInput{
350				KeyId:             aws.String(keyARN),
351				CiphertextBlob:    ciphertext,
352				EncryptionContext: context,
353			}
354			decResponse, err := fakekms.Decrypt(decRequest)
355			if err != nil {
356				t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
357			}
358			if !bytes.Equal(decResponse.Plaintext, plaintext) {
359				t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
360			}
361			if strings.Compare(*decResponse.KeyId, keyARN) != 0 {
362				t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN)
363			}
364		})
365	}
366}
367
368func TestEncryptionContextName_defaultEncryptionContextName(t *testing.T) {
369	keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
370	keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f"
371	fakekms, err := fakeawskms.New([]string{keyARN})
372	if err != nil {
373		t.Fatalf("fakeawskms.New() failed: %v", err)
374	}
375
376	tests := []struct {
377		name string
378		client		func(t *testing.T) registry.KMSClient
379		wantContextName string
380	}{
381		{
382			name: "NewClientWithOptions",
383			client: func(t *testing.T) registry.KMSClient{
384				t.Helper()
385				c, err := NewClientWithOptions(keyURI, WithKMS(fakekms))
386				if err != nil {
387					t.Fatalf("NewClientWithOptions() failed: %v", err)
388				}
389				return c
390
391			},
392			wantContextName: "associatedData",
393		},
394		// Test deprecated factory function.
395		{
396			name: "NewClientWithKMS",
397			client: func(t *testing.T) registry.KMSClient{
398				t.Helper()
399				c, err := NewClientWithKMS(keyURI, fakekms)
400				if err != nil {
401					t.Fatalf("NewClientWithKMS() failed: %v", err)
402				}
403				return c
404			},
405			wantContextName: "additionalData",
406		},
407	}
408
409	for _, test := range tests {
410		t.Run(test.name, func(t *testing.T) {
411			client := test.client(t)
412			a, err := client.GetAEAD(keyURI)
413			if err != nil {
414				t.Fatalf("client.GetAEAD(keyURI) failed: %s", err)
415			}
416
417			plaintext := []byte("plaintext")
418			associatedData := []byte("associatedData")
419			ciphertext, err := a.Encrypt(plaintext, associatedData)
420			if err != nil {
421				t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err)
422			}
423
424			hexAD := hex.EncodeToString(associatedData)
425			context := map[string]*string{test.wantContextName: &hexAD}
426			decRequest := &kms.DecryptInput{
427				KeyId:             aws.String(keyARN),
428				CiphertextBlob:    ciphertext,
429				EncryptionContext: context,
430			}
431			decResponse, err := fakekms.Decrypt(decRequest)
432			if err != nil {
433				t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
434			}
435			if !bytes.Equal(decResponse.Plaintext, plaintext) {
436				t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
437			}
438			if strings.Compare(*decResponse.KeyId, keyARN) != 0 {
439				t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN)
440			}
441		})
442	}
443}
444