xref: /aosp_15_r20/external/tink/go/streamingaead/decrypt_reader_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1*e7b1675dSTing-Kang Chang// Copyright 2022 Google LLC
2*e7b1675dSTing-Kang Chang//
3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License");
4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License.
5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at
6*e7b1675dSTing-Kang Chang//
7*e7b1675dSTing-Kang Chang//      http://www.apache.org/licenses/LICENSE-2.0
8*e7b1675dSTing-Kang Chang//
9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software
10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS,
11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and
13*e7b1675dSTing-Kang Chang// limitations under the License.
14*e7b1675dSTing-Kang Chang//
15*e7b1675dSTing-Kang Chang////////////////////////////////////////////////////////////////////////////////
16*e7b1675dSTing-Kang Chang
17*e7b1675dSTing-Kang Changpackage streamingaead
18*e7b1675dSTing-Kang Chang
19*e7b1675dSTing-Kang Changimport (
20*e7b1675dSTing-Kang Chang	"bytes"
21*e7b1675dSTing-Kang Chang	"crypto/rand"
22*e7b1675dSTing-Kang Chang	"fmt"
23*e7b1675dSTing-Kang Chang	"io"
24*e7b1675dSTing-Kang Chang	"strings"
25*e7b1675dSTing-Kang Chang	"testing"
26*e7b1675dSTing-Kang Chang
27*e7b1675dSTing-Kang Chang	"github.com/google/tink/go/subtle/random"
28*e7b1675dSTing-Kang Chang	"github.com/google/tink/go/testkeyset"
29*e7b1675dSTing-Kang Chang	"github.com/google/tink/go/testutil"
30*e7b1675dSTing-Kang Chang	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
31*e7b1675dSTing-Kang Chang)
32*e7b1675dSTing-Kang Chang
33*e7b1675dSTing-Kang Changfunc BenchmarkDecryptReader(b *testing.B) {
34*e7b1675dSTing-Kang Chang	b.ReportAllocs()
35*e7b1675dSTing-Kang Chang
36*e7b1675dSTing-Kang Chang	// Create a Streaming AEAD primitive using a full keyset.
37*e7b1675dSTing-Kang Chang	decKeyset := testutil.NewTestAESGCMHKDFKeyset()
38*e7b1675dSTing-Kang Chang	decKeysetHandle, err := testkeyset.NewHandle(decKeyset)
39*e7b1675dSTing-Kang Chang	if err != nil {
40*e7b1675dSTing-Kang Chang		b.Fatalf("Failed creating keyset handle: %v", err)
41*e7b1675dSTing-Kang Chang	}
42*e7b1675dSTing-Kang Chang	decCipher, err := New(decKeysetHandle)
43*e7b1675dSTing-Kang Chang	if err != nil {
44*e7b1675dSTing-Kang Chang		b.Errorf("streamingaead.New failed: %v", err)
45*e7b1675dSTing-Kang Chang	}
46*e7b1675dSTing-Kang Chang
47*e7b1675dSTing-Kang Chang	// Extract the raw key from the keyset and create a Streaming AEAD primitive
48*e7b1675dSTing-Kang Chang	// using only that key.
49*e7b1675dSTing-Kang Chang	//
50*e7b1675dSTing-Kang Chang	// testutil.NewTestAESGCMHKDFKeyset() places a raw key at position 1.
51*e7b1675dSTing-Kang Chang	rawKey := decKeyset.Key[1]
52*e7b1675dSTing-Kang Chang	if rawKey.OutputPrefixType != tinkpb.OutputPrefixType_RAW {
53*e7b1675dSTing-Kang Chang		b.Fatalf("Expected a raw key.")
54*e7b1675dSTing-Kang Chang	}
55*e7b1675dSTing-Kang Chang	encKeyset := testutil.NewKeyset(rawKey.KeyId, []*tinkpb.Keyset_Key{rawKey})
56*e7b1675dSTing-Kang Chang	encKeysetHandle, err := testkeyset.NewHandle(encKeyset)
57*e7b1675dSTing-Kang Chang	if err != nil {
58*e7b1675dSTing-Kang Chang		b.Fatalf("Failed creating keyset handle: %v", err)
59*e7b1675dSTing-Kang Chang	}
60*e7b1675dSTing-Kang Chang	encCipher, err := New(encKeysetHandle)
61*e7b1675dSTing-Kang Chang	if err != nil {
62*e7b1675dSTing-Kang Chang		b.Fatalf("streamingaead.New failed: %v", err)
63*e7b1675dSTing-Kang Chang	}
64*e7b1675dSTing-Kang Chang
65*e7b1675dSTing-Kang Chang	plaintext := random.GetRandomBytes(8)
66*e7b1675dSTing-Kang Chang	associatedData := random.GetRandomBytes(32)
67*e7b1675dSTing-Kang Chang
68*e7b1675dSTing-Kang Chang	b.ResetTimer()
69*e7b1675dSTing-Kang Chang	for i := 0; i < b.N; i++ {
70*e7b1675dSTing-Kang Chang		// Create a pipe for communication between the encrypting writer and
71*e7b1675dSTing-Kang Chang		// decrypting reader.
72*e7b1675dSTing-Kang Chang		r, w := io.Pipe()
73*e7b1675dSTing-Kang Chang		defer r.Close()
74*e7b1675dSTing-Kang Chang
75*e7b1675dSTing-Kang Chang		// Repeatedly encrypt the plaintext and write the ciphertext to a pipe.
76*e7b1675dSTing-Kang Chang		go func() {
77*e7b1675dSTing-Kang Chang			const writeAtLeast = 1 << 30 // 1 GiB
78*e7b1675dSTing-Kang Chang
79*e7b1675dSTing-Kang Chang			enc, err := encCipher.NewEncryptingWriter(w, associatedData)
80*e7b1675dSTing-Kang Chang			if err != nil {
81*e7b1675dSTing-Kang Chang				b.Errorf("Cannot create encrypt writer: %v", err)
82*e7b1675dSTing-Kang Chang				return
83*e7b1675dSTing-Kang Chang			}
84*e7b1675dSTing-Kang Chang
85*e7b1675dSTing-Kang Chang			for i := 0; i < writeAtLeast; i += len(plaintext) {
86*e7b1675dSTing-Kang Chang				if _, err := enc.Write(plaintext); err != nil {
87*e7b1675dSTing-Kang Chang					b.Errorf("Error encrypting data: %v", err)
88*e7b1675dSTing-Kang Chang					return
89*e7b1675dSTing-Kang Chang				}
90*e7b1675dSTing-Kang Chang			}
91*e7b1675dSTing-Kang Chang			if err := enc.Close(); err != nil {
92*e7b1675dSTing-Kang Chang				b.Errorf("Error closing encrypting writer: %v", err)
93*e7b1675dSTing-Kang Chang				return
94*e7b1675dSTing-Kang Chang			}
95*e7b1675dSTing-Kang Chang			if err := w.Close(); err != nil {
96*e7b1675dSTing-Kang Chang				b.Errorf("Error closing pipe: %v", err)
97*e7b1675dSTing-Kang Chang				return
98*e7b1675dSTing-Kang Chang			}
99*e7b1675dSTing-Kang Chang		}()
100*e7b1675dSTing-Kang Chang
101*e7b1675dSTing-Kang Chang		// Decrypt the ciphertext in small chunks.
102*e7b1675dSTing-Kang Chang		dec, err := decCipher.NewDecryptingReader(r, associatedData)
103*e7b1675dSTing-Kang Chang		if err != nil {
104*e7b1675dSTing-Kang Chang			b.Fatalf("Cannot create decrypt reader: %v", err)
105*e7b1675dSTing-Kang Chang		}
106*e7b1675dSTing-Kang Chang		buf := make([]byte, 16384) // 16 KiB
107*e7b1675dSTing-Kang Chang		for {
108*e7b1675dSTing-Kang Chang			_, err := dec.Read(buf)
109*e7b1675dSTing-Kang Chang			if err == io.EOF {
110*e7b1675dSTing-Kang Chang				break
111*e7b1675dSTing-Kang Chang			}
112*e7b1675dSTing-Kang Chang			if err != nil {
113*e7b1675dSTing-Kang Chang				b.Fatalf("Error decrypting data: %v", err)
114*e7b1675dSTing-Kang Chang			}
115*e7b1675dSTing-Kang Chang		}
116*e7b1675dSTing-Kang Chang	}
117*e7b1675dSTing-Kang Chang}
118*e7b1675dSTing-Kang Chang
119*e7b1675dSTing-Kang Changfunc TestUnreaderUnread(t *testing.T) {
120*e7b1675dSTing-Kang Chang	original := make([]byte, 4096)
121*e7b1675dSTing-Kang Chang	if _, err := io.ReadFull(rand.Reader, original); err != nil {
122*e7b1675dSTing-Kang Chang		t.Fatalf("Failed to fill buffer with random bytes: %v", err)
123*e7b1675dSTing-Kang Chang	}
124*e7b1675dSTing-Kang Chang
125*e7b1675dSTing-Kang Chang	u := &unreader{r: bytes.NewReader(original)}
126*e7b1675dSTing-Kang Chang	got, err := io.ReadAll(u)
127*e7b1675dSTing-Kang Chang	if err != nil {
128*e7b1675dSTing-Kang Chang		t.Errorf("First io.ReadAll(%T) failed unexpectedly: %v", u, err)
129*e7b1675dSTing-Kang Chang	}
130*e7b1675dSTing-Kang Chang	if !bytes.Equal(got, original) {
131*e7b1675dSTing-Kang Chang		t.Errorf("First io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, len(got), len(original), got, original)
132*e7b1675dSTing-Kang Chang	}
133*e7b1675dSTing-Kang Chang
134*e7b1675dSTing-Kang Chang	u.unread()
135*e7b1675dSTing-Kang Chang	got, err = io.ReadAll(u)
136*e7b1675dSTing-Kang Chang	if err != nil {
137*e7b1675dSTing-Kang Chang		t.Errorf("After %T.unread(), io.ReadAll(%T) failed unexpectedly: %v", u, u, err)
138*e7b1675dSTing-Kang Chang	}
139*e7b1675dSTing-Kang Chang	if !bytes.Equal(got, original) {
140*e7b1675dSTing-Kang Chang		t.Errorf("After %T.unread(), io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, u, len(got), len(original), got, original)
141*e7b1675dSTing-Kang Chang	}
142*e7b1675dSTing-Kang Chang}
143*e7b1675dSTing-Kang Chang
144*e7b1675dSTing-Kang Changfunc TestUnreader(t *testing.T) {
145*e7b1675dSTing-Kang Chang	// Repeating sequence of characters '0' through '9' makes it easy to see
146*e7b1675dSTing-Kang Chang	// holes or repeated data.
147*e7b1675dSTing-Kang Chang	original := make([]byte, 100)
148*e7b1675dSTing-Kang Chang	for i := range original {
149*e7b1675dSTing-Kang Chang		original[i] = '0' + byte(i%10)
150*e7b1675dSTing-Kang Chang	}
151*e7b1675dSTing-Kang Chang
152*e7b1675dSTing-Kang Chang	type step struct {
153*e7b1675dSTing-Kang Chang		read    int  // If set, read the given number of bytes exactly.
154*e7b1675dSTing-Kang Chang		unread  bool // If true, call unread().
155*e7b1675dSTing-Kang Chang		disable bool // If true, call disable().
156*e7b1675dSTing-Kang Chang	}
157*e7b1675dSTing-Kang Chang	tcs := []struct {
158*e7b1675dSTing-Kang Chang		name  string
159*e7b1675dSTing-Kang Chang		steps []step
160*e7b1675dSTing-Kang Chang	}{
161*e7b1675dSTing-Kang Chang		{"Read2UnreadRead4Unread", []step{{read: 2}, {unread: true}, {read: 4}, {unread: true}}},
162*e7b1675dSTing-Kang Chang		{"Read4UnreadRead2Unread", []step{{read: 4}, {unread: true}, {read: 2}, {unread: true}}},
163*e7b1675dSTing-Kang Chang		{"Read3UnreadRead3Unread", []step{{read: 3}, {unread: true}, {read: 3}, {unread: true}}},
164*e7b1675dSTing-Kang Chang		{"Read3Disable", []step{{read: 3}, {disable: true}}},
165*e7b1675dSTing-Kang Chang		{"Read2UnreadRead4Disable", []step{{read: 2}, {unread: true}, {read: 4}, {disable: true}}},
166*e7b1675dSTing-Kang Chang		{"Read4UnreadRead2Disable", []step{{read: 4}, {unread: true}, {read: 2}, {disable: true}}},
167*e7b1675dSTing-Kang Chang		{"Read3UnreadRead3Disable", []step{{read: 3}, {unread: true}, {read: 3}, {disable: true}}},
168*e7b1675dSTing-Kang Chang		{"Read2UnreadDisable", []step{{read: 2}, {unread: true}, {disable: true}}},
169*e7b1675dSTing-Kang Chang		{"Read4UnreadDisable", []step{{read: 4}, {unread: true}, {disable: true}}},
170*e7b1675dSTing-Kang Chang		{"ReadAllUnread", []step{{read: len(original)}, {unread: true}}},
171*e7b1675dSTing-Kang Chang		{"ReadAllDisable", []step{{read: len(original)}, {disable: true}}},
172*e7b1675dSTing-Kang Chang		{"Unread", []step{{unread: true}}},
173*e7b1675dSTing-Kang Chang		{"Disable", []step{{disable: true}}},
174*e7b1675dSTing-Kang Chang		{"UnreadDisable", []step{{unread: true}, {disable: true}}},
175*e7b1675dSTing-Kang Chang	}
176*e7b1675dSTing-Kang Chang
177*e7b1675dSTing-Kang Chang	for _, tc := range tcs {
178*e7b1675dSTing-Kang Chang		t.Run(tc.name, func(t *testing.T) {
179*e7b1675dSTing-Kang Chang			u := &unreader{r: bytes.NewReader(original)}
180*e7b1675dSTing-Kang Chang			var (
181*e7b1675dSTing-Kang Chang				after []string
182*e7b1675dSTing-Kang Chang				pos   int
183*e7b1675dSTing-Kang Chang			)
184*e7b1675dSTing-Kang Chang			// Explains what happened before the failure.
185*e7b1675dSTing-Kang Chang			prefix := func() string {
186*e7b1675dSTing-Kang Chang				if after == nil {
187*e7b1675dSTing-Kang Chang					return ""
188*e7b1675dSTing-Kang Chang				}
189*e7b1675dSTing-Kang Chang				return fmt.Sprintf("After %s, ", strings.Join(after, "+"))
190*e7b1675dSTing-Kang Chang			}
191*e7b1675dSTing-Kang Chang			for _, s := range tc.steps {
192*e7b1675dSTing-Kang Chang				if s.read != 0 {
193*e7b1675dSTing-Kang Chang					buf := make([]byte, s.read)
194*e7b1675dSTing-Kang Chang					if _, err := io.ReadFull(u, buf); err != nil {
195*e7b1675dSTing-Kang Chang						t.Fatalf("%sio.ReadFull(%T, %d byte buffer) failed unexpectedly: %v", prefix(), u, s.read, err)
196*e7b1675dSTing-Kang Chang					}
197*e7b1675dSTing-Kang Chang					if want := original[pos : pos+s.read]; !bytes.Equal(buf, want) {
198*e7b1675dSTing-Kang Chang						t.Fatalf("%sio.ReadFull(%T, %d byte buffer) got %q, want %q", prefix(), u, s.read, buf, want)
199*e7b1675dSTing-Kang Chang					}
200*e7b1675dSTing-Kang Chang					after = append(after, fmt.Sprintf("Read(%d bytes)", s.read))
201*e7b1675dSTing-Kang Chang					pos += s.read
202*e7b1675dSTing-Kang Chang				}
203*e7b1675dSTing-Kang Chang				if s.disable {
204*e7b1675dSTing-Kang Chang					u.disable()
205*e7b1675dSTing-Kang Chang					after = append(after, "disable()")
206*e7b1675dSTing-Kang Chang				}
207*e7b1675dSTing-Kang Chang				if s.unread {
208*e7b1675dSTing-Kang Chang					u.unread()
209*e7b1675dSTing-Kang Chang					after = append(after, "unread()")
210*e7b1675dSTing-Kang Chang					pos = 0
211*e7b1675dSTing-Kang Chang				}
212*e7b1675dSTing-Kang Chang			}
213*e7b1675dSTing-Kang Chang			got, err := io.ReadAll(u)
214*e7b1675dSTing-Kang Chang			if err != nil {
215*e7b1675dSTing-Kang Chang				t.Fatalf("%sio.ReadAll(%T) failed unexpectedly: %v", prefix(), u, err)
216*e7b1675dSTing-Kang Chang			}
217*e7b1675dSTing-Kang Chang			if want := original[pos:]; !bytes.Equal(want, got) {
218*e7b1675dSTing-Kang Chang				t.Errorf("%sio.ReadAll(%T) got %q, want %q", prefix(), u, got, want)
219*e7b1675dSTing-Kang Chang			}
220*e7b1675dSTing-Kang Chang		})
221*e7b1675dSTing-Kang Chang	}
222*e7b1675dSTing-Kang Chang}
223