xref: /aosp_15_r20/external/tink/go/aead/aead_factory.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package aead
18
19import (
20	"fmt"
21
22	"github.com/google/tink/go/core/cryptofmt"
23	"github.com/google/tink/go/core/primitiveset"
24	"github.com/google/tink/go/internal/internalregistry"
25	"github.com/google/tink/go/internal/monitoringutil"
26	"github.com/google/tink/go/keyset"
27	"github.com/google/tink/go/monitoring"
28	"github.com/google/tink/go/tink"
29)
30
31// New returns an AEAD primitive from the given keyset handle.
32func New(handle *keyset.Handle) (tink.AEAD, error) {
33	ps, err := handle.Primitives()
34	if err != nil {
35		return nil, fmt.Errorf("aead_factory: cannot obtain primitive set: %s", err)
36	}
37	return newWrappedAead(ps)
38}
39
40// wrappedAead is an AEAD implementation that uses the underlying primitive set for encryption
41// and decryption.
42type wrappedAead struct {
43	ps        *primitiveset.PrimitiveSet
44	encLogger monitoring.Logger
45	decLogger monitoring.Logger
46}
47
48func newWrappedAead(ps *primitiveset.PrimitiveSet) (*wrappedAead, error) {
49	if _, ok := (ps.Primary.Primitive).(tink.AEAD); !ok {
50		return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
51	}
52
53	for _, primitives := range ps.Entries {
54		for _, p := range primitives {
55			if _, ok := (p.Primitive).(tink.AEAD); !ok {
56				return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
57			}
58		}
59	}
60	encLogger, decLogger, err := createLoggers(ps)
61	if err != nil {
62		return nil, err
63	}
64	return &wrappedAead{
65		ps:        ps,
66		encLogger: encLogger,
67		decLogger: decLogger,
68	}, nil
69}
70
71func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) {
72	if len(ps.Annotations) == 0 {
73		return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil
74	}
75	client := internalregistry.GetMonitoringClient()
76	keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
77	if err != nil {
78		return nil, nil, err
79	}
80	encLogger, err := client.NewLogger(&monitoring.Context{
81		Primitive:   "aead",
82		APIFunction: "encrypt",
83		KeysetInfo:  keysetInfo,
84	})
85	if err != nil {
86		return nil, nil, err
87	}
88	decLogger, err := client.NewLogger(&monitoring.Context{
89		Primitive:   "aead",
90		APIFunction: "decrypt",
91		KeysetInfo:  keysetInfo,
92	})
93	if err != nil {
94		return nil, nil, err
95	}
96	return encLogger, decLogger, nil
97}
98
99// Encrypt encrypts the given plaintext with the given associatedData.
100// It returns the concatenation of the primary's identifier and the ciphertext.
101func (a *wrappedAead) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
102	primary := a.ps.Primary
103	p, ok := (primary.Primitive).(tink.AEAD)
104	if !ok {
105		return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
106	}
107	ct, err := p.Encrypt(plaintext, associatedData)
108	if err != nil {
109		a.encLogger.LogFailure()
110		return nil, err
111	}
112	a.encLogger.Log(primary.KeyID, len(plaintext))
113	if len(primary.Prefix) == 0 {
114		return ct, nil
115	}
116	output := make([]byte, 0, len(primary.Prefix)+len(ct))
117	output = append(output, primary.Prefix...)
118	output = append(output, ct...)
119	return output, nil
120}
121
122// Decrypt decrypts the given ciphertext and authenticates it with the given
123// associatedData. It returns the corresponding plaintext if the
124// ciphertext is authenticated.
125func (a *wrappedAead) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
126	// try non-raw keys
127	prefixSize := cryptofmt.NonRawPrefixSize
128	if len(ciphertext) > prefixSize {
129		prefix := ciphertext[:prefixSize]
130		ctNoPrefix := ciphertext[prefixSize:]
131		entries, err := a.ps.EntriesForPrefix(string(prefix))
132		if err == nil {
133			for i := 0; i < len(entries); i++ {
134				p, ok := (entries[i].Primitive).(tink.AEAD)
135				if !ok {
136					return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
137				}
138
139				pt, err := p.Decrypt(ctNoPrefix, associatedData)
140				if err == nil {
141					a.decLogger.Log(entries[i].KeyID, len(ctNoPrefix))
142					return pt, nil
143				}
144			}
145		}
146	}
147	// try raw keys
148	entries, err := a.ps.RawEntries()
149	if err == nil {
150		for i := 0; i < len(entries); i++ {
151			p, ok := (entries[i].Primitive).(tink.AEAD)
152			if !ok {
153				return nil, fmt.Errorf("aead_factory: not an AEAD primitive")
154			}
155
156			pt, err := p.Decrypt(ciphertext, associatedData)
157			if err == nil {
158				a.decLogger.Log(entries[i].KeyID, len(ciphertext))
159				return pt, nil
160			}
161		}
162	}
163	// nothing worked
164	a.decLogger.LogFailure()
165	return nil, fmt.Errorf("aead_factory: decryption failed")
166}
167