xref: /aosp_15_r20/external/tink/go/integration/awskms/internal/fakeawskms/fakeawskms_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package fakeawskms
18
19import (
20	"bytes"
21	"strings"
22	"testing"
23
24	"github.com/aws/aws-sdk-go/aws"
25	"github.com/aws/aws-sdk-go/service/kms"
26)
27
28const validKeyID = "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"
29const validKeyID2 = "arn:aws:kms:us-west-2:123:key/different"
30
31func TestEncyptDecryptWithValidKeyId(t *testing.T) {
32	fakeKMS, err := New([]string{validKeyID})
33	if err != nil {
34		t.Fatalf("New() err = %s, want nil", err)
35	}
36
37	plaintext := []byte("plaintext")
38	contextValue := "contextValue"
39	context := map[string]*string{"contextName": &contextValue}
40
41	encRequest := &kms.EncryptInput{
42		KeyId:             aws.String(validKeyID),
43		Plaintext:         plaintext,
44		EncryptionContext: context,
45	}
46
47	encResponse, err := fakeKMS.Encrypt(encRequest)
48	if err != nil {
49		t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
50	}
51
52	ciphertext := encResponse.CiphertextBlob
53
54	decRequest := &kms.DecryptInput{
55		KeyId:             aws.String(validKeyID),
56		CiphertextBlob:    ciphertext,
57		EncryptionContext: context,
58	}
59	decResponse, err := fakeKMS.Decrypt(decRequest)
60	if err != nil {
61		t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
62	}
63	if !bytes.Equal(decResponse.Plaintext, plaintext) {
64		t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
65	}
66	if strings.Compare(*decResponse.KeyId, validKeyID) != 0 {
67		t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID)
68	}
69
70	// decrypt with a different context should fail
71	otherContextValue := "otherContextValue"
72	otherContext := map[string]*string{"contextName": &otherContextValue}
73	otherDecRequest := &kms.DecryptInput{
74		KeyId:             aws.String(validKeyID),
75		CiphertextBlob:    ciphertext,
76		EncryptionContext: otherContext,
77	}
78	if _, err := fakeKMS.Decrypt(otherDecRequest); err == nil {
79		t.Fatal("fakeKMS.Decrypt(otherDecRequest) err = nil, want not nil")
80	}
81}
82
83func TestEncyptWithUnknownKeyID(t *testing.T) {
84	fakeKMS, err := New([]string{validKeyID})
85	if err != nil {
86		t.Fatalf("New() err = %s, want nil", err)
87	}
88
89	plaintext := []byte("plaintext")
90	contextValue := "contextValue"
91	context := map[string]*string{"contextName": &contextValue}
92
93	encRequestWithUnknownKeyID := &kms.EncryptInput{
94		KeyId:             aws.String(validKeyID2),
95		Plaintext:         plaintext,
96		EncryptionContext: context,
97	}
98
99	if _, err := fakeKMS.Encrypt(encRequestWithUnknownKeyID); err == nil {
100		t.Fatal("fakeKMS.Encrypt(encRequestWithvalidKeyID2) err = nil, want not nil")
101	}
102}
103
104func TestDecryptWithInvalidCiphertext(t *testing.T) {
105	fakeKMS, err := New([]string{validKeyID})
106	if err != nil {
107		t.Fatalf("New() err = %s, want nil", err)
108	}
109
110	invalidCiphertext := []byte("plaintext")
111	contextValue := "contextValue"
112	context := map[string]*string{"contextName": &contextValue}
113
114	decRequest := &kms.DecryptInput{
115		CiphertextBlob:    invalidCiphertext,
116		EncryptionContext: context,
117	}
118
119	if _, err := fakeKMS.Decrypt(decRequest); err == nil {
120		t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
121	}
122}
123
124func TestDecryptWithUnknownKeyId(t *testing.T) {
125	fakeKMS, err := New([]string{validKeyID})
126	if err != nil {
127		t.Fatalf("New() err = %s, want nil", err)
128	}
129
130	ciphertext := []byte("invalidCiphertext")
131	contextValue := "contextValue"
132	context := map[string]*string{"contextName": &contextValue}
133
134	decRequest := &kms.DecryptInput{
135		KeyId:             aws.String(validKeyID2),
136		CiphertextBlob:    ciphertext,
137		EncryptionContext: context,
138	}
139
140	if _, err := fakeKMS.Decrypt(decRequest); err == nil {
141		t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
142	}
143}
144
145func TestDecryptWithWrongKeyId(t *testing.T) {
146	fakeKMS, err := New([]string{validKeyID, validKeyID2})
147	if err != nil {
148		t.Fatalf("New() err = %s, want nil", err)
149	}
150
151	plaintext := []byte("plaintext")
152	contextValue := "contextValue"
153	context := map[string]*string{"contextName": &contextValue}
154
155	encRequest := &kms.EncryptInput{
156		KeyId:             aws.String(validKeyID),
157		Plaintext:         plaintext,
158		EncryptionContext: context,
159	}
160
161	encResponse, err := fakeKMS.Encrypt(encRequest)
162	if err != nil {
163		t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
164	}
165
166	ciphertext := encResponse.CiphertextBlob
167
168	decRequest := &kms.DecryptInput{
169		KeyId:             aws.String(validKeyID2), // wrong key id
170		CiphertextBlob:    ciphertext,
171		EncryptionContext: context,
172	}
173	if _, err := fakeKMS.Decrypt(decRequest); err == nil {
174		t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
175	}
176}
177
178func TestDecryptWithoutKeyId(t *testing.T) {
179	// setting the keyId in DecryptInput is not required, see
180	// https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/kms#DecryptInput
181
182	fakeKMS, err := New([]string{validKeyID, validKeyID2})
183	if err != nil {
184		t.Fatalf("New() err = %s, want nil", err)
185	}
186
187	plaintext := []byte("plaintext")
188	plaintext2 := []byte("plaintext2")
189	contextValue := "contextValue"
190	context := map[string]*string{"contextName": &contextValue}
191
192	encRequest := &kms.EncryptInput{
193		KeyId:             aws.String(validKeyID),
194		Plaintext:         plaintext,
195		EncryptionContext: context,
196	}
197	encResponse, err := fakeKMS.Encrypt(encRequest)
198	if err != nil {
199		t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
200	}
201	if strings.Compare(*encResponse.KeyId, validKeyID) != 0 {
202		t.Fatalf("encResponse.KeyId = %q, want %q", *encResponse.KeyId, validKeyID)
203	}
204
205	encRequest2 := &kms.EncryptInput{
206		KeyId:             aws.String(validKeyID2),
207		Plaintext:         plaintext2,
208		EncryptionContext: context,
209	}
210	encResponse2, err := fakeKMS.Encrypt(encRequest2)
211	if err != nil {
212		t.Fatalf("fakeKMS.Encrypt(encRequest2) err = %s, want nil", err)
213	}
214	if strings.Compare(*encResponse2.KeyId, validKeyID2) != 0 {
215		t.Fatalf("encResponse2.KeyId = %q, want %q", *encResponse2.KeyId, validKeyID2)
216	}
217
218	decRequest := &kms.DecryptInput{
219		// KeyId is not set
220		CiphertextBlob:    encResponse.CiphertextBlob,
221		EncryptionContext: context,
222	}
223	decResponse, err := fakeKMS.Decrypt(decRequest)
224	if err != nil {
225		t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
226	}
227	if !bytes.Equal(decResponse.Plaintext, plaintext) {
228		t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
229	}
230	if strings.Compare(*decResponse.KeyId, validKeyID) != 0 {
231		t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID)
232	}
233
234	decRequest2 := &kms.DecryptInput{
235		// KeyId is not set
236		CiphertextBlob:    encResponse2.CiphertextBlob,
237		EncryptionContext: context,
238	}
239	decResponse2, err := fakeKMS.Decrypt(decRequest2)
240	if err != nil {
241		t.Fatalf("fakeKMS.Decrypt(decRequest2) err = %s, want nil", err)
242	}
243	if !bytes.Equal(decResponse2.Plaintext, plaintext2) {
244		t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext2)
245	}
246	if strings.Compare(*decResponse2.KeyId, validKeyID2) != 0 {
247		t.Fatalf("decResponse2.KeyId = %q, want %q", *decResponse2.KeyId, validKeyID2)
248	}
249}
250
251func TestSerializeContext(t *testing.T) {
252	uvw := "uvw"
253	xyz := "xyz"
254	rst := "rst"
255	context := map[string]*string{"def": &uvw, "abc": &xyz, "ghi": &rst}
256
257	got := string(serializeContext(context))
258	want := "{\"abc\":\"xyz\",\"def\":\"uvw\",\"ghi\":\"rst\"}"
259	if got != want {
260		t.Fatalf("SerializeContext(context) = %s, want %s", got, want)
261	}
262
263	gotEscaped := string(serializeContext(map[string]*string{"a\"b": &xyz}))
264	wantEscaped := "{\"a\\\"b\":\"xyz\"}"
265	if gotEscaped != wantEscaped {
266		t.Fatalf("SerializeContext(context) = %s, want %s", gotEscaped, wantEscaped)
267	}
268
269	gotEmpty := string(serializeContext(map[string]*string{}))
270	if gotEmpty != "{}" {
271		t.Fatalf("SerializeContext(context) = %s, want %s", gotEmpty, "{}")
272	}
273}
274