// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // //////////////////////////////////////////////////////////////////////////////// package streamingaead import ( "bytes" "crypto/rand" "fmt" "io" "strings" "testing" "github.com/google/tink/go/subtle/random" "github.com/google/tink/go/testkeyset" "github.com/google/tink/go/testutil" tinkpb "github.com/google/tink/go/proto/tink_go_proto" ) func BenchmarkDecryptReader(b *testing.B) { b.ReportAllocs() // Create a Streaming AEAD primitive using a full keyset. decKeyset := testutil.NewTestAESGCMHKDFKeyset() decKeysetHandle, err := testkeyset.NewHandle(decKeyset) if err != nil { b.Fatalf("Failed creating keyset handle: %v", err) } decCipher, err := New(decKeysetHandle) if err != nil { b.Errorf("streamingaead.New failed: %v", err) } // Extract the raw key from the keyset and create a Streaming AEAD primitive // using only that key. // // testutil.NewTestAESGCMHKDFKeyset() places a raw key at position 1. rawKey := decKeyset.Key[1] if rawKey.OutputPrefixType != tinkpb.OutputPrefixType_RAW { b.Fatalf("Expected a raw key.") } encKeyset := testutil.NewKeyset(rawKey.KeyId, []*tinkpb.Keyset_Key{rawKey}) encKeysetHandle, err := testkeyset.NewHandle(encKeyset) if err != nil { b.Fatalf("Failed creating keyset handle: %v", err) } encCipher, err := New(encKeysetHandle) if err != nil { b.Fatalf("streamingaead.New failed: %v", err) } plaintext := random.GetRandomBytes(8) associatedData := random.GetRandomBytes(32) b.ResetTimer() for i := 0; i < b.N; i++ { // Create a pipe for communication between the encrypting writer and // decrypting reader. r, w := io.Pipe() defer r.Close() // Repeatedly encrypt the plaintext and write the ciphertext to a pipe. go func() { const writeAtLeast = 1 << 30 // 1 GiB enc, err := encCipher.NewEncryptingWriter(w, associatedData) if err != nil { b.Errorf("Cannot create encrypt writer: %v", err) return } for i := 0; i < writeAtLeast; i += len(plaintext) { if _, err := enc.Write(plaintext); err != nil { b.Errorf("Error encrypting data: %v", err) return } } if err := enc.Close(); err != nil { b.Errorf("Error closing encrypting writer: %v", err) return } if err := w.Close(); err != nil { b.Errorf("Error closing pipe: %v", err) return } }() // Decrypt the ciphertext in small chunks. dec, err := decCipher.NewDecryptingReader(r, associatedData) if err != nil { b.Fatalf("Cannot create decrypt reader: %v", err) } buf := make([]byte, 16384) // 16 KiB for { _, err := dec.Read(buf) if err == io.EOF { break } if err != nil { b.Fatalf("Error decrypting data: %v", err) } } } } func TestUnreaderUnread(t *testing.T) { original := make([]byte, 4096) if _, err := io.ReadFull(rand.Reader, original); err != nil { t.Fatalf("Failed to fill buffer with random bytes: %v", err) } u := &unreader{r: bytes.NewReader(original)} got, err := io.ReadAll(u) if err != nil { t.Errorf("First io.ReadAll(%T) failed unexpectedly: %v", u, err) } if !bytes.Equal(got, original) { t.Errorf("First io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, len(got), len(original), got, original) } u.unread() got, err = io.ReadAll(u) if err != nil { t.Errorf("After %T.unread(), io.ReadAll(%T) failed unexpectedly: %v", u, u, err) } if !bytes.Equal(got, original) { t.Errorf("After %T.unread(), io.ReadAll(%T) got %d bytes, want %d bytes that match the original random data.\nGot: %X\nWant: %X", u, u, len(got), len(original), got, original) } } func TestUnreader(t *testing.T) { // Repeating sequence of characters '0' through '9' makes it easy to see // holes or repeated data. original := make([]byte, 100) for i := range original { original[i] = '0' + byte(i%10) } type step struct { read int // If set, read the given number of bytes exactly. unread bool // If true, call unread(). disable bool // If true, call disable(). } tcs := []struct { name string steps []step }{ {"Read2UnreadRead4Unread", []step{{read: 2}, {unread: true}, {read: 4}, {unread: true}}}, {"Read4UnreadRead2Unread", []step{{read: 4}, {unread: true}, {read: 2}, {unread: true}}}, {"Read3UnreadRead3Unread", []step{{read: 3}, {unread: true}, {read: 3}, {unread: true}}}, {"Read3Disable", []step{{read: 3}, {disable: true}}}, {"Read2UnreadRead4Disable", []step{{read: 2}, {unread: true}, {read: 4}, {disable: true}}}, {"Read4UnreadRead2Disable", []step{{read: 4}, {unread: true}, {read: 2}, {disable: true}}}, {"Read3UnreadRead3Disable", []step{{read: 3}, {unread: true}, {read: 3}, {disable: true}}}, {"Read2UnreadDisable", []step{{read: 2}, {unread: true}, {disable: true}}}, {"Read4UnreadDisable", []step{{read: 4}, {unread: true}, {disable: true}}}, {"ReadAllUnread", []step{{read: len(original)}, {unread: true}}}, {"ReadAllDisable", []step{{read: len(original)}, {disable: true}}}, {"Unread", []step{{unread: true}}}, {"Disable", []step{{disable: true}}}, {"UnreadDisable", []step{{unread: true}, {disable: true}}}, } for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { u := &unreader{r: bytes.NewReader(original)} var ( after []string pos int ) // Explains what happened before the failure. prefix := func() string { if after == nil { return "" } return fmt.Sprintf("After %s, ", strings.Join(after, "+")) } for _, s := range tc.steps { if s.read != 0 { buf := make([]byte, s.read) if _, err := io.ReadFull(u, buf); err != nil { t.Fatalf("%sio.ReadFull(%T, %d byte buffer) failed unexpectedly: %v", prefix(), u, s.read, err) } if want := original[pos : pos+s.read]; !bytes.Equal(buf, want) { t.Fatalf("%sio.ReadFull(%T, %d byte buffer) got %q, want %q", prefix(), u, s.read, buf, want) } after = append(after, fmt.Sprintf("Read(%d bytes)", s.read)) pos += s.read } if s.disable { u.disable() after = append(after, "disable()") } if s.unread { u.unread() after = append(after, "unread()") pos = 0 } } got, err := io.ReadAll(u) if err != nil { t.Fatalf("%sio.ReadAll(%T) failed unexpectedly: %v", prefix(), u, err) } if want := original[pos:]; !bytes.Equal(want, got) { t.Errorf("%sio.ReadAll(%T) got %q, want %q", prefix(), u, got, want) } }) } }