xref: /aosp_15_r20/external/tink/cc/streamingaead/buffered_input_stream_test.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/streamingaead/buffered_input_stream.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "tink/input_stream.h"
32 #include "tink/subtle/random.h"
33 #include "tink/subtle/test_util.h"
34 #include "tink/util/istream_input_stream.h"
35 #include "tink/util/status.h"
36 #include "tink/util/test_matchers.h"
37 
38 namespace crypto {
39 namespace tink {
40 namespace streamingaead {
41 namespace {
42 
43 using crypto::tink::test::IsOk;
44 using crypto::tink::test::StatusIs;
45 using subtle::test::ReadFromStream;
46 using testing::HasSubstr;
47 
48 
49 static int kBufferSize = 4096;
50 
51 // Creates an InputStream with the specified contents.
GetInputStream(absl::string_view contents)52 std::unique_ptr<InputStream> GetInputStream(absl::string_view contents) {
53   // Prepare ciphertext source stream.
54   auto string_stream =
55       absl::make_unique<std::stringstream>(std::string(contents));
56   std::unique_ptr<InputStream> input_stream(
57       absl::make_unique<util::IstreamInputStream>(
58           std::move(string_stream), kBufferSize));
59   return input_stream;
60 }
61 
62 // Attempts to read 'count' bytes from 'input_stream', and writes the read
63 // bytes to 'output'.
ReadFromStream(InputStream * input_stream,int count,std::string * output)64 util::Status ReadFromStream(InputStream* input_stream, int count,
65                             std::string* output) {
66   if (input_stream == nullptr || output == nullptr || count < 0) {
67     return util::Status(absl::StatusCode::kInternal,
68                         "Illegal read from a stream");
69   }
70   const void* buffer;
71   output->clear();
72   int bytes_to_read = count;
73   while (bytes_to_read > 0) {
74     auto next_result = input_stream->Next(&buffer);
75     if (next_result.status().code() == absl::StatusCode::kOutOfRange) {
76       // End of stream.
77       return util::OkStatus();
78     }
79     if (!next_result.ok()) return next_result.status();
80     auto read_bytes = next_result.value();
81     auto used_bytes = std::min(read_bytes, bytes_to_read);
82     if (used_bytes > 0) {
83       output->append(
84           std::string(reinterpret_cast<const char*>(buffer), used_bytes));
85       bytes_to_read -= used_bytes;
86       if (bytes_to_read == 0) input_stream->BackUp(read_bytes - used_bytes);
87     }
88   }
89   return util::OkStatus();
90 }
91 
TEST(BufferedInputStreamTest,ReadingAndRewinding)92 TEST(BufferedInputStreamTest, ReadingAndRewinding) {
93   for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
94     std::string contents = subtle::Random::GetRandomBytes(input_size);
95     auto input_stream = GetInputStream(contents);
96     auto buf_stream = absl::make_unique<BufferedInputStream>(
97         std::move(input_stream));
98     for (auto read_size : {0, 1, 10, 123, 300}) {
99       SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
100                                 ", read_size = ", read_size));
101       // Read a prefix of the stream.
102       std::string prefix;
103       auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
104       EXPECT_THAT(status, IsOk());
105       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
106       EXPECT_EQ(contents.substr(0, read_size), prefix);
107 
108       // Read the rest of the stream.
109       std::string rest;
110       status = ReadFromStream(buf_stream.get(), &rest);
111       EXPECT_THAT(status, IsOk());
112       EXPECT_EQ(input_size, buf_stream->Position());
113       EXPECT_EQ(contents, prefix + rest);
114 
115       // Try reading again, should get an empty string.
116       status = ReadFromStream(buf_stream.get(), &rest);
117       EXPECT_THAT(status, IsOk());
118       EXPECT_EQ("", rest);
119 
120       // Rewind and read again, again in two parts.
121       status = buf_stream->Rewind();
122       EXPECT_EQ(0, buf_stream->Position());
123       EXPECT_THAT(status, IsOk());
124       status = ReadFromStream(buf_stream.get(), read_size, &prefix);
125       EXPECT_THAT(status, IsOk());
126       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
127       EXPECT_EQ(contents.substr(0, read_size), prefix);
128       status = ReadFromStream(buf_stream.get(), &rest);
129       EXPECT_THAT(status, IsOk());
130       EXPECT_EQ(input_size, buf_stream->Position());
131       EXPECT_EQ(contents, prefix + rest);
132 
133       // Rewind so that the next read iteration starts from the beginning.
134       status = buf_stream->Rewind();
135       EXPECT_THAT(status, IsOk());
136       EXPECT_EQ(0, buf_stream->Position());
137     }
138   }
139 }
140 
TEST(BufferedInputStreamTest,SingleBackup)141 TEST(BufferedInputStreamTest, SingleBackup) {
142   for (auto input_size : {0, 1, 10, 100, 1000, 10000, 100000}) {
143     std::string contents = subtle::Random::GetRandomBytes(input_size);
144     for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
145       SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
146                                 ", read_size = ", read_size));
147       auto input_stream = GetInputStream(contents);
148       auto buf_stream = absl::make_unique<BufferedInputStream>(
149           std::move(input_stream));
150 
151       // Read a part of the stream.
152       std::string prefix;
153       auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
154       EXPECT_THAT(status, IsOk());
155       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
156       EXPECT_EQ(contents.substr(0, read_size), prefix);
157 
158       // Read the next block of the stream, and then back it up.
159       const void* buf;
160       int pos = buf_stream->Position();
161       auto next_result = buf_stream->Next(&buf);
162       if (read_size < input_size) {
163         EXPECT_THAT(next_result, IsOk());
164         auto next_size = next_result.value();
165         EXPECT_LE(next_size, kBufferSize);
166         EXPECT_EQ(pos + next_size, buf_stream->Position());
167         buf_stream->BackUp(next_size);
168         EXPECT_EQ(pos, buf_stream->Position());
169         buf_stream->BackUp(input_size);
170         EXPECT_EQ(pos, buf_stream->Position());
171       } else {
172         EXPECT_EQ(absl::StatusCode::kOutOfRange, next_result.status().code());
173       }
174 
175       // Read the rest of the input.
176       std::string rest;
177       status = ReadFromStream(buf_stream.get(), &rest);
178       EXPECT_THAT(status, IsOk());
179       EXPECT_EQ(input_size, buf_stream->Position());
180       EXPECT_EQ(contents, prefix + rest);
181 
182       // Rewind and read prefix again.
183       status = buf_stream->Rewind();
184       EXPECT_EQ(0, buf_stream->Position());
185       EXPECT_THAT(status, IsOk());
186       status = ReadFromStream(buf_stream.get(), read_size, &prefix);
187       EXPECT_THAT(status, IsOk());
188       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
189       EXPECT_EQ(contents.substr(0, read_size), prefix);
190 
191       // The next buffer should contain the rest of the input, if any.
192       pos = buf_stream->Position();
193       next_result = buf_stream->Next(&buf);
194       if (read_size < input_size) {
195         EXPECT_THAT(next_result, IsOk());
196         auto next_size = next_result.value();
197         EXPECT_EQ(input_size - pos, next_size);
198         EXPECT_EQ(input_size, buf_stream->Position());
199         buf_stream->BackUp(next_size);
200         EXPECT_EQ(pos, buf_stream->Position());
201         buf_stream->BackUp(input_size);
202         EXPECT_EQ(pos, buf_stream->Position());
203       } else {
204         EXPECT_EQ(absl::StatusCode::kOutOfRange, next_result.status().code());
205       }
206 
207       // Read the rest of the input.
208       status = ReadFromStream(buf_stream.get(), &rest);
209       EXPECT_THAT(status, IsOk());
210       EXPECT_EQ(input_size, buf_stream->Position());
211       EXPECT_EQ(contents, prefix + rest);
212     }
213   }
214 }
215 
TEST(BufferedInputStreamTest,MultipleBackups)216 TEST(BufferedInputStreamTest, MultipleBackups) {
217   int input_size = 70000;
218   std::string contents = subtle::Random::GetRandomBytes(input_size);
219   auto input_stream = GetInputStream(contents);
220   auto buf_stream = absl::make_unique<BufferedInputStream>(
221       std::move(input_stream));
222   const void* buffer;
223 
224   EXPECT_EQ(0, buf_stream->Position());
225   auto next_result = buf_stream->Next(&buffer);
226   EXPECT_THAT(next_result, IsOk());
227   auto next_size = next_result.value();
228   EXPECT_EQ(contents.substr(0, next_size),
229             std::string(static_cast<const char*>(buffer), next_size));
230 
231   // BackUp several times, but in total fewer bytes than returned by Next().
232   int total_backup_size = 0;
233   for (auto backup_size : {0, 1, 5, 0, 10, 100, -42, 400, 20, -100}) {
234     buf_stream->BackUp(backup_size);
235     total_backup_size += std::max(0, backup_size);
236     EXPECT_EQ(next_size - total_backup_size, buf_stream->Position());
237   }
238   EXPECT_GT(next_size, total_backup_size);
239 
240   // Call Next(), it should return exactly the backed up bytes.
241   next_result = buf_stream->Next(&buffer);
242   EXPECT_THAT(next_result, IsOk());
243   EXPECT_EQ(total_backup_size, next_result.value());
244   EXPECT_EQ(next_size, buf_stream->Position());
245   EXPECT_EQ(contents.substr(next_size - total_backup_size, total_backup_size),
246             std::string(static_cast<const char*>(buffer), total_backup_size));
247 }
248 
TEST(BufferedInputStreamTest,DisableRewindingInitially)249 TEST(BufferedInputStreamTest, DisableRewindingInitially) {
250   for (auto input_size : {0, 10, 100, 1000, 10000}) {
251     std::string contents = subtle::Random::GetRandomBytes(input_size);
252     for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
253       SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
254                                 ", read_size = ", read_size));
255       auto input_stream = GetInputStream(contents);
256       auto buf_stream = absl::make_unique<BufferedInputStream>(
257           std::move(input_stream));
258 
259       // Disable rewinding, and attempt rewind.
260       EXPECT_EQ(0, buf_stream->Position());
261       buf_stream->DisableRewinding();
262       auto status = buf_stream->Rewind();
263       EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
264                                    HasSubstr("rewinding is disabled")));
265 
266       // Read a prefix of the stream.
267       std::string prefix;
268       status = ReadFromStream(buf_stream.get(), read_size, &prefix);
269       EXPECT_THAT(status, IsOk());
270       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
271       EXPECT_EQ(contents.substr(0, read_size), prefix);
272 
273       // Attempt rewidning again.
274       status = buf_stream->Rewind();
275       EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
276                                    HasSubstr("rewinding is disabled")));
277 
278       // Read the rest of the input.
279       std::string rest;
280       status = ReadFromStream(buf_stream.get(), &rest);
281       EXPECT_THAT(status, IsOk());
282       EXPECT_EQ(input_size, buf_stream->Position());
283       EXPECT_EQ(contents, prefix + rest);
284     }
285   }
286 }
287 
TEST(BufferedInputStreamTest,DisableRewindingAfterRewind)288 TEST(BufferedInputStreamTest, DisableRewindingAfterRewind) {
289   for (auto input_size : {0, 10, 100, 1000, 10000}) {
290     std::string contents = subtle::Random::GetRandomBytes(input_size);
291     for (auto read_size : {0, 1, 10, 123, 300, 1024}) {
292       SCOPED_TRACE(absl::StrCat("input_size = ", input_size,
293                                 ", read_size = ", read_size));
294       auto input_stream = GetInputStream(contents);
295       auto buf_stream = absl::make_unique<BufferedInputStream>(
296           std::move(input_stream));
297 
298       // Read a prefix of the stream.
299       std::string prefix;
300       auto status = ReadFromStream(buf_stream.get(), read_size, &prefix);
301       EXPECT_THAT(status, IsOk());
302       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
303       EXPECT_EQ(contents.substr(0, read_size), prefix);
304 
305       // Rewind, and disable rewinding.
306       status = buf_stream->Rewind();
307       EXPECT_THAT(status, IsOk());
308       EXPECT_EQ(0, buf_stream->Position());
309       buf_stream->DisableRewinding();
310       status = buf_stream->Rewind();
311       EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
312                                    HasSubstr("rewinding is disabled")));
313       // Read the prefix again.
314       status = ReadFromStream(buf_stream.get(), read_size, &prefix);
315       EXPECT_THAT(status, IsOk());
316       EXPECT_EQ(std::min(read_size, input_size), buf_stream->Position());
317       EXPECT_EQ(contents.substr(0, read_size), prefix);
318 
319       // Read the rest of the input.
320       std::string rest;
321       status = ReadFromStream(buf_stream.get(), &rest);
322       EXPECT_THAT(status, IsOk());
323       EXPECT_EQ(input_size, buf_stream->Position());
324       EXPECT_EQ(contents, prefix + rest);
325     }
326   }
327 }
328 
329 }  // namespace
330 }  // namespace streamingaead
331 }  // namespace tink
332 }  // namespace crypto
333