xref: /aosp_15_r20/external/tink/cc/streamingaead/buffered_input_stream.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/streamingaead/buffered_input_stream.h"
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/status/status.h"
27 #include "tink/input_stream.h"
28 #include "tink/util/errors.h"
29 #include "tink/util/status.h"
30 #include "tink/util/statusor.h"
31 
32 namespace crypto {
33 namespace tink {
34 namespace streamingaead {
35 
36 using util::Status;
37 using util::StatusOr;
38 
BufferedInputStream(std::unique_ptr<crypto::tink::InputStream> input_stream)39 BufferedInputStream::BufferedInputStream(
40     std::unique_ptr<crypto::tink::InputStream> input_stream) {
41   input_stream_ = std::move(input_stream);
42   count_in_buffer_ = 0;
43   count_backedup_ = 0;
44   position_ = 0;
45   buffer_.resize(4 * 1024);  // 4 KB
46   buffer_offset_ = 0;
47   after_rewind_ = false;
48   rewinding_enabled_ = true;
49   direct_access_ = false;
50   status_ = util::OkStatus();
51 }
52 
Next(const void ** data)53 crypto::tink::util::StatusOr<int> BufferedInputStream::Next(const void** data) {
54   if (direct_access_) return input_stream_->Next(data);
55   if (!status_.ok()) return status_;
56 
57   // We're just after rewind, so return all the data in the buffer, if any.
58   if (after_rewind_ && count_in_buffer_ > 0) {
59     after_rewind_ = false;
60     *data = buffer_.data();
61     position_ = count_in_buffer_;
62     return count_in_buffer_;
63   }
64   if (count_backedup_ > 0) {  // Return the backed-up bytes.
65     buffer_offset_ = count_in_buffer_ - count_backedup_;
66     *data = buffer_.data() + buffer_offset_;
67     int backedup = count_backedup_;
68     count_backedup_ = 0;
69     position_ = count_in_buffer_;
70     return backedup;
71   }
72 
73   // Read new bytes from input_stream_.
74   //
75   // If we don't allow rewind any more, all the data buffered so far
76   // can be discarded, and from now on we go directly to input_stream_
77   if (!rewinding_enabled_) {
78     direct_access_ = true;
79     buffer_.resize(0);
80     return input_stream_->Next(data);
81   }
82 
83   // Otherwise, we read from input_stream_ the next chunk of data,
84   // and append it to buffer_.
85   after_rewind_ = false;
86   const void* buf;
87   auto next_result = input_stream_->Next(&buf);
88   if (!next_result.ok()) {
89     status_ = next_result.status();
90     return status_;
91   }
92   size_t count_read = next_result.value();
93   if (buffer_.size() < count_in_buffer_ + count_read) {
94     buffer_.resize(buffer_.size() + std::max(buffer_.size(), count_read));
95   }
96   memcpy(buffer_.data() + count_in_buffer_, buf, count_read);
97   buffer_offset_ = count_in_buffer_;
98   count_backedup_ = 0;
99   count_in_buffer_ += count_read;
100   position_ = position_ + count_read;
101   *data = buffer_.data() + buffer_offset_;
102   return count_read;
103 }
104 
BackUp(int count)105 void BufferedInputStream::BackUp(int count) {
106   if (direct_access_) {
107     input_stream_->BackUp(count);
108     return;
109   }
110   if (!status_.ok() || count < 1 ||
111       count_backedup_ == (count_in_buffer_ - buffer_offset_)) {
112     return;
113   }
114   int actual_count = std::min(
115       count, count_in_buffer_ - buffer_offset_ - count_backedup_);
116   count_backedup_ += actual_count;
117   position_ = position_ - actual_count;
118 }
119 
DisableRewinding()120 void BufferedInputStream::DisableRewinding() {
121   rewinding_enabled_ = false;
122 }
123 
Rewind()124 crypto::tink::util::Status BufferedInputStream::Rewind() {
125   if (!rewinding_enabled_) {
126     return util::Status(absl::StatusCode::kInvalidArgument,
127                         "rewinding is disabled");
128   }
129   if (status_.ok() || status_.code() == absl::StatusCode::kOutOfRange) {
130     status_ = util::OkStatus();
131     position_ = 0;
132     count_backedup_ = 0;
133     buffer_offset_ = 0;
134     after_rewind_ = true;
135   }
136   return status_;
137 }
138 
139 BufferedInputStream::~BufferedInputStream() = default;
140 
Position() const141 int64_t BufferedInputStream::Position() const {
142   if (direct_access_) return input_stream_->Position();
143   return position_;
144 }
145 
146 }  // namespace streamingaead
147 }  // namespace tink
148 }  // namespace crypto
149