1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package fakeawskms 18 19import ( 20 "bytes" 21 "strings" 22 "testing" 23 24 "github.com/aws/aws-sdk-go/aws" 25 "github.com/aws/aws-sdk-go/service/kms" 26) 27 28const validKeyID = "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab" 29const validKeyID2 = "arn:aws:kms:us-west-2:123:key/different" 30 31func TestEncyptDecryptWithValidKeyId(t *testing.T) { 32 fakeKMS, err := New([]string{validKeyID}) 33 if err != nil { 34 t.Fatalf("New() err = %s, want nil", err) 35 } 36 37 plaintext := []byte("plaintext") 38 contextValue := "contextValue" 39 context := map[string]*string{"contextName": &contextValue} 40 41 encRequest := &kms.EncryptInput{ 42 KeyId: aws.String(validKeyID), 43 Plaintext: plaintext, 44 EncryptionContext: context, 45 } 46 47 encResponse, err := fakeKMS.Encrypt(encRequest) 48 if err != nil { 49 t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) 50 } 51 52 ciphertext := encResponse.CiphertextBlob 53 54 decRequest := &kms.DecryptInput{ 55 KeyId: aws.String(validKeyID), 56 CiphertextBlob: ciphertext, 57 EncryptionContext: context, 58 } 59 decResponse, err := fakeKMS.Decrypt(decRequest) 60 if err != nil { 61 t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) 62 } 63 if !bytes.Equal(decResponse.Plaintext, plaintext) { 64 t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) 65 } 66 if strings.Compare(*decResponse.KeyId, validKeyID) != 0 { 67 t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID) 68 } 69 70 // decrypt with a different context should fail 71 otherContextValue := "otherContextValue" 72 otherContext := map[string]*string{"contextName": &otherContextValue} 73 otherDecRequest := &kms.DecryptInput{ 74 KeyId: aws.String(validKeyID), 75 CiphertextBlob: ciphertext, 76 EncryptionContext: otherContext, 77 } 78 if _, err := fakeKMS.Decrypt(otherDecRequest); err == nil { 79 t.Fatal("fakeKMS.Decrypt(otherDecRequest) err = nil, want not nil") 80 } 81} 82 83func TestEncyptWithUnknownKeyID(t *testing.T) { 84 fakeKMS, err := New([]string{validKeyID}) 85 if err != nil { 86 t.Fatalf("New() err = %s, want nil", err) 87 } 88 89 plaintext := []byte("plaintext") 90 contextValue := "contextValue" 91 context := map[string]*string{"contextName": &contextValue} 92 93 encRequestWithUnknownKeyID := &kms.EncryptInput{ 94 KeyId: aws.String(validKeyID2), 95 Plaintext: plaintext, 96 EncryptionContext: context, 97 } 98 99 if _, err := fakeKMS.Encrypt(encRequestWithUnknownKeyID); err == nil { 100 t.Fatal("fakeKMS.Encrypt(encRequestWithvalidKeyID2) err = nil, want not nil") 101 } 102} 103 104func TestDecryptWithInvalidCiphertext(t *testing.T) { 105 fakeKMS, err := New([]string{validKeyID}) 106 if err != nil { 107 t.Fatalf("New() err = %s, want nil", err) 108 } 109 110 invalidCiphertext := []byte("plaintext") 111 contextValue := "contextValue" 112 context := map[string]*string{"contextName": &contextValue} 113 114 decRequest := &kms.DecryptInput{ 115 CiphertextBlob: invalidCiphertext, 116 EncryptionContext: context, 117 } 118 119 if _, err := fakeKMS.Decrypt(decRequest); err == nil { 120 t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") 121 } 122} 123 124func TestDecryptWithUnknownKeyId(t *testing.T) { 125 fakeKMS, err := New([]string{validKeyID}) 126 if err != nil { 127 t.Fatalf("New() err = %s, want nil", err) 128 } 129 130 ciphertext := []byte("invalidCiphertext") 131 contextValue := "contextValue" 132 context := map[string]*string{"contextName": &contextValue} 133 134 decRequest := &kms.DecryptInput{ 135 KeyId: aws.String(validKeyID2), 136 CiphertextBlob: ciphertext, 137 EncryptionContext: context, 138 } 139 140 if _, err := fakeKMS.Decrypt(decRequest); err == nil { 141 t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") 142 } 143} 144 145func TestDecryptWithWrongKeyId(t *testing.T) { 146 fakeKMS, err := New([]string{validKeyID, validKeyID2}) 147 if err != nil { 148 t.Fatalf("New() err = %s, want nil", err) 149 } 150 151 plaintext := []byte("plaintext") 152 contextValue := "contextValue" 153 context := map[string]*string{"contextName": &contextValue} 154 155 encRequest := &kms.EncryptInput{ 156 KeyId: aws.String(validKeyID), 157 Plaintext: plaintext, 158 EncryptionContext: context, 159 } 160 161 encResponse, err := fakeKMS.Encrypt(encRequest) 162 if err != nil { 163 t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) 164 } 165 166 ciphertext := encResponse.CiphertextBlob 167 168 decRequest := &kms.DecryptInput{ 169 KeyId: aws.String(validKeyID2), // wrong key id 170 CiphertextBlob: ciphertext, 171 EncryptionContext: context, 172 } 173 if _, err := fakeKMS.Decrypt(decRequest); err == nil { 174 t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") 175 } 176} 177 178func TestDecryptWithoutKeyId(t *testing.T) { 179 // setting the keyId in DecryptInput is not required, see 180 // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/kms#DecryptInput 181 182 fakeKMS, err := New([]string{validKeyID, validKeyID2}) 183 if err != nil { 184 t.Fatalf("New() err = %s, want nil", err) 185 } 186 187 plaintext := []byte("plaintext") 188 plaintext2 := []byte("plaintext2") 189 contextValue := "contextValue" 190 context := map[string]*string{"contextName": &contextValue} 191 192 encRequest := &kms.EncryptInput{ 193 KeyId: aws.String(validKeyID), 194 Plaintext: plaintext, 195 EncryptionContext: context, 196 } 197 encResponse, err := fakeKMS.Encrypt(encRequest) 198 if err != nil { 199 t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) 200 } 201 if strings.Compare(*encResponse.KeyId, validKeyID) != 0 { 202 t.Fatalf("encResponse.KeyId = %q, want %q", *encResponse.KeyId, validKeyID) 203 } 204 205 encRequest2 := &kms.EncryptInput{ 206 KeyId: aws.String(validKeyID2), 207 Plaintext: plaintext2, 208 EncryptionContext: context, 209 } 210 encResponse2, err := fakeKMS.Encrypt(encRequest2) 211 if err != nil { 212 t.Fatalf("fakeKMS.Encrypt(encRequest2) err = %s, want nil", err) 213 } 214 if strings.Compare(*encResponse2.KeyId, validKeyID2) != 0 { 215 t.Fatalf("encResponse2.KeyId = %q, want %q", *encResponse2.KeyId, validKeyID2) 216 } 217 218 decRequest := &kms.DecryptInput{ 219 // KeyId is not set 220 CiphertextBlob: encResponse.CiphertextBlob, 221 EncryptionContext: context, 222 } 223 decResponse, err := fakeKMS.Decrypt(decRequest) 224 if err != nil { 225 t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) 226 } 227 if !bytes.Equal(decResponse.Plaintext, plaintext) { 228 t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) 229 } 230 if strings.Compare(*decResponse.KeyId, validKeyID) != 0 { 231 t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID) 232 } 233 234 decRequest2 := &kms.DecryptInput{ 235 // KeyId is not set 236 CiphertextBlob: encResponse2.CiphertextBlob, 237 EncryptionContext: context, 238 } 239 decResponse2, err := fakeKMS.Decrypt(decRequest2) 240 if err != nil { 241 t.Fatalf("fakeKMS.Decrypt(decRequest2) err = %s, want nil", err) 242 } 243 if !bytes.Equal(decResponse2.Plaintext, plaintext2) { 244 t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext2) 245 } 246 if strings.Compare(*decResponse2.KeyId, validKeyID2) != 0 { 247 t.Fatalf("decResponse2.KeyId = %q, want %q", *decResponse2.KeyId, validKeyID2) 248 } 249} 250 251func TestSerializeContext(t *testing.T) { 252 uvw := "uvw" 253 xyz := "xyz" 254 rst := "rst" 255 context := map[string]*string{"def": &uvw, "abc": &xyz, "ghi": &rst} 256 257 got := string(serializeContext(context)) 258 want := "{\"abc\":\"xyz\",\"def\":\"uvw\",\"ghi\":\"rst\"}" 259 if got != want { 260 t.Fatalf("SerializeContext(context) = %s, want %s", got, want) 261 } 262 263 gotEscaped := string(serializeContext(map[string]*string{"a\"b": &xyz})) 264 wantEscaped := "{\"a\\\"b\":\"xyz\"}" 265 if gotEscaped != wantEscaped { 266 t.Fatalf("SerializeContext(context) = %s, want %s", gotEscaped, wantEscaped) 267 } 268 269 gotEmpty := string(serializeContext(map[string]*string{})) 270 if gotEmpty != "{}" { 271 t.Fatalf("SerializeContext(context) = %s, want %s", gotEmpty, "{}") 272 } 273} 274