1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package awskms 18 19import ( 20 "bytes" 21 "encoding/hex" 22 "os" 23 "path/filepath" 24 "strings" 25 "testing" 26 27 "github.com/google/tink/go/integration/awskms/internal/fakeawskms" 28 "github.com/google/tink/go/core/registry" 29 "github.com/aws/aws-sdk-go/aws" 30 "github.com/aws/aws-sdk-go/service/kms" 31) 32 33func TestNewClientWithOptions_URIPrefix(t *testing.T) { 34 srcDir, ok := os.LookupEnv("TEST_SRCDIR") 35 if !ok { 36 t.Skip("TEST_SRCDIR not set") 37 } 38 39 // Necessary for testing deprecated factory functions. 40 credFile := filepath.Join(srcDir, "tink_go/testdata/aws/credentials.csv") 41 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 42 fakekms, err := fakeawskms.New([]string{keyARN}) 43 if err != nil { 44 t.Fatalf("fakeawskms.New() failed: %v", err) 45 } 46 47 tests := []struct { 48 name string 49 uriPrefix string 50 valid bool 51 }{ 52 { 53 name: "AWS partition", 54 uriPrefix: "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f", 55 valid: true, 56 }, 57 { 58 name: "AWS US government partition", 59 uriPrefix: "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f", 60 valid: true, 61 }, 62 { 63 name: "AWS CN partition", 64 uriPrefix: "aws-kms://arn:aws-cn:kms:cn-north-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f", 65 valid: true, 66 }, 67 { 68 name: "invalid", 69 uriPrefix: "bad-prefix://arn:aws-cn:kms:cn-north-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f", 70 valid: false, 71 }, 72 } 73 74 for _, test := range tests { 75 t.Run(test.name, func(t *testing.T) { 76 _, err := NewClientWithOptions(test.uriPrefix) 77 if test.valid && err != nil { 78 t.Errorf("NewClientWithOptions(%q) err = %v, want nil", test.uriPrefix, err) 79 } 80 if !test.valid && err == nil { 81 t.Errorf("NewClientWithOptions(%q) err = nil, want error", test.uriPrefix) 82 } 83 84 // Test deprecated factory functions. 85 _, err = NewClient(test.uriPrefix) 86 if test.valid && err != nil { 87 t.Errorf("NewClient(%q) err = %v, want nil", test.uriPrefix, err) 88 } 89 if !test.valid && err == nil { 90 t.Errorf("NewClient(%q) err = nil, want error", test.uriPrefix) 91 } 92 93 _, err = NewClientWithCredentials(test.uriPrefix, credFile) 94 if test.valid && err != nil { 95 t.Errorf("NewClientWithCredentialPath(%q, _) err = %v, want nil", test.uriPrefix, err) 96 } 97 if !test.valid && err == nil { 98 t.Errorf("NewClientWithCredentialPath(%q, _) err = nil, want error", test.uriPrefix) 99 } 100 101 _, err = NewClientWithKMS(test.uriPrefix, fakekms) 102 if test.valid && err != nil { 103 t.Errorf("NewClientWithKMS(%q, _) err = %v, want nil", test.uriPrefix, err) 104 } 105 if !test.valid && err == nil { 106 t.Errorf("NewClientWithKMS(%q, _) err = nil, want error", test.uriPrefix) 107 } 108 }) 109 } 110} 111 112func TestNewClientWithOptions_WithCredentialPath(t *testing.T) { 113 srcDir, ok := os.LookupEnv("TEST_SRCDIR") 114 if !ok { 115 t.Skip("TEST_SRCDIR not set") 116 } 117 118 uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/" 119 120 tests := []struct { 121 name string 122 credFile string 123 valid bool 124 }{ 125 { 126 name: "valid CSV credentials file", 127 credFile: filepath.Join(srcDir, "tink_go/testdata/aws/credentials.csv"), 128 valid: true, 129 }, 130 { 131 name: "valid INI credentials file", 132 credFile: filepath.Join(srcDir, "tink_go/testdata/aws/credentials.cred"), 133 valid: true, 134 }, 135 { 136 name: "invalid credentials file", 137 credFile: filepath.Join(srcDir, "tink_go/testdata/aws/access_keys_bad.csv"), 138 valid: false, 139 }, 140 } 141 142 for _, test := range tests { 143 t.Run(test.name, func(t *testing.T) { 144 _, err := NewClientWithOptions(uriPrefix, WithCredentialPath(test.credFile)) 145 if test.valid && err != nil { 146 t.Errorf("NewClientWithOptions(uriPrefix, WithCredentialPath(%q)) err = %v, want nil", test.credFile, err) 147 } 148 if !test.valid && err == nil { 149 t.Errorf("NewClientWithOptions(uriPrefix, WithCredentialPath(%q)) err = nil, want error", test.credFile) 150 } 151 152 // Test deprecated factory function. 153 _, err = NewClientWithCredentials(uriPrefix, test.credFile) 154 if test.valid && err != nil { 155 t.Errorf("NewClientWithCredentials(uriPrefix, %q) err = %v, want nil", test.credFile, err) 156 } 157 if !test.valid && err == nil { 158 t.Errorf("NewClientWithCredentials(uriPrefix, %q) err = nil, want error", test.credFile) 159 } 160 161 }) 162 } 163} 164 165func TestNewClientWithOptions_RepeatedWithKMSFails(t *testing.T) { 166 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 167 fakekms, err := fakeawskms.New([]string{keyARN}) 168 if err != nil { 169 t.Fatalf("fakekms.New() failed: %v", err) 170 } 171 172 _, err = NewClientWithOptions("aws-kms://", WithKMS(fakekms), WithKMS(fakekms)) 173 if err == nil { 174 t.Fatalf("NewClientWithOptions(_, WithKMS(_), WithKMS(_)) err = nil, want error") 175 } 176} 177 178func TestNewClientWithOptions_RepeatedWithEncryptionContextNameFails(t *testing.T) { 179 _, err := NewClientWithOptions("aws-kms://", WithEncryptionContextName(LegacyAdditionalData), WithEncryptionContextName(AssociatedData)) 180 if err == nil { 181 t.Fatalf("NewClientWithOptions(_, WithEncryptionContextName(_), WithEncryptionContextName(_)) err = nil, want error") 182 } 183} 184 185func TestSupported(t *testing.T) { 186 uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/" 187 supportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 188 nonSupportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-DOES-NOT-EXIST:key/" 189 190 client, err := NewClientWithOptions(uriPrefix) 191 if err != nil { 192 t.Fatalf("NewClientWithOptions() failed: %v", err) 193 } 194 195 if !client.Supported(supportedKeyURI) { 196 t.Errorf("client with URI prefix %q should support key URI %q", uriPrefix, supportedKeyURI) 197 } 198 199 if client.Supported(nonSupportedKeyURI) { 200 t.Errorf("client with URI prefix %q should NOT support key URI %q", uriPrefix, nonSupportedKeyURI) 201 } 202} 203 204func TestGetAEADSupportedURI(t *testing.T) { 205 uriPrefix := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/" 206 supportedKeyURI := "aws-kms://arn:aws-us-gov:kms:us-gov-east-1:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 207 208 client, err := NewClientWithOptions(uriPrefix) 209 if err != nil { 210 t.Fatalf("NewClientWithOptions() failed: %v", err) 211 } 212 213 _, err = client.GetAEAD(supportedKeyURI) 214 if err != nil { 215 t.Errorf("client with URI prefix %q should support key URI %q", uriPrefix, supportedKeyURI) 216 } 217} 218 219func TestGetAEADEncryptDecrypt(t *testing.T) { 220 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 221 keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 222 fakekms, err := fakeawskms.New([]string{keyARN}) 223 if err != nil { 224 t.Fatalf("fakekms.New() failed: %v", err) 225 } 226 227 client, err := NewClientWithOptions("aws-kms://", WithKMS(fakekms)) 228 if err != nil { 229 t.Fatalf("NewClientWithOptions() failed: %v", err) 230 } 231 232 a, err := client.GetAEAD(keyURI) 233 if err != nil { 234 t.Fatalf("client.GetAEAD(keyURI) err = %v, want nil", err) 235 } 236 237 plaintext := []byte("plaintext") 238 associatedData := []byte("associatedData") 239 ciphertext, err := a.Encrypt(plaintext, associatedData) 240 if err != nil { 241 t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err) 242 } 243 decrypted, err := a.Decrypt(ciphertext, associatedData) 244 if err != nil { 245 t.Fatalf("a.Decrypt(ciphertext, associatedData) err = %v, want nil", err) 246 } 247 if !bytes.Equal(decrypted, plaintext) { 248 t.Errorf("decrypted = %q, want %q", decrypted, plaintext) 249 } 250 251 _, err = a.Decrypt(ciphertext, []byte("invalidAssociatedData")) 252 if err == nil { 253 t.Error("a.Decrypt(ciphertext, []byte(\"invalidAssociatedData\")) err = nil, want error") 254 } 255 256 _, err = a.Decrypt([]byte("invalidCiphertext"), associatedData) 257 if err == nil { 258 t.Error("a.Decrypt([]byte(\"invalidCiphertext\"), associatedData) err = nil, want error") 259 } 260} 261 262func TestUsesAdditionalDataAsContextName(t *testing.T) { 263 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 264 keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 265 fakekms, err := fakeawskms.New([]string{keyARN}) 266 if err != nil { 267 t.Fatalf("fakeawskms.New() failed: %v", err) 268 } 269 270 client, err := NewClientWithKMS("aws-kms://", fakekms) 271 if err != nil { 272 t.Fatalf("NewClientWithKMS() failed: %v", err) 273 } 274 275 a, err := client.GetAEAD(keyURI) 276 if err != nil { 277 t.Fatalf("client.GetAEAD(keyURI) failed: %s", err) 278 } 279 280 plaintext := []byte("plaintext") 281 associatedData := []byte("associatedData") 282 ciphertext, err := a.Encrypt(plaintext, associatedData) 283 if err != nil { 284 t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err) 285 } 286 287 hexAD := hex.EncodeToString(associatedData) 288 context := map[string]*string{"additionalData": &hexAD} 289 decRequest := &kms.DecryptInput{ 290 KeyId: aws.String(keyARN), 291 CiphertextBlob: ciphertext, 292 EncryptionContext: context, 293 } 294 decResponse, err := fakekms.Decrypt(decRequest) 295 if err != nil { 296 t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) 297 } 298 if !bytes.Equal(decResponse.Plaintext, plaintext) { 299 t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) 300 } 301 if strings.Compare(*decResponse.KeyId, keyARN) != 0 { 302 t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN) 303 } 304} 305 306func TestEncryptionContextName(t *testing.T) { 307 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 308 keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 309 fakekms, err := fakeawskms.New([]string{keyARN}) 310 if err != nil { 311 t.Fatalf("fakeawskms.New() failed: %v", err) 312 } 313 314 tests := []struct { 315 contextName EncryptionContextName 316 wantContextName string 317 }{ 318 { 319 contextName: LegacyAdditionalData, 320 wantContextName: "additionalData", 321 }, 322 { 323 contextName: AssociatedData, 324 wantContextName: "associatedData", 325 }, 326 } 327 328 for _, test := range tests { 329 t.Run(test.wantContextName, func(t *testing.T) { 330 client, err := NewClientWithOptions("aws-kms://", WithKMS(fakekms), WithEncryptionContextName(test.contextName)) 331 if err != nil { 332 t.Fatalf("NewClientWithOptions() failed: %v", err) 333 } 334 335 a, err := client.GetAEAD(keyURI) 336 if err != nil { 337 t.Fatalf("client.GetAEAD(keyURI) failed: %s", err) 338 } 339 340 plaintext := []byte("plaintext") 341 associatedData := []byte("associatedData") 342 ciphertext, err := a.Encrypt(plaintext, associatedData) 343 if err != nil { 344 t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err) 345 } 346 347 hexAD := hex.EncodeToString(associatedData) 348 context := map[string]*string{test.wantContextName: &hexAD} 349 decRequest := &kms.DecryptInput{ 350 KeyId: aws.String(keyARN), 351 CiphertextBlob: ciphertext, 352 EncryptionContext: context, 353 } 354 decResponse, err := fakekms.Decrypt(decRequest) 355 if err != nil { 356 t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) 357 } 358 if !bytes.Equal(decResponse.Plaintext, plaintext) { 359 t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) 360 } 361 if strings.Compare(*decResponse.KeyId, keyARN) != 0 { 362 t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN) 363 } 364 }) 365 } 366} 367 368func TestEncryptionContextName_defaultEncryptionContextName(t *testing.T) { 369 keyARN := "arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 370 keyURI := "aws-kms://arn:aws:kms:us-east-2:235739564943:key/3ee50705-5a82-4f5b-9753-05c4f473922f" 371 fakekms, err := fakeawskms.New([]string{keyARN}) 372 if err != nil { 373 t.Fatalf("fakeawskms.New() failed: %v", err) 374 } 375 376 tests := []struct { 377 name string 378 client func(t *testing.T) registry.KMSClient 379 wantContextName string 380 }{ 381 { 382 name: "NewClientWithOptions", 383 client: func(t *testing.T) registry.KMSClient{ 384 t.Helper() 385 c, err := NewClientWithOptions(keyURI, WithKMS(fakekms)) 386 if err != nil { 387 t.Fatalf("NewClientWithOptions() failed: %v", err) 388 } 389 return c 390 391 }, 392 wantContextName: "associatedData", 393 }, 394 // Test deprecated factory function. 395 { 396 name: "NewClientWithKMS", 397 client: func(t *testing.T) registry.KMSClient{ 398 t.Helper() 399 c, err := NewClientWithKMS(keyURI, fakekms) 400 if err != nil { 401 t.Fatalf("NewClientWithKMS() failed: %v", err) 402 } 403 return c 404 }, 405 wantContextName: "additionalData", 406 }, 407 } 408 409 for _, test := range tests { 410 t.Run(test.name, func(t *testing.T) { 411 client := test.client(t) 412 a, err := client.GetAEAD(keyURI) 413 if err != nil { 414 t.Fatalf("client.GetAEAD(keyURI) failed: %s", err) 415 } 416 417 plaintext := []byte("plaintext") 418 associatedData := []byte("associatedData") 419 ciphertext, err := a.Encrypt(plaintext, associatedData) 420 if err != nil { 421 t.Fatalf("a.Encrypt(plaintext, associatedData) err = %v, want nil", err) 422 } 423 424 hexAD := hex.EncodeToString(associatedData) 425 context := map[string]*string{test.wantContextName: &hexAD} 426 decRequest := &kms.DecryptInput{ 427 KeyId: aws.String(keyARN), 428 CiphertextBlob: ciphertext, 429 EncryptionContext: context, 430 } 431 decResponse, err := fakekms.Decrypt(decRequest) 432 if err != nil { 433 t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) 434 } 435 if !bytes.Equal(decResponse.Plaintext, plaintext) { 436 t.Errorf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) 437 } 438 if strings.Compare(*decResponse.KeyId, keyARN) != 0 { 439 t.Errorf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, keyARN) 440 } 441 }) 442 } 443} 444