1// Copyright 2022 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package streamingaead 18 19import ( 20 "bytes" 21 "crypto/rand" 22 "fmt" 23 "io" 24 "strings" 25 "testing" 26 27 "github.com/google/tink/go/subtle/random" 28 "github.com/google/tink/go/testkeyset" 29 "github.com/google/tink/go/testutil" 30 tinkpb "github.com/google/tink/go/proto/tink_go_proto" 31) 32 33func BenchmarkDecryptReader(b *testing.B) { 34 b.ReportAllocs() 35 36 // Create a Streaming AEAD primitive using a full keyset. 37 decKeyset := testutil.NewTestAESGCMHKDFKeyset() 38 decKeysetHandle, err := testkeyset.NewHandle(decKeyset) 39 if err != nil { 40 b.Fatalf("Failed creating keyset handle: %v", err) 41 } 42 decCipher, err := New(decKeysetHandle) 43 if err != nil { 44 b.Errorf("streamingaead.New failed: %v", err) 45 } 46 47 // Extract the raw key from the keyset and create a Streaming AEAD primitive 48 // using only that key. 49 // 50 // testutil.NewTestAESGCMHKDFKeyset() places a raw key at position 1. 51 rawKey := decKeyset.Key[1] 52 if rawKey.OutputPrefixType != tinkpb.OutputPrefixType_RAW { 53 b.Fatalf("Expected a raw key.") 54 } 55 encKeyset := testutil.NewKeyset(rawKey.KeyId, []*tinkpb.Keyset_Key{rawKey}) 56 encKeysetHandle, err := testkeyset.NewHandle(encKeyset) 57 if err != nil { 58 b.Fatalf("Failed creating keyset handle: %v", err) 59 } 60 encCipher, err := New(encKeysetHandle) 61 if err != nil { 62 b.Fatalf("streamingaead.New failed: %v", err) 63 } 64 65 plaintext := random.GetRandomBytes(8) 66 associatedData := random.GetRandomBytes(32) 67 68 b.ResetTimer() 69 for i := 0; i < b.N; i++ { 70 // Create a pipe for communication between the encrypting writer and 71 // decrypting reader. 72 r, w := io.Pipe() 73 defer r.Close() 74 75 // Repeatedly encrypt the plaintext and write the ciphertext to a pipe. 76 go func() { 77 const writeAtLeast = 1 << 30 // 1 GiB 78 79 enc, err := encCipher.NewEncryptingWriter(w, associatedData) 80 if err != nil { 81 b.Errorf("Cannot create encrypt writer: %v", err) 82 return 83 } 84 85 for i := 0; i < writeAtLeast; i += len(plaintext) { 86 if _, err := enc.Write(plaintext); err != nil { 87 b.Errorf("Error encrypting data: %v", err) 88 return 89 } 90 } 91 if err := enc.Close(); err != nil { 92 b.Errorf("Error closing encrypting writer: %v", err) 93 return 94 } 95 if err := w.Close(); err != nil { 96 b.Errorf("Error closing pipe: %v", err) 97 return 98 } 99 }() 100 101 // Decrypt the ciphertext in small chunks. 102 dec, err := decCipher.NewDecryptingReader(r, associatedData) 103 if err != nil { 104 b.Fatalf("Cannot create decrypt reader: %v", err) 105 } 106 buf := make([]byte, 16384) // 16 KiB 107 for { 108 _, err := dec.Read(buf) 109 if err == io.EOF { 110 break 111 } 112 if err != nil { 113 b.Fatalf("Error decrypting data: %v", err) 114 } 115 } 116 } 117} 118 119func TestUnreaderUnread(t *testing.T) { 120 original := make([]byte, 4096) 121 if _, err := io.ReadFull(rand.Reader, original); err != nil { 122 t.Fatalf("Failed to fill buffer with random bytes: %v", err) 123 } 124 125 u := &unreader{r: bytes.NewReader(original)} 126 got, err := io.ReadAll(u) 127 if err != nil { 128 t.Errorf("First io.ReadAll(%T) failed unexpectedly: %v", u, err) 129 } 130 if !bytes.Equal(got, original) { 131 t.Errorf("First io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, len(got), len(original), got, original) 132 } 133 134 u.unread() 135 got, err = io.ReadAll(u) 136 if err != nil { 137 t.Errorf("After %T.unread(), io.ReadAll(%T) failed unexpectedly: %v", u, u, err) 138 } 139 if !bytes.Equal(got, original) { 140 t.Errorf("After %T.unread(), io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, u, len(got), len(original), got, original) 141 } 142} 143 144func TestUnreader(t *testing.T) { 145 // Repeating sequence of characters '0' through '9' makes it easy to see 146 // holes or repeated data. 147 original := make([]byte, 100) 148 for i := range original { 149 original[i] = '0' + byte(i%10) 150 } 151 152 type step struct { 153 read int // If set, read the given number of bytes exactly. 154 unread bool // If true, call unread(). 155 disable bool // If true, call disable(). 156 } 157 tcs := []struct { 158 name string 159 steps []step 160 }{ 161 {"Read2UnreadRead4Unread", []step{{read: 2}, {unread: true}, {read: 4}, {unread: true}}}, 162 {"Read4UnreadRead2Unread", []step{{read: 4}, {unread: true}, {read: 2}, {unread: true}}}, 163 {"Read3UnreadRead3Unread", []step{{read: 3}, {unread: true}, {read: 3}, {unread: true}}}, 164 {"Read3Disable", []step{{read: 3}, {disable: true}}}, 165 {"Read2UnreadRead4Disable", []step{{read: 2}, {unread: true}, {read: 4}, {disable: true}}}, 166 {"Read4UnreadRead2Disable", []step{{read: 4}, {unread: true}, {read: 2}, {disable: true}}}, 167 {"Read3UnreadRead3Disable", []step{{read: 3}, {unread: true}, {read: 3}, {disable: true}}}, 168 {"Read2UnreadDisable", []step{{read: 2}, {unread: true}, {disable: true}}}, 169 {"Read4UnreadDisable", []step{{read: 4}, {unread: true}, {disable: true}}}, 170 {"ReadAllUnread", []step{{read: len(original)}, {unread: true}}}, 171 {"ReadAllDisable", []step{{read: len(original)}, {disable: true}}}, 172 {"Unread", []step{{unread: true}}}, 173 {"Disable", []step{{disable: true}}}, 174 {"UnreadDisable", []step{{unread: true}, {disable: true}}}, 175 } 176 177 for _, tc := range tcs { 178 t.Run(tc.name, func(t *testing.T) { 179 u := &unreader{r: bytes.NewReader(original)} 180 var ( 181 after []string 182 pos int 183 ) 184 // Explains what happened before the failure. 185 prefix := func() string { 186 if after == nil { 187 return "" 188 } 189 return fmt.Sprintf("After %s, ", strings.Join(after, "+")) 190 } 191 for _, s := range tc.steps { 192 if s.read != 0 { 193 buf := make([]byte, s.read) 194 if _, err := io.ReadFull(u, buf); err != nil { 195 t.Fatalf("%sio.ReadFull(%T, %d byte buffer) failed unexpectedly: %v", prefix(), u, s.read, err) 196 } 197 if want := original[pos : pos+s.read]; !bytes.Equal(buf, want) { 198 t.Fatalf("%sio.ReadFull(%T, %d byte buffer) got %q, want %q", prefix(), u, s.read, buf, want) 199 } 200 after = append(after, fmt.Sprintf("Read(%d bytes)", s.read)) 201 pos += s.read 202 } 203 if s.disable { 204 u.disable() 205 after = append(after, "disable()") 206 } 207 if s.unread { 208 u.unread() 209 after = append(after, "unread()") 210 pos = 0 211 } 212 } 213 got, err := io.ReadAll(u) 214 if err != nil { 215 t.Fatalf("%sio.ReadAll(%T) failed unexpectedly: %v", prefix(), u, err) 216 } 217 if want := original[pos:]; !bytes.Equal(want, got) { 218 t.Errorf("%sio.ReadAll(%T) got %q, want %q", prefix(), u, got, want) 219 } 220 }) 221 } 222} 223