xref: /aosp_15_r20/external/tink/go/streamingaead/decrypt_reader.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package streamingaead
18
19import (
20	"errors"
21	"io"
22
23	"github.com/google/tink/go/core/primitiveset"
24	"github.com/google/tink/go/tink"
25)
26
27var (
28	_              io.Reader = &decryptReader{}
29	errKeyNotFound           = errors.New("no matching key found for the ciphertext in the stream")
30)
31
32// decryptReader is a reader that tries to find the right key to decrypt ciphertext from the given primitive set.
33type decryptReader struct {
34	wrapped *wrappedStreamingAEAD
35	// cr is a source Reader which provides ciphertext to be decrypted.
36	cr  io.Reader
37	aad []byte
38
39	matchAttempted bool
40	// mr is a matched decrypting reader initialized with a proper key to decrypt ciphertext.
41	mr io.Reader
42}
43
44func (dr *decryptReader) Read(p []byte) (n int, err error) {
45	if dr.mr != nil {
46		return dr.mr.Read(p)
47	}
48	if dr.matchAttempted {
49		return 0, errKeyNotFound
50	}
51
52	// For legacy reasons (Tink always encrypted with non-RAW keys) we use all
53	// primitives, even those which have output_prefix_type != RAW.
54	var allEntries []*primitiveset.Entry
55	for _, entryList := range dr.wrapped.ps.Entries {
56		allEntries = append(allEntries, entryList...)
57	}
58	if err != nil {
59		return 0, err
60	}
61
62	dr.matchAttempted = true
63	ur := &unreader{r: dr.cr}
64
65	// find proper key to decrypt ciphertext
66	for _, e := range allEntries {
67		sa, ok := e.Primitive.(tink.StreamingAEAD)
68		if !ok {
69			continue
70		}
71
72		read := func() (io.Reader, int, error) {
73			r, err := sa.NewDecryptingReader(ur, dr.aad)
74			if err != nil {
75				return nil, 0, err
76			}
77			n, err := r.Read(p)
78			if err != nil {
79				return nil, 0, err
80			}
81			return r, n, nil
82		}
83
84		r, n, err := read()
85		if err == nil {
86			dr.mr = r
87			ur.disable()
88			return n, nil
89		}
90
91		ur.unread()
92	}
93	return 0, errKeyNotFound
94}
95
96// unreader wraps a reader and keeps a copy of everything that's read so it can
97// be unread and read again. When no additional unreads are needed, the buffer
98// can be disabled and the memory released.
99type unreader struct {
100	r        io.Reader
101	buf      []byte
102	pos      int
103	disabled bool
104}
105
106func (u *unreader) Read(buf []byte) (int, error) {
107	if len(u.buf) != u.pos {
108		n := copy(buf, u.buf[u.pos:])
109		u.pos += n
110		return n, nil
111	}
112	n, err := u.r.Read(buf)
113	if u.disabled {
114		u.buf = nil
115		u.pos = 0
116	} else {
117		u.buf = append(u.buf, buf[:n]...)
118		u.pos = len(u.buf)
119	}
120	return n, err
121}
122
123// unread starts the reader over again. A copy of all read data will be returned
124// by `Read()` before the wrapped reader is read from again.
125func (u *unreader) unread() {
126	u.pos = 0
127}
128
129// disable ensures the buffer is released for garbage collection once it's no
130// longer needed.
131func (u *unreader) disable() {
132	u.disabled = true
133}
134