1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/streamingaead/buffered_input_stream.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "tink/input_stream.h"
32 #include "tink/subtle/random.h"
33 #include "tink/subtle/test_util.h"
34 #include "tink/util/istream_input_stream.h"
35 #include "tink/util/status.h"
36 #include "tink/util/test_matchers.h"
37
38 namespace crypto {
39 namespace tink {
40 namespace streamingaead {
41 namespace {
42
43 using crypto::tink::test::IsOk;
44 using crypto::tink::test::StatusIs;
45 using subtle::test::ReadFromStream;
46 using testing::HasSubstr;
47
48
49 static int kBufferSize = 4096;
50
51 // Creates an InputStream with the specified contents.
GetInputStream(absl::string_view contents)52 std::unique_ptr<InputStream> GetInputStream(absl::string_view contents) {
53 // Prepare ciphertext source stream.
54 auto string_stream =
55 absl::make_unique<std::stringstream>(std::string(contents));
56 std::unique_ptr<InputStream> input_stream(
57 absl::make_unique<util::IstreamInputStream>(
58 std::move(string_stream), kBufferSize));
59 return input_stream;
60 }
61
62 // Attempts to read 'count' bytes from 'input_stream', and writes the read
63 // bytes to 'output'.
ReadFromStream(InputStream * input_stream,int count,std::string * output)64 util::Status ReadFromStream(InputStream* input_stream, int count,
65 std::string* output) {
66 if (input_stream == nullptr || output == nullptr || count < 0) {
67 return util::Status(absl::StatusCode::kInternal,
68 "Illegal read from a stream");
69 }
70 const void* buffer;
71 output->clear();
72 int bytes_to_read = count;
73 while (bytes_to_read > 0) {
74 auto next_result = input_stream->Next(&buffer);
75 if (next_result.status().code() == absl::StatusCode::kOutOfRange) {
76 // End of stream.
77 return util::OkStatus();
78 }
79 if (!next_result.ok()) return next_result.status();
80 auto read_bytes = next_result.value();
81 auto used_bytes = std::min(read_bytes, bytes_to_read);
82 if (used_bytes > 0) {
83 output->append(
84 std::string(reinterpret_cast<const char*>(buffer), used_bytes));
85 bytes_to_read -= used_bytes;
86 if (bytes_to_read == 0) input_stream->BackUp(read_bytes - used_bytes);
87 }
88 }
89 return util::OkStatus();
90 }
91
TEST(BufferedInputStreamTest,ReadingAndRewinding)92 TEST(BufferedInputStreamTest, ReadingAndRewinding) {
93 for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
94 std::string contents = subtle::Random::GetRandomBytes(input_size);
95 auto input_stream = GetInputStream(contents);
96 auto buf_stream = absl::make_unique<BufferedInputStream>(
97 std::move(input_stream));
98 for (auto read_size : {0, 1, 10, 123, 300}) {
99 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
100 ", read_size = ", read_size));
101 // Read a prefix of the stream.
102 std::string prefix;
103 auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
104 EXPECT_THAT(status, IsOk());
105 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
106 EXPECT_EQ(contents.substr(0, read_size), prefix);
107
108 // Read the rest of the stream.
109 std::string rest;
110 status = ReadFromStream(buf_stream.get(), &rest);
111 EXPECT_THAT(status, IsOk());
112 EXPECT_EQ(input_size, buf_stream->Position());
113 EXPECT_EQ(contents, prefix + rest);
114
115 // Try reading again, should get an empty string.
116 status = ReadFromStream(buf_stream.get(), &rest);
117 EXPECT_THAT(status, IsOk());
118 EXPECT_EQ("", rest);
119
120 // Rewind and read again, again in two parts.
121 status = buf_stream->Rewind();
122 EXPECT_EQ(0, buf_stream->Position());
123 EXPECT_THAT(status, IsOk());
124 status = ReadFromStream(buf_stream.get(), read_size, &prefix);
125 EXPECT_THAT(status, IsOk());
126 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
127 EXPECT_EQ(contents.substr(0, read_size), prefix);
128 status = ReadFromStream(buf_stream.get(), &rest);
129 EXPECT_THAT(status, IsOk());
130 EXPECT_EQ(input_size, buf_stream->Position());
131 EXPECT_EQ(contents, prefix + rest);
132
133 // Rewind so that the next read iteration starts from the beginning.
134 status = buf_stream->Rewind();
135 EXPECT_THAT(status, IsOk());
136 EXPECT_EQ(0, buf_stream->Position());
137 }
138 }
139 }
140
TEST(BufferedInputStreamTest,SingleBackup)141 TEST(BufferedInputStreamTest, SingleBackup) {
142 for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
143 std::string contents = subtle::Random::GetRandomBytes(input_size);
144 for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
145 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
146 ", read_size = ", read_size));
147 auto input_stream = GetInputStream(contents);
148 auto buf_stream = absl::make_unique<BufferedInputStream>(
149 std::move(input_stream));
150
151 // Read a part of the stream.
152 std::string prefix;
153 auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
154 EXPECT_THAT(status, IsOk());
155 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
156 EXPECT_EQ(contents.substr(0, read_size), prefix);
157
158 // Read the next block of the stream, and then back it up.
159 const void* buf;
160 int pos = buf_stream->Position();
161 auto next_result = buf_stream->Next(&buf);
162 if (read_size < input_size) {
163 EXPECT_THAT(next_result, IsOk());
164 auto next_size = next_result.value();
165 EXPECT_LE(next_size, kBufferSize);
166 EXPECT_EQ(pos + next_size, buf_stream->Position());
167 buf_stream->BackUp(next_size);
168 EXPECT_EQ(pos, buf_stream->Position());
169 buf_stream->BackUp(input_size);
170 EXPECT_EQ(pos, buf_stream->Position());
171 } else {
172 EXPECT_EQ(absl::StatusCode::kOutOfRange, next_result.status().code());
173 }
174
175 // Read the rest of the input.
176 std::string rest;
177 status = ReadFromStream(buf_stream.get(), &rest);
178 EXPECT_THAT(status, IsOk());
179 EXPECT_EQ(input_size, buf_stream->Position());
180 EXPECT_EQ(contents, prefix + rest);
181
182 // Rewind and read prefix again.
183 status = buf_stream->Rewind();
184 EXPECT_EQ(0, buf_stream->Position());
185 EXPECT_THAT(status, IsOk());
186 status = ReadFromStream(buf_stream.get(), read_size, &prefix);
187 EXPECT_THAT(status, IsOk());
188 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
189 EXPECT_EQ(contents.substr(0, read_size), prefix);
190
191 // The next buffer should contain the rest of the input, if any.
192 pos = buf_stream->Position();
193 next_result = buf_stream->Next(&buf);
194 if (read_size < input_size) {
195 EXPECT_THAT(next_result, IsOk());
196 auto next_size = next_result.value();
197 EXPECT_EQ(input_size - pos, next_size);
198 EXPECT_EQ(input_size, buf_stream->Position());
199 buf_stream->BackUp(next_size);
200 EXPECT_EQ(pos, buf_stream->Position());
201 buf_stream->BackUp(input_size);
202 EXPECT_EQ(pos, buf_stream->Position());
203 } else {
204 EXPECT_EQ(absl::StatusCode::kOutOfRange, next_result.status().code());
205 }
206
207 // Read the rest of the input.
208 status = ReadFromStream(buf_stream.get(), &rest);
209 EXPECT_THAT(status, IsOk());
210 EXPECT_EQ(input_size, buf_stream->Position());
211 EXPECT_EQ(contents, prefix + rest);
212 }
213 }
214 }
215
TEST(BufferedInputStreamTest,MultipleBackups)216 TEST(BufferedInputStreamTest, MultipleBackups) {
217 int input_size = 70000;
218 std::string contents = subtle::Random::GetRandomBytes(input_size);
219 auto input_stream = GetInputStream(contents);
220 auto buf_stream = absl::make_unique<BufferedInputStream>(
221 std::move(input_stream));
222 const void* buffer;
223
224 EXPECT_EQ(0, buf_stream->Position());
225 auto next_result = buf_stream->Next(&buffer);
226 EXPECT_THAT(next_result, IsOk());
227 auto next_size = next_result.value();
228 EXPECT_EQ(contents.substr(0, next_size),
229 std::string(static_cast<const char*>(buffer), next_size));
230
231 // BackUp several times, but in total fewer bytes than returned by Next().
232 int total_backup_size = 0;
233 for (auto backup_size : {0, 1, 5, 0, 10, 100, -42, 400, 20, -100}) {
234 buf_stream->BackUp(backup_size);
235 total_backup_size += std::max(0, backup_size);
236 EXPECT_EQ(next_size - total_backup_size, buf_stream->Position());
237 }
238 EXPECT_GT(next_size, total_backup_size);
239
240 // Call Next(), it should return exactly the backed up bytes.
241 next_result = buf_stream->Next(&buffer);
242 EXPECT_THAT(next_result, IsOk());
243 EXPECT_EQ(total_backup_size, next_result.value());
244 EXPECT_EQ(next_size, buf_stream->Position());
245 EXPECT_EQ(contents.substr(next_size - total_backup_size, total_backup_size),
246 std::string(static_cast<const char*>(buffer), total_backup_size));
247 }
248
TEST(BufferedInputStreamTest,DisableRewindingInitially)249 TEST(BufferedInputStreamTest, DisableRewindingInitially) {
250 for (auto input_size : {0, 10, 100, 1000, 10000}) {
251 std::string contents = subtle::Random::GetRandomBytes(input_size);
252 for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
253 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
254 ", read_size = ", read_size));
255 auto input_stream = GetInputStream(contents);
256 auto buf_stream = absl::make_unique<BufferedInputStream>(
257 std::move(input_stream));
258
259 // Disable rewinding, and attempt rewind.
260 EXPECT_EQ(0, buf_stream->Position());
261 buf_stream->DisableRewinding();
262 auto status = buf_stream->Rewind();
263 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
264 HasSubstr("rewinding is disabled")));
265
266 // Read a prefix of the stream.
267 std::string prefix;
268 status = ReadFromStream(buf_stream.get(), read_size, &prefix);
269 EXPECT_THAT(status, IsOk());
270 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
271 EXPECT_EQ(contents.substr(0, read_size), prefix);
272
273 // Attempt rewidning again.
274 status = buf_stream->Rewind();
275 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
276 HasSubstr("rewinding is disabled")));
277
278 // Read the rest of the input.
279 std::string rest;
280 status = ReadFromStream(buf_stream.get(), &rest);
281 EXPECT_THAT(status, IsOk());
282 EXPECT_EQ(input_size, buf_stream->Position());
283 EXPECT_EQ(contents, prefix + rest);
284 }
285 }
286 }
287
TEST(BufferedInputStreamTest,DisableRewindingAfterRewind)288 TEST(BufferedInputStreamTest, DisableRewindingAfterRewind) {
289 for (auto input_size : {0, 10, 100, 1000, 10000}) {
290 std::string contents = subtle::Random::GetRandomBytes(input_size);
291 for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
292 SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
293 ", read_size = ", read_size));
294 auto input_stream = GetInputStream(contents);
295 auto buf_stream = absl::make_unique<BufferedInputStream>(
296 std::move(input_stream));
297
298 // Read a prefix of the stream.
299 std::string prefix;
300 auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
301 EXPECT_THAT(status, IsOk());
302 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
303 EXPECT_EQ(contents.substr(0, read_size), prefix);
304
305 // Rewind, and disable rewinding.
306 status = buf_stream->Rewind();
307 EXPECT_THAT(status, IsOk());
308 EXPECT_EQ(0, buf_stream->Position());
309 buf_stream->DisableRewinding();
310 status = buf_stream->Rewind();
311 EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
312 HasSubstr("rewinding is disabled")));
313 // Read the prefix again.
314 status = ReadFromStream(buf_stream.get(), read_size, &prefix);
315 EXPECT_THAT(status, IsOk());
316 EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
317 EXPECT_EQ(contents.substr(0, read_size), prefix);
318
319 // Read the rest of the input.
320 std::string rest;
321 status = ReadFromStream(buf_stream.get(), &rest);
322 EXPECT_THAT(status, IsOk());
323 EXPECT_EQ(input_size, buf_stream->Position());
324 EXPECT_EQ(contents, prefix + rest);
325 }
326 }
327 }
328
329 } // namespace
330 } // namespace streamingaead
331 } // namespace tink
332 } // namespace crypto
333