xref: /aosp_15_r20/external/tink/go/streamingaead/decrypt_reader_test.go (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1// Copyright 2022 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15////////////////////////////////////////////////////////////////////////////////
16
17package streamingaead
18
19import (
20	"bytes"
21	"crypto/rand"
22	"fmt"
23	"io"
24	"strings"
25	"testing"
26
27	"github.com/google/tink/go/subtle/random"
28	"github.com/google/tink/go/testkeyset"
29	"github.com/google/tink/go/testutil"
30	tinkpb "github.com/google/tink/go/proto/tink_go_proto"
31)
32
33func BenchmarkDecryptReader(b *testing.B) {
34	b.ReportAllocs()
35
36	// Create a Streaming AEAD primitive using a full keyset.
37	decKeyset := testutil.NewTestAESGCMHKDFKeyset()
38	decKeysetHandle, err := testkeyset.NewHandle(decKeyset)
39	if err != nil {
40		b.Fatalf("Failed creating keyset handle: %v", err)
41	}
42	decCipher, err := New(decKeysetHandle)
43	if err != nil {
44		b.Errorf("streamingaead.New failed: %v", err)
45	}
46
47	// Extract the raw key from the keyset and create a Streaming AEAD primitive
48	// using only that key.
49	//
50	// testutil.NewTestAESGCMHKDFKeyset() places a raw key at position 1.
51	rawKey := decKeyset.Key[1]
52	if rawKey.OutputPrefixType != tinkpb.OutputPrefixType_RAW {
53		b.Fatalf("Expected a raw key.")
54	}
55	encKeyset := testutil.NewKeyset(rawKey.KeyId, []*tinkpb.Keyset_Key{rawKey})
56	encKeysetHandle, err := testkeyset.NewHandle(encKeyset)
57	if err != nil {
58		b.Fatalf("Failed creating keyset handle: %v", err)
59	}
60	encCipher, err := New(encKeysetHandle)
61	if err != nil {
62		b.Fatalf("streamingaead.New failed: %v", err)
63	}
64
65	plaintext := random.GetRandomBytes(8)
66	associatedData := random.GetRandomBytes(32)
67
68	b.ResetTimer()
69	for i := 0; i < b.N; i++ {
70		// Create a pipe for communication between the encrypting writer and
71		// decrypting reader.
72		r, w := io.Pipe()
73		defer r.Close()
74
75		// Repeatedly encrypt the plaintext and write the ciphertext to a pipe.
76		go func() {
77			const writeAtLeast = 1 << 30 // 1 GiB
78
79			enc, err := encCipher.NewEncryptingWriter(w, associatedData)
80			if err != nil {
81				b.Errorf("Cannot create encrypt writer: %v", err)
82				return
83			}
84
85			for i := 0; i < writeAtLeast; i += len(plaintext) {
86				if _, err := enc.Write(plaintext); err != nil {
87					b.Errorf("Error encrypting data: %v", err)
88					return
89				}
90			}
91			if err := enc.Close(); err != nil {
92				b.Errorf("Error closing encrypting writer: %v", err)
93				return
94			}
95			if err := w.Close(); err != nil {
96				b.Errorf("Error closing pipe: %v", err)
97				return
98			}
99		}()
100
101		// Decrypt the ciphertext in small chunks.
102		dec, err := decCipher.NewDecryptingReader(r, associatedData)
103		if err != nil {
104			b.Fatalf("Cannot create decrypt reader: %v", err)
105		}
106		buf := make([]byte, 16384) // 16 KiB
107		for {
108			_, err := dec.Read(buf)
109			if err == io.EOF {
110				break
111			}
112			if err != nil {
113				b.Fatalf("Error decrypting data: %v", err)
114			}
115		}
116	}
117}
118
119func TestUnreaderUnread(t *testing.T) {
120	original := make([]byte, 4096)
121	if _, err := io.ReadFull(rand.Reader, original); err != nil {
122		t.Fatalf("Failed to fill buffer with random bytes: %v", err)
123	}
124
125	u := &unreader{r: bytes.NewReader(original)}
126	got, err := io.ReadAll(u)
127	if err != nil {
128		t.Errorf("First io.ReadAll(%T) failed unexpectedly: %v", u, err)
129	}
130	if !bytes.Equal(got, original) {
131		t.Errorf("First io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, len(got), len(original), got, original)
132	}
133
134	u.unread()
135	got, err = io.ReadAll(u)
136	if err != nil {
137		t.Errorf("After %T.unread(), io.ReadAll(%T) failed unexpectedly: %v", u, u, err)
138	}
139	if !bytes.Equal(got, original) {
140		t.Errorf("After %T.unread(), io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, u, len(got), len(original), got, original)
141	}
142}
143
144func TestUnreader(t *testing.T) {
145	// Repeating sequence of characters '0' through '9' makes it easy to see
146	// holes or repeated data.
147	original := make([]byte, 100)
148	for i := range original {
149		original[i] = '0' + byte(i%10)
150	}
151
152	type step struct {
153		read    int  // If set, read the given number of bytes exactly.
154		unread  bool // If true, call unread().
155		disable bool // If true, call disable().
156	}
157	tcs := []struct {
158		name  string
159		steps []step
160	}{
161		{"Read2UnreadRead4Unread", []step{{read: 2}, {unread: true}, {read: 4}, {unread: true}}},
162		{"Read4UnreadRead2Unread", []step{{read: 4}, {unread: true}, {read: 2}, {unread: true}}},
163		{"Read3UnreadRead3Unread", []step{{read: 3}, {unread: true}, {read: 3}, {unread: true}}},
164		{"Read3Disable", []step{{read: 3}, {disable: true}}},
165		{"Read2UnreadRead4Disable", []step{{read: 2}, {unread: true}, {read: 4}, {disable: true}}},
166		{"Read4UnreadRead2Disable", []step{{read: 4}, {unread: true}, {read: 2}, {disable: true}}},
167		{"Read3UnreadRead3Disable", []step{{read: 3}, {unread: true}, {read: 3}, {disable: true}}},
168		{"Read2UnreadDisable", []step{{read: 2}, {unread: true}, {disable: true}}},
169		{"Read4UnreadDisable", []step{{read: 4}, {unread: true}, {disable: true}}},
170		{"ReadAllUnread", []step{{read: len(original)}, {unread: true}}},
171		{"ReadAllDisable", []step{{read: len(original)}, {disable: true}}},
172		{"Unread", []step{{unread: true}}},
173		{"Disable", []step{{disable: true}}},
174		{"UnreadDisable", []step{{unread: true}, {disable: true}}},
175	}
176
177	for _, tc := range tcs {
178		t.Run(tc.name, func(t *testing.T) {
179			u := &unreader{r: bytes.NewReader(original)}
180			var (
181				after []string
182				pos   int
183			)
184			// Explains what happened before the failure.
185			prefix := func() string {
186				if after == nil {
187					return ""
188				}
189				return fmt.Sprintf("After %s, ", strings.Join(after, "+"))
190			}
191			for _, s := range tc.steps {
192				if s.read != 0 {
193					buf := make([]byte, s.read)
194					if _, err := io.ReadFull(u, buf); err != nil {
195						t.Fatalf("%sio.ReadFull(%T, %d byte buffer) failed unexpectedly: %v", prefix(), u, s.read, err)
196					}
197					if want := original[pos : pos+s.read]; !bytes.Equal(buf, want) {
198						t.Fatalf("%sio.ReadFull(%T, %d byte buffer) got %q, want %q", prefix(), u, s.read, buf, want)
199					}
200					after = append(after, fmt.Sprintf("Read(%d bytes)", s.read))
201					pos += s.read
202				}
203				if s.disable {
204					u.disable()
205					after = append(after, "disable()")
206				}
207				if s.unread {
208					u.unread()
209					after = append(after, "unread()")
210					pos = 0
211				}
212			}
213			got, err := io.ReadAll(u)
214			if err != nil {
215				t.Fatalf("%sio.ReadAll(%T) failed unexpectedly: %v", prefix(), u, err)
216			}
217			if want := original[pos:]; !bytes.Equal(want, got) {
218				t.Errorf("%sio.ReadAll(%T) got %q, want %q", prefix(), u, got, want)
219			}
220		})
221	}
222}
223