1*e7b1675dSTing-Kang Chang// Copyright 2020 Google LLC 2*e7b1675dSTing-Kang Chang// 3*e7b1675dSTing-Kang Chang// Licensed under the Apache License, Version 2.0 (the "License"); 4*e7b1675dSTing-Kang Chang// you may not use this file except in compliance with the License. 5*e7b1675dSTing-Kang Chang// You may obtain a copy of the License at 6*e7b1675dSTing-Kang Chang// 7*e7b1675dSTing-Kang Chang// http://www.apache.org/licenses/LICENSE-2.0 8*e7b1675dSTing-Kang Chang// 9*e7b1675dSTing-Kang Chang// Unless required by applicable law or agreed to in writing, software 10*e7b1675dSTing-Kang Chang// distributed under the License is distributed on an "AS IS" BASIS, 11*e7b1675dSTing-Kang Chang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*e7b1675dSTing-Kang Chang// See the License for the specific language governing permissions and 13*e7b1675dSTing-Kang Chang// limitations under the License. 14*e7b1675dSTing-Kang Chang// 15*e7b1675dSTing-Kang Chang//////////////////////////////////////////////////////////////////////////////// 16*e7b1675dSTing-Kang Chang 17*e7b1675dSTing-Kang Changpackage streamingaead 18*e7b1675dSTing-Kang Chang 19*e7b1675dSTing-Kang Changimport ( 20*e7b1675dSTing-Kang Chang "errors" 21*e7b1675dSTing-Kang Chang "io" 22*e7b1675dSTing-Kang Chang 23*e7b1675dSTing-Kang Chang "github.com/google/tink/go/core/primitiveset" 24*e7b1675dSTing-Kang Chang "github.com/google/tink/go/tink" 25*e7b1675dSTing-Kang Chang) 26*e7b1675dSTing-Kang Chang 27*e7b1675dSTing-Kang Changvar ( 28*e7b1675dSTing-Kang Chang _ io.Reader = &decryptReader{} 29*e7b1675dSTing-Kang Chang errKeyNotFound = errors.New("no matching key found for the ciphertext in the stream") 30*e7b1675dSTing-Kang Chang) 31*e7b1675dSTing-Kang Chang 32*e7b1675dSTing-Kang Chang// decryptReader is a reader that tries to find the right key to decrypt ciphertext from the given primitive set. 33*e7b1675dSTing-Kang Changtype decryptReader struct { 34*e7b1675dSTing-Kang Chang wrapped *wrappedStreamingAEAD 35*e7b1675dSTing-Kang Chang // cr is a source Reader which provides ciphertext to be decrypted. 36*e7b1675dSTing-Kang Chang cr io.Reader 37*e7b1675dSTing-Kang Chang aad []byte 38*e7b1675dSTing-Kang Chang 39*e7b1675dSTing-Kang Chang matchAttempted bool 40*e7b1675dSTing-Kang Chang // mr is a matched decrypting reader initialized with a proper key to decrypt ciphertext. 41*e7b1675dSTing-Kang Chang mr io.Reader 42*e7b1675dSTing-Kang Chang} 43*e7b1675dSTing-Kang Chang 44*e7b1675dSTing-Kang Changfunc (dr *decryptReader) Read(p []byte) (n int, err error) { 45*e7b1675dSTing-Kang Chang if dr.mr != nil { 46*e7b1675dSTing-Kang Chang return dr.mr.Read(p) 47*e7b1675dSTing-Kang Chang } 48*e7b1675dSTing-Kang Chang if dr.matchAttempted { 49*e7b1675dSTing-Kang Chang return 0, errKeyNotFound 50*e7b1675dSTing-Kang Chang } 51*e7b1675dSTing-Kang Chang 52*e7b1675dSTing-Kang Chang // For legacy reasons (Tink always encrypted with non-RAW keys) we use all 53*e7b1675dSTing-Kang Chang // primitives, even those which have output_prefix_type != RAW. 54*e7b1675dSTing-Kang Chang var allEntries []*primitiveset.Entry 55*e7b1675dSTing-Kang Chang for _, entryList := range dr.wrapped.ps.Entries { 56*e7b1675dSTing-Kang Chang allEntries = append(allEntries, entryList...) 57*e7b1675dSTing-Kang Chang } 58*e7b1675dSTing-Kang Chang if err != nil { 59*e7b1675dSTing-Kang Chang return 0, err 60*e7b1675dSTing-Kang Chang } 61*e7b1675dSTing-Kang Chang 62*e7b1675dSTing-Kang Chang dr.matchAttempted = true 63*e7b1675dSTing-Kang Chang ur := &unreader{r: dr.cr} 64*e7b1675dSTing-Kang Chang 65*e7b1675dSTing-Kang Chang // find proper key to decrypt ciphertext 66*e7b1675dSTing-Kang Chang for _, e := range allEntries { 67*e7b1675dSTing-Kang Chang sa, ok := e.Primitive.(tink.StreamingAEAD) 68*e7b1675dSTing-Kang Chang if !ok { 69*e7b1675dSTing-Kang Chang continue 70*e7b1675dSTing-Kang Chang } 71*e7b1675dSTing-Kang Chang 72*e7b1675dSTing-Kang Chang read := func() (io.Reader, int, error) { 73*e7b1675dSTing-Kang Chang r, err := sa.NewDecryptingReader(ur, dr.aad) 74*e7b1675dSTing-Kang Chang if err != nil { 75*e7b1675dSTing-Kang Chang return nil, 0, err 76*e7b1675dSTing-Kang Chang } 77*e7b1675dSTing-Kang Chang n, err := r.Read(p) 78*e7b1675dSTing-Kang Chang if err != nil { 79*e7b1675dSTing-Kang Chang return nil, 0, err 80*e7b1675dSTing-Kang Chang } 81*e7b1675dSTing-Kang Chang return r, n, nil 82*e7b1675dSTing-Kang Chang } 83*e7b1675dSTing-Kang Chang 84*e7b1675dSTing-Kang Chang r, n, err := read() 85*e7b1675dSTing-Kang Chang if err == nil { 86*e7b1675dSTing-Kang Chang dr.mr = r 87*e7b1675dSTing-Kang Chang ur.disable() 88*e7b1675dSTing-Kang Chang return n, nil 89*e7b1675dSTing-Kang Chang } 90*e7b1675dSTing-Kang Chang 91*e7b1675dSTing-Kang Chang ur.unread() 92*e7b1675dSTing-Kang Chang } 93*e7b1675dSTing-Kang Chang return 0, errKeyNotFound 94*e7b1675dSTing-Kang Chang} 95*e7b1675dSTing-Kang Chang 96*e7b1675dSTing-Kang Chang// unreader wraps a reader and keeps a copy of everything that's read so it can 97*e7b1675dSTing-Kang Chang// be unread and read again. When no additional unreads are needed, the buffer 98*e7b1675dSTing-Kang Chang// can be disabled and the memory released. 99*e7b1675dSTing-Kang Changtype unreader struct { 100*e7b1675dSTing-Kang Chang r io.Reader 101*e7b1675dSTing-Kang Chang buf []byte 102*e7b1675dSTing-Kang Chang pos int 103*e7b1675dSTing-Kang Chang disabled bool 104*e7b1675dSTing-Kang Chang} 105*e7b1675dSTing-Kang Chang 106*e7b1675dSTing-Kang Changfunc (u *unreader) Read(buf []byte) (int, error) { 107*e7b1675dSTing-Kang Chang if len(u.buf) != u.pos { 108*e7b1675dSTing-Kang Chang n := copy(buf, u.buf[u.pos:]) 109*e7b1675dSTing-Kang Chang u.pos += n 110*e7b1675dSTing-Kang Chang return n, nil 111*e7b1675dSTing-Kang Chang } 112*e7b1675dSTing-Kang Chang n, err := u.r.Read(buf) 113*e7b1675dSTing-Kang Chang if u.disabled { 114*e7b1675dSTing-Kang Chang u.buf = nil 115*e7b1675dSTing-Kang Chang u.pos = 0 116*e7b1675dSTing-Kang Chang } else { 117*e7b1675dSTing-Kang Chang u.buf = append(u.buf, buf[:n]...) 118*e7b1675dSTing-Kang Chang u.pos = len(u.buf) 119*e7b1675dSTing-Kang Chang } 120*e7b1675dSTing-Kang Chang return n, err 121*e7b1675dSTing-Kang Chang} 122*e7b1675dSTing-Kang Chang 123*e7b1675dSTing-Kang Chang// unread starts the reader over again. A copy of all read data will be returned 124*e7b1675dSTing-Kang Chang// by `Read()` before the wrapped reader is read from again. 125*e7b1675dSTing-Kang Changfunc (u *unreader) unread() { 126*e7b1675dSTing-Kang Chang u.pos = 0 127*e7b1675dSTing-Kang Chang} 128*e7b1675dSTing-Kang Chang 129*e7b1675dSTing-Kang Chang// disable ensures the buffer is released for garbage collection once it's no 130*e7b1675dSTing-Kang Chang// longer needed. 131*e7b1675dSTing-Kang Changfunc (u *unreader) disable() { 132*e7b1675dSTing-Kang Chang u.disabled = true 133*e7b1675dSTing-Kang Chang} 134