1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15//////////////////////////////////////////////////////////////////////////////// 16 17package streamingaead 18 19import ( 20 "errors" 21 "io" 22 23 "github.com/google/tink/go/core/primitiveset" 24 "github.com/google/tink/go/tink" 25) 26 27var ( 28 _ io.Reader = &decryptReader{} 29 errKeyNotFound = errors.New("no matching key found for the ciphertext in the stream") 30) 31 32// decryptReader is a reader that tries to find the right key to decrypt ciphertext from the given primitive set. 33type decryptReader struct { 34 wrapped *wrappedStreamingAEAD 35 // cr is a source Reader which provides ciphertext to be decrypted. 36 cr io.Reader 37 aad []byte 38 39 matchAttempted bool 40 // mr is a matched decrypting reader initialized with a proper key to decrypt ciphertext. 41 mr io.Reader 42} 43 44func (dr *decryptReader) Read(p []byte) (n int, err error) { 45 if dr.mr != nil { 46 return dr.mr.Read(p) 47 } 48 if dr.matchAttempted { 49 return 0, errKeyNotFound 50 } 51 52 // For legacy reasons (Tink always encrypted with non-RAW keys) we use all 53 // primitives, even those which have output_prefix_type != RAW. 54 var allEntries []*primitiveset.Entry 55 for _, entryList := range dr.wrapped.ps.Entries { 56 allEntries = append(allEntries, entryList...) 57 } 58 if err != nil { 59 return 0, err 60 } 61 62 dr.matchAttempted = true 63 ur := &unreader{r: dr.cr} 64 65 // find proper key to decrypt ciphertext 66 for _, e := range allEntries { 67 sa, ok := e.Primitive.(tink.StreamingAEAD) 68 if !ok { 69 continue 70 } 71 72 read := func() (io.Reader, int, error) { 73 r, err := sa.NewDecryptingReader(ur, dr.aad) 74 if err != nil { 75 return nil, 0, err 76 } 77 n, err := r.Read(p) 78 if err != nil { 79 return nil, 0, err 80 } 81 return r, n, nil 82 } 83 84 r, n, err := read() 85 if err == nil { 86 dr.mr = r 87 ur.disable() 88 return n, nil 89 } 90 91 ur.unread() 92 } 93 return 0, errKeyNotFound 94} 95 96// unreader wraps a reader and keeps a copy of everything that's read so it can 97// be unread and read again. When no additional unreads are needed, the buffer 98// can be disabled and the memory released. 99type unreader struct { 100 r io.Reader 101 buf []byte 102 pos int 103 disabled bool 104} 105 106func (u *unreader) Read(buf []byte) (int, error) { 107 if len(u.buf) != u.pos { 108 n := copy(buf, u.buf[u.pos:]) 109 u.pos += n 110 return n, nil 111 } 112 n, err := u.r.Read(buf) 113 if u.disabled { 114 u.buf = nil 115 u.pos = 0 116 } else { 117 u.buf = append(u.buf, buf[:n]...) 118 u.pos = len(u.buf) 119 } 120 return n, err 121} 122 123// unread starts the reader over again. A copy of all read data will be returned 124// by `Read()` before the wrapped reader is read from again. 125func (u *unreader) unread() { 126 u.pos = 0 127} 128 129// disable ensures the buffer is released for garbage collection once it's no 130// longer needed. 131func (u *unreader) disable() { 132 u.disabled = true 133} 134