xref: /aosp_15_r20/external/tink/cc/streamingaead/decrypting_random_access_stream_test.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2019 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/streamingaead/decrypting_random_access_stream.h"
18 
19 #include <memory>
20 #include <sstream>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "tink/internal/test_random_access_stream.h"
32 #include "tink/output_stream.h"
33 #include "tink/primitive_set.h"
34 #include "tink/random_access_stream.h"
35 #include "tink/streaming_aead.h"
36 #include "tink/subtle/random.h"
37 #include "tink/subtle/test_util.h"
38 #include "tink/util/ostream_output_stream.h"
39 #include "tink/util/status.h"
40 #include "tink/util/test_matchers.h"
41 #include "tink/util/test_util.h"
42 #include "proto/tink.pb.h"
43 
44 namespace crypto {
45 namespace tink {
46 namespace streamingaead {
47 namespace {
48 
49 using crypto::tink::test::DummyStreamingAead;
50 using crypto::tink::test::IsOk;
51 using crypto::tink::test::StatusIs;
52 using google::crypto::tink::KeysetInfo;
53 using google::crypto::tink::KeyStatusType;
54 using google::crypto::tink::OutputPrefixType;
55 using subtle::test::WriteToStream;
56 using testing::HasSubstr;
57 
58 // Creates an RandomAccessStream that contains ciphertext resulting
59 // from encryption of 'pt' with 'aad' as associated data, using 'saead'.
GetCiphertextSource(StreamingAead * saead,absl::string_view pt,absl::string_view aad)60 std::unique_ptr<RandomAccessStream> GetCiphertextSource(
61     StreamingAead* saead, absl::string_view pt, absl::string_view aad) {
62   // Prepare ciphertext destination stream.
63   auto ct_stream = absl::make_unique<std::stringstream>();
64   // A reference to the ciphertext buffer.
65   auto ct_buf = ct_stream->rdbuf();
66   std::unique_ptr<OutputStream> ct_destination(
67       absl::make_unique<util::OstreamOutputStream>(std::move(ct_stream)));
68 
69   // Compute the ciphertext.
70   auto enc_stream_result =
71       saead->NewEncryptingStream(std::move(ct_destination), aad);
72   EXPECT_THAT(enc_stream_result, IsOk());
73   EXPECT_THAT(WriteToStream(enc_stream_result.value().get(), pt), IsOk());
74 
75   // Return the ciphertext as RandomAccessStream.
76   return std::make_unique<internal::TestRandomAccessStream>(ct_buf->str());
77 }
78 
79 // A container for specification of instances of DummyStreamingAead
80 // to be created for testing.
81 struct StreamingAeadSpec {
82   uint32_t key_id;
83   std::string saead_name;
84 };
85 
86 // Generates a PrimitiveSet<StreamingAead> with DummyStreamingAead
87 // instances according to the specification in 'spec'.
88 // The last entry in 'spec' will be the primary primitive in the returned set.
GetTestStreamingAeadSet(const std::vector<StreamingAeadSpec> & spec)89 std::shared_ptr<PrimitiveSet<StreamingAead>> GetTestStreamingAeadSet(
90     const std::vector<StreamingAeadSpec>& spec) {
91   std::shared_ptr<PrimitiveSet<StreamingAead>> saead_set =
92       std::make_shared<PrimitiveSet<StreamingAead>>();
93   int i = 0;
94   for (auto& s : spec) {
95     KeysetInfo::KeyInfo key_info;
96     key_info.set_output_prefix_type(OutputPrefixType::RAW);
97     key_info.set_key_id(s.key_id);
98     key_info.set_status(KeyStatusType::ENABLED);
99     std::unique_ptr<StreamingAead> saead =
100         absl::make_unique<DummyStreamingAead>(s.saead_name);
101     auto entry_result = saead_set->AddPrimitive(std::move(saead), key_info);
102     EXPECT_TRUE(entry_result.ok());
103     if (i + 1 == spec.size()) {
104       EXPECT_THAT(saead_set->set_primary(entry_result.value()), IsOk());
105     }
106     i++;
107   }
108   return saead_set;
109 }
110 
TEST(DecryptingRandomAccessStreamTest,BasicDecryption)111 TEST(DecryptingRandomAccessStreamTest, BasicDecryption) {
112   uint32_t key_id_0 = 1234543;
113   uint32_t key_id_1 = 726329;
114   uint32_t key_id_2 = 7213743;
115   std::string saead_name_0 = "streaming_aead0";
116   std::string saead_name_1 = "streaming_aead1";
117   std::string saead_name_2 = "streaming_aead2";
118 
119   auto saead_set = GetTestStreamingAeadSet(
120       {{key_id_0, saead_name_0}, {key_id_1, saead_name_1},
121        {key_id_2, saead_name_2}});
122 
123   for (int pt_size : {0, 1, 10, 100, 10000}) {
124     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
125     for (std::string aad : {"some_aad", "", "some other aad"}) {
126       SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
127                                 ", aad = '", aad, "'"));
128       // Pre-compute ciphertexts. We create one ciphertext for each primitive
129       // in the primitive set, so that we can test decryption with both
130       // the primary primitive, and the non-primary ones.
131       std::vector<std::unique_ptr<RandomAccessStream>> ciphertexts;
132       for (const auto& p : *(saead_set->get_raw_primitives().value())) {
133         ciphertexts.push_back(
134             GetCiphertextSource(&(p->get_primitive()), plaintext, aad));
135       }
136       EXPECT_EQ(3, ciphertexts.size());
137 
138       // Check the decryption of each of the pre-computed ciphertexts.
139       for (auto& ct : ciphertexts) {
140         // Wrap the primitive set and test the resulting
141         // DecryptingRandomAccessStream.
142         auto dec_stream_result =
143             DecryptingRandomAccessStream::New(saead_set, std::move(ct), aad);
144         EXPECT_THAT(dec_stream_result, IsOk());
145         auto dec_stream = std::move(dec_stream_result.value());
146         std::string decrypted;
147         auto status = internal::ReadAllFromRandomAccessStream(dec_stream.get(),
148                                                               decrypted);
149         EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange,
150                                      HasSubstr("EOF")));
151         EXPECT_EQ(pt_size, dec_stream->size().value());
152         EXPECT_EQ(plaintext, decrypted);
153       }
154     }
155   }
156 }
157 
TEST(DecryptingRandomAccessStreamTest,SelectiveDecryption)158 TEST(DecryptingRandomAccessStreamTest, SelectiveDecryption) {
159   uint32_t key_id_0 = 1234543;
160   uint32_t key_id_1 = 726329;
161   uint32_t key_id_2 = 7213743;
162   std::string saead_name_0 = "streaming_aead0";
163   std::string saead_name_1 = "streaming_aead1";
164   std::string saead_name_2 = "streaming_aead2";
165 
166   auto saead_set = GetTestStreamingAeadSet(
167       {{key_id_0, saead_name_0}, {key_id_1, saead_name_1},
168        {key_id_2, saead_name_2}});
169 
170   for (int pt_size : {5, 10, 100, 10000}) {
171     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
172     for (std::string aad : {"some_aad", "", "some other aad"}) {
173       SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
174                                 ", aad = '", aad, "'"));
175       // Pre-compute ciphertexts. We create one ciphertext for each primitive
176       // in the primitive set, so that we can test decryption with both
177       // the primary primitive, and the non-primary ones.
178       std::vector<std::unique_ptr<RandomAccessStream>> ciphertexts;
179       for (const auto& p : *(saead_set->get_raw_primitives().value())) {
180         ciphertexts.push_back(
181             GetCiphertextSource(&(p->get_primitive()), plaintext, aad));
182       }
183       EXPECT_EQ(3, ciphertexts.size());
184 
185       // Check the decryption of each of the pre-computed ciphertexts.
186       int ct_number = 0;
187       for (auto& ct : ciphertexts) {
188         // Wrap the primitive set and test the resulting
189         // DecryptingRandomAccessStream.
190         auto dec_stream_result =
191             DecryptingRandomAccessStream::New(saead_set, std::move(ct), aad);
192         EXPECT_THAT(dec_stream_result, IsOk());
193         auto dec_stream = std::move(dec_stream_result.value());
194         for (int position : {0, 1, 2, pt_size/2, pt_size-1}) {
195           for (int chunk_size : {1, pt_size/2, pt_size}) {
196             SCOPED_TRACE(absl::StrCat("ct_number = ", ct_number,
197                                       ", position = ", position,
198                                       ", chunk_size = ", chunk_size));
199             auto buffer = std::move(util::Buffer::New(chunk_size).value());
200             util::Status status =
201                 dec_stream->PRead(position, chunk_size, buffer.get());
202             EXPECT_THAT(status,
203                         testing::AnyOf(
204                             IsOk(), StatusIs(absl::StatusCode::kOutOfRange)));
205             EXPECT_EQ(std::min(chunk_size, pt_size - position), buffer->size());
206             EXPECT_EQ(0, std::memcmp(plaintext.data() + position,
207                                      buffer->get_mem_block(), buffer->size()));
208           }
209         }
210         ct_number++;
211       }
212     }
213   }
214 }
215 
TEST(DecryptingRandomAccessStreamTest,OutOfRangeDecryption)216 TEST(DecryptingRandomAccessStreamTest, OutOfRangeDecryption) {
217   uint32_t key_id_0 = 1234543;
218   uint32_t key_id_1 = 726329;
219   uint32_t key_id_2 = 7213743;
220   std::string saead_name_0 = "streaming_aead0";
221   std::string saead_name_1 = "streaming_aead1";
222   std::string saead_name_2 = "streaming_aead2";
223 
224   auto saead_set = GetTestStreamingAeadSet(
225       {{key_id_0, saead_name_0}, {key_id_1, saead_name_1},
226        {key_id_2, saead_name_2}});
227 
228   for (int pt_size : {1, 5, 10, 100, 10000}) {
229     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
230     for (std::string aad : {"some_aad", "", "some other aad"}) {
231       SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
232                                 ", aad = '", aad, "'"));
233       // Pre-compute ciphertexts. We create one ciphertext for each primitive
234       // in the primitive set, so that we can test decryption with both
235       // the primary primitive, and the non-primary ones.
236       std::vector<std::unique_ptr<RandomAccessStream>> ciphertexts;
237       for (const auto& p : *(saead_set->get_raw_primitives().value())) {
238         ciphertexts.push_back(
239             GetCiphertextSource(&(p->get_primitive()), plaintext, aad));
240       }
241       EXPECT_EQ(3, ciphertexts.size());
242 
243       // Check the decryption of each of the pre-computed ciphertexts.
244       int ct_number = 0;
245       for (auto& ct : ciphertexts) {
246         // Wrap the primitive set and test the resulting
247         // DecryptingRandomAccessStream.
248         auto dec_stream_result =
249             DecryptingRandomAccessStream::New(saead_set, std::move(ct), aad);
250         EXPECT_THAT(dec_stream_result, IsOk());
251         auto dec_stream = std::move(dec_stream_result.value());
252         int chunk_size = 1;
253         auto buffer = std::move(util::Buffer::New(chunk_size).value());
254         for (int position : {pt_size, pt_size + 1}) {
255           SCOPED_TRACE(absl::StrCat("ct_number = ", ct_number,
256                                     ", position = ", position));
257           // Negative chunk size.
258           auto status = dec_stream->PRead(position, -1, buffer.get());
259           EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
260 
261           // Negative position.
262           status = dec_stream->PRead(-1, chunk_size, buffer.get());
263           EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
264 
265           // Reading past EOF.
266           status = dec_stream->PRead(position, chunk_size, buffer.get());
267           EXPECT_THAT(status, StatusIs(absl::StatusCode::kOutOfRange));
268         }
269         ct_number++;
270       }
271     }
272   }
273 }
274 
TEST(DecryptingRandomAccessStreamTest,WrongAssociatedData)275 TEST(DecryptingRandomAccessStreamTest, WrongAssociatedData) {
276   uint32_t key_id_0 = 1234543;
277   uint32_t key_id_1 = 726329;
278   uint32_t key_id_2 = 7213743;
279   std::string saead_name_0 = "streaming_aead0";
280   std::string saead_name_1 = "streaming_aead1";
281   std::string saead_name_2 = "streaming_aead2";
282 
283   auto saead_set = GetTestStreamingAeadSet(
284       {{key_id_0, saead_name_0}, {key_id_1, saead_name_1},
285        {key_id_2, saead_name_2}});
286 
287   for (int pt_size : {0, 1, 10, 100, 10000}) {
288     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
289     for (std::string aad : {"some_aad", "", "some other aad"}) {
290       SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
291                                 ", aad = '", aad, "'"));
292       // Compute a ciphertext with the primary primitive.
293       auto ct = GetCiphertextSource(
294           &(saead_set->get_primary()->get_primitive()), plaintext, aad);
295       auto dec_stream_result = DecryptingRandomAccessStream::New(
296           saead_set, std::move(ct), "wrong aad");
297       EXPECT_THAT(dec_stream_result, IsOk());
298       std::string decrypted;
299       auto status = internal::ReadAllFromRandomAccessStream(
300           dec_stream_result.value().get(), decrypted);
301       EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
302     }
303   }
304 }
305 
TEST(DecryptingRandomAccessStreamTest,WrongCiphertext)306 TEST(DecryptingRandomAccessStreamTest, WrongCiphertext) {
307   uint32_t key_id_0 = 1234543;
308   uint32_t key_id_1 = 726329;
309   uint32_t key_id_2 = 7213743;
310   std::string saead_name_0 = "streaming_aead0";
311   std::string saead_name_1 = "streaming_aead1";
312   std::string saead_name_2 = "streaming_aead2";
313 
314   auto saead_set = GetTestStreamingAeadSet(
315       {{key_id_0, saead_name_0}, {key_id_1, saead_name_1},
316        {key_id_2, saead_name_2}});
317 
318   for (int pt_size : {0, 1, 10, 100, 10000}) {
319     std::string plaintext = subtle::Random::GetRandomBytes(pt_size);
320     for (std::string aad : {"some_aad", "", "some other aad"}) {
321       SCOPED_TRACE(absl::StrCat("pt_size = ", pt_size,
322                                 ", aad = '", aad, "'"));
323       // Try decrypting a wrong ciphertext.
324       auto wrong_ct = std::make_unique<internal::TestRandomAccessStream>(
325           subtle::Random::GetRandomBytes(pt_size));
326       auto dec_stream_result = DecryptingRandomAccessStream::New(
327           saead_set, std::move(wrong_ct), aad);
328       EXPECT_THAT(dec_stream_result, IsOk());
329       std::string decrypted;
330       auto status = internal::ReadAllFromRandomAccessStream(
331           dec_stream_result.value().get(), decrypted);
332       EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument));
333     }
334   }
335 }
336 
TEST(DecryptingRandomAccessStreamTest,NullPrimitiveSet)337 TEST(DecryptingRandomAccessStreamTest, NullPrimitiveSet) {
338   auto ct_stream = std::make_unique<internal::TestRandomAccessStream>(
339       "some ciphertext contents");
340   auto dec_stream_result = DecryptingRandomAccessStream::New(
341           nullptr, std::move(ct_stream), "some aad");
342   EXPECT_THAT(dec_stream_result.status(),
343               StatusIs(absl::StatusCode::kInvalidArgument,
344                        HasSubstr("primitives must be non-null")));
345 }
346 
TEST(DecryptingRandomAccessStreamTest,NullCiphertextSource)347 TEST(DecryptingRandomAccessStreamTest, NullCiphertextSource) {
348   uint32_t key_id = 1234543;
349   std::string saead_name = "streaming_aead";
350   auto saead_set = GetTestStreamingAeadSet({{key_id, saead_name}});
351 
352   auto dec_stream_result = DecryptingRandomAccessStream::New(
353       saead_set, nullptr, "some aad");
354   EXPECT_THAT(dec_stream_result.status(),
355               StatusIs(absl::StatusCode::kInvalidArgument,
356                        HasSubstr("ciphertext_source must be non-null")));
357 }
358 
359 }  // namespace
360 }  // namespace streamingaead
361 }  // namespace tink
362 }  // namespace crypto
363