1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17// Package fakeawskms provides a partial fake implementation of kmsiface.KMSAPI. 18package fakeawskms 19 20import ( 21 "bytes" 22 "errors" 23 "fmt" 24 "sort" 25 26 "github.com/aws/aws-sdk-go/service/kms" 27 "github.com/aws/aws-sdk-go/service/kms/kmsiface" 28 "github.com/google/tink/go/aead" 29 "github.com/google/tink/go/keyset" 30 "github.com/google/tink/go/tink" 31) 32 33type fakeAWSKMS struct { 34 kmsiface.KMSAPI 35 aeads map[string]tink.AEAD 36 keyIDs []string 37} 38 39// serializeContext serializes the context map in a canonical way into a byte array. 40func serializeContext(context map[string]*string) []byte { 41 names := make([]string, 0, len(context)) 42 for name := range context { 43 names = append(names, name) 44 } 45 sort.Strings(names) 46 b := new(bytes.Buffer) 47 b.WriteString("{") 48 for i, name := range names { 49 if i > 0 { 50 b.WriteString(",") 51 } 52 fmt.Fprintf(b, "%q:%q", name, *context[name]) 53 } 54 b.WriteString("}") 55 return b.Bytes() 56} 57 58// New returns a new fake AWS KMS API. 59func New(validKeyIDs []string) (kmsiface.KMSAPI, error) { 60 aeads := make(map[string]tink.AEAD) 61 for _, keyID := range validKeyIDs { 62 handle, err := keyset.NewHandle(aead.AES256GCMKeyTemplate()) 63 if err != nil { 64 return nil, err 65 } 66 a, err := aead.New(handle) 67 if err != nil { 68 return nil, err 69 } 70 aeads[keyID] = a 71 } 72 return &fakeAWSKMS{ 73 aeads: aeads, 74 keyIDs: validKeyIDs, 75 }, nil 76} 77 78func (f *fakeAWSKMS) Encrypt(request *kms.EncryptInput) (*kms.EncryptOutput, error) { 79 a, ok := f.aeads[*request.KeyId] 80 if !ok { 81 return nil, fmt.Errorf("Unknown keyID: %q not in %q", *request.KeyId, f.keyIDs) 82 } 83 serializedContext := serializeContext(request.EncryptionContext) 84 ciphertext, err := a.Encrypt(request.Plaintext, serializedContext) 85 if err != nil { 86 return nil, err 87 } 88 return &kms.EncryptOutput{ 89 CiphertextBlob: ciphertext, 90 KeyId: request.KeyId, 91 }, nil 92} 93 94func (f *fakeAWSKMS) Decrypt(request *kms.DecryptInput) (*kms.DecryptOutput, error) { 95 serializedContext := serializeContext(request.EncryptionContext) 96 if request.KeyId != nil { 97 a, ok := f.aeads[*request.KeyId] 98 if !ok { 99 return nil, fmt.Errorf("Unknown keyID: %q not in %q", *request.KeyId, f.keyIDs) 100 } 101 plaintext, err := a.Decrypt(request.CiphertextBlob, serializedContext) 102 if err != nil { 103 return nil, fmt.Errorf("Decryption with keyID %q failed", *request.KeyId) 104 } 105 return &kms.DecryptOutput{ 106 Plaintext: plaintext, 107 KeyId: request.KeyId, 108 }, nil 109 } 110 // When KeyId is not set, try out all AEADs. 111 for keyID, a := range f.aeads { 112 plaintext, err := a.Decrypt(request.CiphertextBlob, serializedContext) 113 if err == nil { 114 return &kms.DecryptOutput{ 115 Plaintext: plaintext, 116 KeyId: &keyID, 117 }, nil 118 } 119 } 120 return nil, errors.New("unable to decrypt message") 121} 122