xref: /aosp_15_r20/external/tink/go/daead/daead_factory.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package daead
18
19import (
20	"fmt"
21
22	"github.com/google/tink/go/core/cryptofmt"
23	"github.com/google/tink/go/core/primitiveset"
24	"github.com/google/tink/go/internal/internalregistry"
25	"github.com/google/tink/go/internal/monitoringutil"
26	"github.com/google/tink/go/keyset"
27	"github.com/google/tink/go/monitoring"
28	"github.com/google/tink/go/tink"
29)
30
31// New returns a DeterministicAEAD primitive from the given keyset handle.
32func New(handle *keyset.Handle) (tink.DeterministicAEAD, error) {
33	ps, err := handle.Primitives()
34	if err != nil {
35		return nil, fmt.Errorf("daead_factory: cannot obtain primitive set: %s", err)
36	}
37	return newWrappedDeterministicAEAD(ps)
38}
39
40// wrappedDeterministicAEAD is a DeterministicAEAD implementation that uses an underlying primitive set
41// for deterministic encryption and decryption.
42type wrappedDeterministicAEAD struct {
43	ps        *primitiveset.PrimitiveSet
44	encLogger monitoring.Logger
45	decLogger monitoring.Logger
46}
47
48// Asserts that wrappedDeterministicAEAD implements the DeterministicAEAD interface.
49var _ tink.DeterministicAEAD = (*wrappedDeterministicAEAD)(nil)
50
51func newWrappedDeterministicAEAD(ps *primitiveset.PrimitiveSet) (*wrappedDeterministicAEAD, error) {
52	if _, ok := (ps.Primary.Primitive).(tink.DeterministicAEAD); !ok {
53		return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
54	}
55
56	for _, primitives := range ps.Entries {
57		for _, p := range primitives {
58			if _, ok := (p.Primitive).(tink.DeterministicAEAD); !ok {
59				return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
60			}
61		}
62	}
63	encLogger, decLogger, err := createLoggers(ps)
64	if err != nil {
65		return nil, err
66	}
67	return &wrappedDeterministicAEAD{
68		ps:        ps,
69		encLogger: encLogger,
70		decLogger: decLogger,
71	}, nil
72}
73
74func createLoggers(ps *primitiveset.PrimitiveSet) (monitoring.Logger, monitoring.Logger, error) {
75	if len(ps.Annotations) == 0 {
76		return &monitoringutil.DoNothingLogger{}, &monitoringutil.DoNothingLogger{}, nil
77	}
78	client := internalregistry.GetMonitoringClient()
79	keysetInfo, err := monitoringutil.KeysetInfoFromPrimitiveSet(ps)
80	if err != nil {
81		return nil, nil, err
82	}
83	encLogger, err := client.NewLogger(&monitoring.Context{
84		Primitive:   "daead",
85		APIFunction: "encrypt",
86		KeysetInfo:  keysetInfo,
87	})
88	if err != nil {
89		return nil, nil, err
90	}
91	decLogger, err := client.NewLogger(&monitoring.Context{
92		Primitive:   "daead",
93		APIFunction: "decrypt",
94		KeysetInfo:  keysetInfo,
95	})
96	if err != nil {
97		return nil, nil, err
98	}
99	return encLogger, decLogger, nil
100}
101
102// EncryptDeterministically deterministically encrypts plaintext with additionalData as additional authenticated data.
103// It returns the concatenation of the primary's identifier and the ciphertext.
104func (d *wrappedDeterministicAEAD) EncryptDeterministically(pt, aad []byte) ([]byte, error) {
105	primary := d.ps.Primary
106	p, ok := (primary.Primitive).(tink.DeterministicAEAD)
107	if !ok {
108		return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
109	}
110
111	ct, err := p.EncryptDeterministically(pt, aad)
112	if err != nil {
113		d.encLogger.LogFailure()
114		return nil, err
115	}
116	d.encLogger.Log(primary.KeyID, len(pt))
117	if len(primary.Prefix) == 0 {
118		return ct, nil
119	}
120	output := make([]byte, 0, len(primary.Prefix)+len(ct))
121	output = append(output, primary.Prefix...)
122	output = append(output, ct...)
123	return output, nil
124}
125
126// DecryptDeterministically deterministically decrypts ciphertext with additionalData as
127// additional authenticated data. It returns the corresponding plaintext if the
128// ciphertext is authenticated.
129func (d *wrappedDeterministicAEAD) DecryptDeterministically(ct, aad []byte) ([]byte, error) {
130	// try non-raw keys
131	prefixSize := cryptofmt.NonRawPrefixSize
132	if len(ct) > prefixSize {
133		prefix := ct[:prefixSize]
134		ctNoPrefix := ct[prefixSize:]
135		entries, err := d.ps.EntriesForPrefix(string(prefix))
136		if err == nil {
137			for i := 0; i < len(entries); i++ {
138				p, ok := (entries[i].Primitive).(tink.DeterministicAEAD)
139				if !ok {
140					return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
141				}
142
143				pt, err := p.DecryptDeterministically(ctNoPrefix, aad)
144				if err == nil {
145					d.decLogger.Log(entries[i].KeyID, len(ctNoPrefix))
146					return pt, nil
147				}
148			}
149		}
150	}
151
152	// try raw keys
153	entries, err := d.ps.RawEntries()
154	if err == nil {
155		for i := 0; i < len(entries); i++ {
156			p, ok := (entries[i].Primitive).(tink.DeterministicAEAD)
157			if !ok {
158				return nil, fmt.Errorf("daead_factory: not a DeterministicAEAD primitive")
159			}
160
161			pt, err := p.DecryptDeterministically(ct, aad)
162			if err == nil {
163				d.decLogger.Log(entries[i].KeyID, len(ct))
164				return pt, nil
165			}
166		}
167	}
168	// nothing worked
169	d.decLogger.LogFailure()
170	return nil, fmt.Errorf("daead_factory: decryption failed")
171}
172