1*e7b1675dSTing-Kang Chang// Copyright 2022 Google LLC 2*e7b1675dSTing-Kang Chang// 3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License"); 4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License. 5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at 6*e7b1675dSTing-Kang Chang// 7*e7b1675dSTing-Kang Chang// http://www.apache.org/licenses/LICENSE-2.0 8*e7b1675dSTing-Kang Chang// 9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software 10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS, 11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and 13*e7b1675dSTing-Kang Chang// limitations under the License. 14*e7b1675dSTing-Kang Chang// 15*e7b1675dSTing-Kang Chang//////////////////////////////////////////////////////////////////////////////// 16*e7b1675dSTing-Kang Chang 17*e7b1675dSTing-Kang Changpackage streamingaead 18*e7b1675dSTing-Kang Chang 19*e7b1675dSTing-Kang Changimport ( 20*e7b1675dSTing-Kang Chang "bytes" 21*e7b1675dSTing-Kang Chang "crypto/rand" 22*e7b1675dSTing-Kang Chang "fmt" 23*e7b1675dSTing-Kang Chang "io" 24*e7b1675dSTing-Kang Chang "strings" 25*e7b1675dSTing-Kang Chang "testing" 26*e7b1675dSTing-Kang Chang 27*e7b1675dSTing-Kang Chang "github.com/google/tink/go/subtle/random" 28*e7b1675dSTing-Kang Chang "github.com/google/tink/go/testkeyset" 29*e7b1675dSTing-Kang Chang "github.com/google/tink/go/testutil" 30*e7b1675dSTing-Kang Chang tinkpb "github.com/google/tink/go/proto/tink_go_proto" 31*e7b1675dSTing-Kang Chang) 32*e7b1675dSTing-Kang Chang 33*e7b1675dSTing-Kang Changfunc BenchmarkDecryptReader(b *testing.B) { 34*e7b1675dSTing-Kang Chang b.ReportAllocs() 35*e7b1675dSTing-Kang Chang 36*e7b1675dSTing-Kang Chang // Create a Streaming AEAD primitive using a full keyset. 37*e7b1675dSTing-Kang Chang decKeyset := testutil.NewTestAESGCMHKDFKeyset() 38*e7b1675dSTing-Kang Chang decKeysetHandle, err := testkeyset.NewHandle(decKeyset) 39*e7b1675dSTing-Kang Chang if err != nil { 40*e7b1675dSTing-Kang Chang b.Fatalf("Failed creating keyset handle: %v", err) 41*e7b1675dSTing-Kang Chang } 42*e7b1675dSTing-Kang Chang decCipher, err := New(decKeysetHandle) 43*e7b1675dSTing-Kang Chang if err != nil { 44*e7b1675dSTing-Kang Chang b.Errorf("streamingaead.New failed: %v", err) 45*e7b1675dSTing-Kang Chang } 46*e7b1675dSTing-Kang Chang 47*e7b1675dSTing-Kang Chang // Extract the raw key from the keyset and create a Streaming AEAD primitive 48*e7b1675dSTing-Kang Chang // using only that key. 49*e7b1675dSTing-Kang Chang // 50*e7b1675dSTing-Kang Chang // testutil.NewTestAESGCMHKDFKeyset() places a raw key at position 1. 51*e7b1675dSTing-Kang Chang rawKey := decKeyset.Key[1] 52*e7b1675dSTing-Kang Chang if rawKey.OutputPrefixType != tinkpb.OutputPrefixType_RAW { 53*e7b1675dSTing-Kang Chang b.Fatalf("Expected a raw key.") 54*e7b1675dSTing-Kang Chang } 55*e7b1675dSTing-Kang Chang encKeyset := testutil.NewKeyset(rawKey.KeyId, []*tinkpb.Keyset_Key{rawKey}) 56*e7b1675dSTing-Kang Chang encKeysetHandle, err := testkeyset.NewHandle(encKeyset) 57*e7b1675dSTing-Kang Chang if err != nil { 58*e7b1675dSTing-Kang Chang b.Fatalf("Failed creating keyset handle: %v", err) 59*e7b1675dSTing-Kang Chang } 60*e7b1675dSTing-Kang Chang encCipher, err := New(encKeysetHandle) 61*e7b1675dSTing-Kang Chang if err != nil { 62*e7b1675dSTing-Kang Chang b.Fatalf("streamingaead.New failed: %v", err) 63*e7b1675dSTing-Kang Chang } 64*e7b1675dSTing-Kang Chang 65*e7b1675dSTing-Kang Chang plaintext := random.GetRandomBytes(8) 66*e7b1675dSTing-Kang Chang associatedData := random.GetRandomBytes(32) 67*e7b1675dSTing-Kang Chang 68*e7b1675dSTing-Kang Chang b.ResetTimer() 69*e7b1675dSTing-Kang Chang for i := 0; i < b.N; i++ { 70*e7b1675dSTing-Kang Chang // Create a pipe for communication between the encrypting writer and 71*e7b1675dSTing-Kang Chang // decrypting reader. 72*e7b1675dSTing-Kang Chang r, w := io.Pipe() 73*e7b1675dSTing-Kang Chang defer r.Close() 74*e7b1675dSTing-Kang Chang 75*e7b1675dSTing-Kang Chang // Repeatedly encrypt the plaintext and write the ciphertext to a pipe. 76*e7b1675dSTing-Kang Chang go func() { 77*e7b1675dSTing-Kang Chang const writeAtLeast = 1 << 30 // 1 GiB 78*e7b1675dSTing-Kang Chang 79*e7b1675dSTing-Kang Chang enc, err := encCipher.NewEncryptingWriter(w, associatedData) 80*e7b1675dSTing-Kang Chang if err != nil { 81*e7b1675dSTing-Kang Chang b.Errorf("Cannot create encrypt writer: %v", err) 82*e7b1675dSTing-Kang Chang return 83*e7b1675dSTing-Kang Chang } 84*e7b1675dSTing-Kang Chang 85*e7b1675dSTing-Kang Chang for i := 0; i < writeAtLeast; i += len(plaintext) { 86*e7b1675dSTing-Kang Chang if _, err := enc.Write(plaintext); err != nil { 87*e7b1675dSTing-Kang Chang b.Errorf("Error encrypting data: %v", err) 88*e7b1675dSTing-Kang Chang return 89*e7b1675dSTing-Kang Chang } 90*e7b1675dSTing-Kang Chang } 91*e7b1675dSTing-Kang Chang if err := enc.Close(); err != nil { 92*e7b1675dSTing-Kang Chang b.Errorf("Error closing encrypting writer: %v", err) 93*e7b1675dSTing-Kang Chang return 94*e7b1675dSTing-Kang Chang } 95*e7b1675dSTing-Kang Chang if err := w.Close(); err != nil { 96*e7b1675dSTing-Kang Chang b.Errorf("Error closing pipe: %v", err) 97*e7b1675dSTing-Kang Chang return 98*e7b1675dSTing-Kang Chang } 99*e7b1675dSTing-Kang Chang }() 100*e7b1675dSTing-Kang Chang 101*e7b1675dSTing-Kang Chang // Decrypt the ciphertext in small chunks. 102*e7b1675dSTing-Kang Chang dec, err := decCipher.NewDecryptingReader(r, associatedData) 103*e7b1675dSTing-Kang Chang if err != nil { 104*e7b1675dSTing-Kang Chang b.Fatalf("Cannot create decrypt reader: %v", err) 105*e7b1675dSTing-Kang Chang } 106*e7b1675dSTing-Kang Chang buf := make([]byte, 16384) // 16 KiB 107*e7b1675dSTing-Kang Chang for { 108*e7b1675dSTing-Kang Chang _, err := dec.Read(buf) 109*e7b1675dSTing-Kang Chang if err == io.EOF { 110*e7b1675dSTing-Kang Chang break 111*e7b1675dSTing-Kang Chang } 112*e7b1675dSTing-Kang Chang if err != nil { 113*e7b1675dSTing-Kang Chang b.Fatalf("Error decrypting data: %v", err) 114*e7b1675dSTing-Kang Chang } 115*e7b1675dSTing-Kang Chang } 116*e7b1675dSTing-Kang Chang } 117*e7b1675dSTing-Kang Chang} 118*e7b1675dSTing-Kang Chang 119*e7b1675dSTing-Kang Changfunc TestUnreaderUnread(t *testing.T) { 120*e7b1675dSTing-Kang Chang original := make([]byte, 4096) 121*e7b1675dSTing-Kang Chang if _, err := io.ReadFull(rand.Reader, original); err != nil { 122*e7b1675dSTing-Kang Chang t.Fatalf("Failed to fill buffer with random bytes: %v", err) 123*e7b1675dSTing-Kang Chang } 124*e7b1675dSTing-Kang Chang 125*e7b1675dSTing-Kang Chang u := &unreader{r: bytes.NewReader(original)} 126*e7b1675dSTing-Kang Chang got, err := io.ReadAll(u) 127*e7b1675dSTing-Kang Chang if err != nil { 128*e7b1675dSTing-Kang Chang t.Errorf("First io.ReadAll(%T) failed unexpectedly: %v", u, err) 129*e7b1675dSTing-Kang Chang } 130*e7b1675dSTing-Kang Chang if !bytes.Equal(got, original) { 131*e7b1675dSTing-Kang Chang t.Errorf("First io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, len(got), len(original), got, original) 132*e7b1675dSTing-Kang Chang } 133*e7b1675dSTing-Kang Chang 134*e7b1675dSTing-Kang Chang u.unread() 135*e7b1675dSTing-Kang Chang got, err = io.ReadAll(u) 136*e7b1675dSTing-Kang Chang if err != nil { 137*e7b1675dSTing-Kang Chang t.Errorf("After %T.unread(), io.ReadAll(%T) failed unexpectedly: %v", u, u, err) 138*e7b1675dSTing-Kang Chang } 139*e7b1675dSTing-Kang Chang if !bytes.Equal(got, original) { 140*e7b1675dSTing-Kang Chang t.Errorf("After %T.unread(), io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, u, len(got), len(original), got, original) 141*e7b1675dSTing-Kang Chang } 142*e7b1675dSTing-Kang Chang} 143*e7b1675dSTing-Kang Chang 144*e7b1675dSTing-Kang Changfunc TestUnreader(t *testing.T) { 145*e7b1675dSTing-Kang Chang // Repeating sequence of characters '0' through '9' makes it easy to see 146*e7b1675dSTing-Kang Chang // holes or repeated data. 147*e7b1675dSTing-Kang Chang original := make([]byte, 100) 148*e7b1675dSTing-Kang Chang for i := range original { 149*e7b1675dSTing-Kang Chang original[i] = '0' + byte(i%10) 150*e7b1675dSTing-Kang Chang } 151*e7b1675dSTing-Kang Chang 152*e7b1675dSTing-Kang Chang type step struct { 153*e7b1675dSTing-Kang Chang read int // If set, read the given number of bytes exactly. 154*e7b1675dSTing-Kang Chang unread bool // If true, call unread(). 155*e7b1675dSTing-Kang Chang disable bool // If true, call disable(). 156*e7b1675dSTing-Kang Chang } 157*e7b1675dSTing-Kang Chang tcs := []struct { 158*e7b1675dSTing-Kang Chang name string 159*e7b1675dSTing-Kang Chang steps []step 160*e7b1675dSTing-Kang Chang }{ 161*e7b1675dSTing-Kang Chang {"Read2UnreadRead4Unread", []step{{read: 2}, {unread: true}, {read: 4}, {unread: true}}}, 162*e7b1675dSTing-Kang Chang {"Read4UnreadRead2Unread", []step{{read: 4}, {unread: true}, {read: 2}, {unread: true}}}, 163*e7b1675dSTing-Kang Chang {"Read3UnreadRead3Unread", []step{{read: 3}, {unread: true}, {read: 3}, {unread: true}}}, 164*e7b1675dSTing-Kang Chang {"Read3Disable", []step{{read: 3}, {disable: true}}}, 165*e7b1675dSTing-Kang Chang {"Read2UnreadRead4Disable", []step{{read: 2}, {unread: true}, {read: 4}, {disable: true}}}, 166*e7b1675dSTing-Kang Chang {"Read4UnreadRead2Disable", []step{{read: 4}, {unread: true}, {read: 2}, {disable: true}}}, 167*e7b1675dSTing-Kang Chang {"Read3UnreadRead3Disable", []step{{read: 3}, {unread: true}, {read: 3}, {disable: true}}}, 168*e7b1675dSTing-Kang Chang {"Read2UnreadDisable", []step{{read: 2}, {unread: true}, {disable: true}}}, 169*e7b1675dSTing-Kang Chang {"Read4UnreadDisable", []step{{read: 4}, {unread: true}, {disable: true}}}, 170*e7b1675dSTing-Kang Chang {"ReadAllUnread", []step{{read: len(original)}, {unread: true}}}, 171*e7b1675dSTing-Kang Chang {"ReadAllDisable", []step{{read: len(original)}, {disable: true}}}, 172*e7b1675dSTing-Kang Chang {"Unread", []step{{unread: true}}}, 173*e7b1675dSTing-Kang Chang {"Disable", []step{{disable: true}}}, 174*e7b1675dSTing-Kang Chang {"UnreadDisable", []step{{unread: true}, {disable: true}}}, 175*e7b1675dSTing-Kang Chang } 176*e7b1675dSTing-Kang Chang 177*e7b1675dSTing-Kang Chang for _, tc := range tcs { 178*e7b1675dSTing-Kang Chang t.Run(tc.name, func(t *testing.T) { 179*e7b1675dSTing-Kang Chang u := &unreader{r: bytes.NewReader(original)} 180*e7b1675dSTing-Kang Chang var ( 181*e7b1675dSTing-Kang Chang after []string 182*e7b1675dSTing-Kang Chang pos int 183*e7b1675dSTing-Kang Chang ) 184*e7b1675dSTing-Kang Chang // Explains what happened before the failure. 185*e7b1675dSTing-Kang Chang prefix := func() string { 186*e7b1675dSTing-Kang Chang if after == nil { 187*e7b1675dSTing-Kang Chang return "" 188*e7b1675dSTing-Kang Chang } 189*e7b1675dSTing-Kang Chang return fmt.Sprintf("After %s, ", strings.Join(after, "+")) 190*e7b1675dSTing-Kang Chang } 191*e7b1675dSTing-Kang Chang for _, s := range tc.steps { 192*e7b1675dSTing-Kang Chang if s.read != 0 { 193*e7b1675dSTing-Kang Chang buf := make([]byte, s.read) 194*e7b1675dSTing-Kang Chang if _, err := io.ReadFull(u, buf); err != nil { 195*e7b1675dSTing-Kang Chang t.Fatalf("%sio.ReadFull(%T, %d byte buffer) failed unexpectedly: %v", prefix(), u, s.read, err) 196*e7b1675dSTing-Kang Chang } 197*e7b1675dSTing-Kang Chang if want := original[pos : pos+s.read]; !bytes.Equal(buf, want) { 198*e7b1675dSTing-Kang Chang t.Fatalf("%sio.ReadFull(%T, %d byte buffer) got %q, want %q", prefix(), u, s.read, buf, want) 199*e7b1675dSTing-Kang Chang } 200*e7b1675dSTing-Kang Chang after = append(after, fmt.Sprintf("Read(%d bytes)", s.read)) 201*e7b1675dSTing-Kang Chang pos += s.read 202*e7b1675dSTing-Kang Chang } 203*e7b1675dSTing-Kang Chang if s.disable { 204*e7b1675dSTing-Kang Chang u.disable() 205*e7b1675dSTing-Kang Chang after = append(after, "disable()") 206*e7b1675dSTing-Kang Chang } 207*e7b1675dSTing-Kang Chang if s.unread { 208*e7b1675dSTing-Kang Chang u.unread() 209*e7b1675dSTing-Kang Chang after = append(after, "unread()") 210*e7b1675dSTing-Kang Chang pos = 0 211*e7b1675dSTing-Kang Chang } 212*e7b1675dSTing-Kang Chang } 213*e7b1675dSTing-Kang Chang got, err := io.ReadAll(u) 214*e7b1675dSTing-Kang Chang if err != nil { 215*e7b1675dSTing-Kang Chang t.Fatalf("%sio.ReadAll(%T) failed unexpectedly: %v", prefix(), u, err) 216*e7b1675dSTing-Kang Chang } 217*e7b1675dSTing-Kang Chang if want := original[pos:]; !bytes.Equal(want, got) { 218*e7b1675dSTing-Kang Chang t.Errorf("%sio.ReadAll(%T) got %q, want %q", prefix(), u, got, want) 219*e7b1675dSTing-Kang Chang } 220*e7b1675dSTing-Kang Chang }) 221*e7b1675dSTing-Kang Chang } 222*e7b1675dSTing-Kang Chang} 223