xref: /aosp_15_r20/external/tink/go/streamingaead/decrypt_reader.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1*e7b1675dSTing-Kang Chang// Copyright 2020 Google LLC
2*e7b1675dSTing-Kang Chang//
3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License");
4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License.
5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at
6*e7b1675dSTing-Kang Chang//
7*e7b1675dSTing-Kang Chang//      http://www.apache.org/licenses/LICENSE-2.0
8*e7b1675dSTing-Kang Chang//
9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software
10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS,
11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and
13*e7b1675dSTing-Kang Chang// limitations under the License.
14*e7b1675dSTing-Kang Chang//
15*e7b1675dSTing-Kang Chang////////////////////////////////////////////////////////////////////////////////
16*e7b1675dSTing-Kang Chang
17*e7b1675dSTing-Kang Changpackage streamingaead
18*e7b1675dSTing-Kang Chang
19*e7b1675dSTing-Kang Changimport (
20*e7b1675dSTing-Kang Chang	"errors"
21*e7b1675dSTing-Kang Chang	"io"
22*e7b1675dSTing-Kang Chang
23*e7b1675dSTing-Kang Chang	"github.com/google/tink/go/core/primitiveset"
24*e7b1675dSTing-Kang Chang	"github.com/google/tink/go/tink"
25*e7b1675dSTing-Kang Chang)
26*e7b1675dSTing-Kang Chang
27*e7b1675dSTing-Kang Changvar (
28*e7b1675dSTing-Kang Chang	_              io.Reader = &decryptReader{}
29*e7b1675dSTing-Kang Chang	errKeyNotFound           = errors.New("no matching key found for the ciphertext in the stream")
30*e7b1675dSTing-Kang Chang)
31*e7b1675dSTing-Kang Chang
32*e7b1675dSTing-Kang Chang// decryptReader is a reader that tries to find the right key to decrypt ciphertext from the given primitive set.
33*e7b1675dSTing-Kang Changtype decryptReader struct {
34*e7b1675dSTing-Kang Chang	wrapped *wrappedStreamingAEAD
35*e7b1675dSTing-Kang Chang	// cr is a source Reader which provides ciphertext to be decrypted.
36*e7b1675dSTing-Kang Chang	cr  io.Reader
37*e7b1675dSTing-Kang Chang	aad []byte
38*e7b1675dSTing-Kang Chang
39*e7b1675dSTing-Kang Chang	matchAttempted bool
40*e7b1675dSTing-Kang Chang	// mr is a matched decrypting reader initialized with a proper key to decrypt ciphertext.
41*e7b1675dSTing-Kang Chang	mr io.Reader
42*e7b1675dSTing-Kang Chang}
43*e7b1675dSTing-Kang Chang
44*e7b1675dSTing-Kang Changfunc (dr *decryptReader) Read(p []byte) (n int, err error) {
45*e7b1675dSTing-Kang Chang	if dr.mr != nil {
46*e7b1675dSTing-Kang Chang		return dr.mr.Read(p)
47*e7b1675dSTing-Kang Chang	}
48*e7b1675dSTing-Kang Chang	if dr.matchAttempted {
49*e7b1675dSTing-Kang Chang		return 0, errKeyNotFound
50*e7b1675dSTing-Kang Chang	}
51*e7b1675dSTing-Kang Chang
52*e7b1675dSTing-Kang Chang	// For legacy reasons (Tink always encrypted with non-RAW keys) we use all
53*e7b1675dSTing-Kang Chang	// primitives, even those which have output_prefix_type != RAW.
54*e7b1675dSTing-Kang Chang	var allEntries []*primitiveset.Entry
55*e7b1675dSTing-Kang Chang	for _, entryList := range dr.wrapped.ps.Entries {
56*e7b1675dSTing-Kang Chang		allEntries = append(allEntries, entryList...)
57*e7b1675dSTing-Kang Chang	}
58*e7b1675dSTing-Kang Chang	if err != nil {
59*e7b1675dSTing-Kang Chang		return 0, err
60*e7b1675dSTing-Kang Chang	}
61*e7b1675dSTing-Kang Chang
62*e7b1675dSTing-Kang Chang	dr.matchAttempted = true
63*e7b1675dSTing-Kang Chang	ur := &unreader{r: dr.cr}
64*e7b1675dSTing-Kang Chang
65*e7b1675dSTing-Kang Chang	// find proper key to decrypt ciphertext
66*e7b1675dSTing-Kang Chang	for _, e := range allEntries {
67*e7b1675dSTing-Kang Chang		sa, ok := e.Primitive.(tink.StreamingAEAD)
68*e7b1675dSTing-Kang Chang		if !ok {
69*e7b1675dSTing-Kang Chang			continue
70*e7b1675dSTing-Kang Chang		}
71*e7b1675dSTing-Kang Chang
72*e7b1675dSTing-Kang Chang		read := func() (io.Reader, int, error) {
73*e7b1675dSTing-Kang Chang			r, err := sa.NewDecryptingReader(ur, dr.aad)
74*e7b1675dSTing-Kang Chang			if err != nil {
75*e7b1675dSTing-Kang Chang				return nil, 0, err
76*e7b1675dSTing-Kang Chang			}
77*e7b1675dSTing-Kang Chang			n, err := r.Read(p)
78*e7b1675dSTing-Kang Chang			if err != nil {
79*e7b1675dSTing-Kang Chang				return nil, 0, err
80*e7b1675dSTing-Kang Chang			}
81*e7b1675dSTing-Kang Chang			return r, n, nil
82*e7b1675dSTing-Kang Chang		}
83*e7b1675dSTing-Kang Chang
84*e7b1675dSTing-Kang Chang		r, n, err := read()
85*e7b1675dSTing-Kang Chang		if err == nil {
86*e7b1675dSTing-Kang Chang			dr.mr = r
87*e7b1675dSTing-Kang Chang			ur.disable()
88*e7b1675dSTing-Kang Chang			return n, nil
89*e7b1675dSTing-Kang Chang		}
90*e7b1675dSTing-Kang Chang
91*e7b1675dSTing-Kang Chang		ur.unread()
92*e7b1675dSTing-Kang Chang	}
93*e7b1675dSTing-Kang Chang	return 0, errKeyNotFound
94*e7b1675dSTing-Kang Chang}
95*e7b1675dSTing-Kang Chang
96*e7b1675dSTing-Kang Chang// unreader wraps a reader and keeps a copy of everything that's read so it can
97*e7b1675dSTing-Kang Chang// be unread and read again. When no additional unreads are needed, the buffer
98*e7b1675dSTing-Kang Chang// can be disabled and the memory released.
99*e7b1675dSTing-Kang Changtype unreader struct {
100*e7b1675dSTing-Kang Chang	r        io.Reader
101*e7b1675dSTing-Kang Chang	buf      []byte
102*e7b1675dSTing-Kang Chang	pos      int
103*e7b1675dSTing-Kang Chang	disabled bool
104*e7b1675dSTing-Kang Chang}
105*e7b1675dSTing-Kang Chang
106*e7b1675dSTing-Kang Changfunc (u *unreader) Read(buf []byte) (int, error) {
107*e7b1675dSTing-Kang Chang	if len(u.buf) != u.pos {
108*e7b1675dSTing-Kang Chang		n := copy(buf, u.buf[u.pos:])
109*e7b1675dSTing-Kang Chang		u.pos += n
110*e7b1675dSTing-Kang Chang		return n, nil
111*e7b1675dSTing-Kang Chang	}
112*e7b1675dSTing-Kang Chang	n, err := u.r.Read(buf)
113*e7b1675dSTing-Kang Chang	if u.disabled {
114*e7b1675dSTing-Kang Chang		u.buf = nil
115*e7b1675dSTing-Kang Chang		u.pos = 0
116*e7b1675dSTing-Kang Chang	} else {
117*e7b1675dSTing-Kang Chang		u.buf = append(u.buf, buf[:n]...)
118*e7b1675dSTing-Kang Chang		u.pos = len(u.buf)
119*e7b1675dSTing-Kang Chang	}
120*e7b1675dSTing-Kang Chang	return n, err
121*e7b1675dSTing-Kang Chang}
122*e7b1675dSTing-Kang Chang
123*e7b1675dSTing-Kang Chang// unread starts the reader over again. A copy of all read data will be returned
124*e7b1675dSTing-Kang Chang// by `Read()` before the wrapped reader is read from again.
125*e7b1675dSTing-Kang Changfunc (u *unreader) unread() {
126*e7b1675dSTing-Kang Chang	u.pos = 0
127*e7b1675dSTing-Kang Chang}
128*e7b1675dSTing-Kang Chang
129*e7b1675dSTing-Kang Chang// disable ensures the buffer is released for garbage collection once it's no
130*e7b1675dSTing-Kang Chang// longer needed.
131*e7b1675dSTing-Kang Changfunc (u *unreader) disable() {
132*e7b1675dSTing-Kang Chang	u.disabled = true
133*e7b1675dSTing-Kang Chang}
134