xref: /aosp_15_r20/external/tink/go/integration/awskms/internal/fakeawskms/fakeawskms.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17// Package fakeawskms provides a partial fake implementation of kmsiface.KMSAPI.
18package fakeawskms
19
20import (
21	"bytes"
22	"errors"
23	"fmt"
24	"sort"
25
26	"github.com/aws/aws-sdk-go/service/kms"
27	"github.com/aws/aws-sdk-go/service/kms/kmsiface"
28	"github.com/google/tink/go/aead"
29	"github.com/google/tink/go/keyset"
30	"github.com/google/tink/go/tink"
31)
32
33type fakeAWSKMS struct {
34	kmsiface.KMSAPI
35	aeads  map[string]tink.AEAD
36	keyIDs []string
37}
38
39// serializeContext serializes the context map in a canonical way into a byte array.
40func serializeContext(context map[string]*string) []byte {
41	names := make([]string, 0, len(context))
42	for name := range context {
43		names = append(names, name)
44	}
45	sort.Strings(names)
46	b := new(bytes.Buffer)
47	b.WriteString("{")
48	for i, name := range names {
49		if i > 0 {
50			b.WriteString(",")
51		}
52		fmt.Fprintf(b, "%q:%q", name, *context[name])
53	}
54	b.WriteString("}")
55	return b.Bytes()
56}
57
58// New returns a new fake AWS KMS API.
59func New(validKeyIDs []string) (kmsiface.KMSAPI, error) {
60	aeads := make(map[string]tink.AEAD)
61	for _, keyID := range validKeyIDs {
62		handle, err := keyset.NewHandle(aead.AES256GCMKeyTemplate())
63		if err != nil {
64			return nil, err
65		}
66		a, err := aead.New(handle)
67		if err != nil {
68			return nil, err
69		}
70		aeads[keyID] = a
71	}
72	return &fakeAWSKMS{
73		aeads:  aeads,
74		keyIDs: validKeyIDs,
75	}, nil
76}
77
78func (f *fakeAWSKMS) Encrypt(request *kms.EncryptInput) (*kms.EncryptOutput, error) {
79	a, ok := f.aeads[*request.KeyId]
80	if !ok {
81		return nil, fmt.Errorf("Unknown keyID: %q not in %q", *request.KeyId, f.keyIDs)
82	}
83	serializedContext := serializeContext(request.EncryptionContext)
84	ciphertext, err := a.Encrypt(request.Plaintext, serializedContext)
85	if err != nil {
86		return nil, err
87	}
88	return &kms.EncryptOutput{
89		CiphertextBlob: ciphertext,
90		KeyId:          request.KeyId,
91	}, nil
92}
93
94func (f *fakeAWSKMS) Decrypt(request *kms.DecryptInput) (*kms.DecryptOutput, error) {
95	serializedContext := serializeContext(request.EncryptionContext)
96	if request.KeyId != nil {
97		a, ok := f.aeads[*request.KeyId]
98		if !ok {
99			return nil, fmt.Errorf("Unknown keyID: %q not in %q", *request.KeyId, f.keyIDs)
100		}
101		plaintext, err := a.Decrypt(request.CiphertextBlob, serializedContext)
102		if err != nil {
103			return nil, fmt.Errorf("Decryption with keyID %q failed", *request.KeyId)
104		}
105		return &kms.DecryptOutput{
106			Plaintext: plaintext,
107			KeyId:     request.KeyId,
108		}, nil
109	}
110	// When KeyId is not set, try out all AEADs.
111	for keyID, a := range f.aeads {
112		plaintext, err := a.Decrypt(request.CiphertextBlob, serializedContext)
113		if err == nil {
114			return &kms.DecryptOutput{
115				Plaintext: plaintext,
116				KeyId:     &keyID,
117			}, nil
118		}
119	}
120	return nil, errors.New("unable to decrypt message")
121}
122