1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/streamingaead/buffered_input_stream.h"
18
19 #include <algorithm>
20 #include <cstring>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/memory/memory.h"
26 #include "absl/status/status.h"
27 #include "tink/input_stream.h"
28 #include "tink/util/errors.h"
29 #include "tink/util/status.h"
30 #include "tink/util/statusor.h"
31
32 namespace crypto {
33 namespace tink {
34 namespace streamingaead {
35
36 using util::Status;
37 using util::StatusOr;
38
BufferedInputStream(std::unique_ptr<crypto::tink::InputStream> input_stream)39 BufferedInputStream::BufferedInputStream(
40 std::unique_ptr<crypto::tink::InputStream> input_stream) {
41 input_stream_ = std::move(input_stream);
42 count_in_buffer_ = 0;
43 count_backedup_ = 0;
44 position_ = 0;
45 buffer_.resize(4 * 1024); // 4 KB
46 buffer_offset_ = 0;
47 after_rewind_ = false;
48 rewinding_enabled_ = true;
49 direct_access_ = false;
50 status_ = util::OkStatus();
51 }
52
Next(const void ** data)53 crypto::tink::util::StatusOr<int> BufferedInputStream::Next(const void** data) {
54 if (direct_access_) return input_stream_->Next(data);
55 if (!status_.ok()) return status_;
56
57 // We're just after rewind, so return all the data in the buffer, if any.
58 if (after_rewind_ && count_in_buffer_ > 0) {
59 after_rewind_ = false;
60 *data = buffer_.data();
61 position_ = count_in_buffer_;
62 return count_in_buffer_;
63 }
64 if (count_backedup_ > 0) { // Return the backed-up bytes.
65 buffer_offset_ = count_in_buffer_ - count_backedup_;
66 *data = buffer_.data() + buffer_offset_;
67 int backedup = count_backedup_;
68 count_backedup_ = 0;
69 position_ = count_in_buffer_;
70 return backedup;
71 }
72
73 // Read new bytes from input_stream_.
74 //
75 // If we don't allow rewind any more, all the data buffered so far
76 // can be discarded, and from now on we go directly to input_stream_
77 if (!rewinding_enabled_) {
78 direct_access_ = true;
79 buffer_.resize(0);
80 return input_stream_->Next(data);
81 }
82
83 // Otherwise, we read from input_stream_ the next chunk of data,
84 // and append it to buffer_.
85 after_rewind_ = false;
86 const void* buf;
87 auto next_result = input_stream_->Next(&buf);
88 if (!next_result.ok()) {
89 status_ = next_result.status();
90 return status_;
91 }
92 size_t count_read = next_result.value();
93 if (buffer_.size() < count_in_buffer_ + count_read) {
94 buffer_.resize(buffer_.size() + std::max(buffer_.size(), count_read));
95 }
96 memcpy(buffer_.data() + count_in_buffer_, buf, count_read);
97 buffer_offset_ = count_in_buffer_;
98 count_backedup_ = 0;
99 count_in_buffer_ += count_read;
100 position_ = position_ + count_read;
101 *data = buffer_.data() + buffer_offset_;
102 return count_read;
103 }
104
BackUp(int count)105 void BufferedInputStream::BackUp(int count) {
106 if (direct_access_) {
107 input_stream_->BackUp(count);
108 return;
109 }
110 if (!status_.ok() || count < 1 ||
111 count_backedup_ == (count_in_buffer_ - buffer_offset_)) {
112 return;
113 }
114 int actual_count = std::min(
115 count, count_in_buffer_ - buffer_offset_ - count_backedup_);
116 count_backedup_ += actual_count;
117 position_ = position_ - actual_count;
118 }
119
DisableRewinding()120 void BufferedInputStream::DisableRewinding() {
121 rewinding_enabled_ = false;
122 }
123
Rewind()124 crypto::tink::util::Status BufferedInputStream::Rewind() {
125 if (!rewinding_enabled_) {
126 return util::Status(absl::StatusCode::kInvalidArgument,
127 "rewinding is disabled");
128 }
129 if (status_.ok() || status_.code() == absl::StatusCode::kOutOfRange) {
130 status_ = util::OkStatus();
131 position_ = 0;
132 count_backedup_ = 0;
133 buffer_offset_ = 0;
134 after_rewind_ = true;
135 }
136 return status_;
137 }
138
139 BufferedInputStream::~BufferedInputStream() = default;
140
Position() const141 int64_t BufferedInputStream::Position() const {
142 if (direct_access_) return input_stream_->Position();
143 return position_;
144 }
145
146 } // namespace streamingaead
147 } // namespace tink
148 } // namespace crypto
149