xref: /aosp_15_r20/external/tink/cc/prf/prf_set_test.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/prf/prf_set.h"
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "tink/keyset_handle.h"
31 #include "tink/keyset_manager.h"
32 #include "tink/prf/prf_config.h"
33 #include "tink/prf/prf_key_templates.h"
34 #include "tink/util/statusor.h"
35 #include "tink/util/test_matchers.h"
36 #include "tink/util/test_util.h"
37 
38 namespace crypto {
39 namespace tink {
40 namespace {
41 
42 using ::crypto::tink::test::IsOk;
43 using ::testing::_;
44 using ::testing::Eq;
45 using ::testing::Pair;
46 using ::testing::SizeIs;
47 using ::testing::StrEq;
48 using ::testing::UnorderedElementsAre;
49 
50 class DummyPrf : public Prf {
Compute(absl::string_view input,size_t output_length) const51   util::StatusOr<std::string> Compute(absl::string_view input,
52                                       size_t output_length) const override {
53     return std::string("DummyPRF");
54   }
55 };
56 
57 class DummyPrfSet : public PrfSet {
58  public:
GetPrimaryId() const59   uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const60   const std::map<uint32_t, Prf*>& GetPrfs() const override {
61     static const std::map<uint32_t, Prf*>* prfs =
62         new std::map<uint32_t, Prf*>({{1, dummy_.get()}});
63     return *prfs;
64   }
65 
66  private:
67   std::unique_ptr<Prf> dummy_ = absl::make_unique<DummyPrf>();
68 };
69 
70 class BrokenDummyPrfSet : public PrfSet {
71  public:
GetPrimaryId() const72   uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const73   const std::map<uint32_t, Prf*>& GetPrfs() const override {
74     static const std::map<uint32_t, Prf*>* prfs =
75         new std::map<uint32_t, Prf*>();
76     return *prfs;
77   }
78 };
79 
TEST(PrfSetTest,ComputePrimary)80 TEST(PrfSetTest, ComputePrimary) {
81   DummyPrfSet prfset;
82   auto output = prfset.ComputePrimary("DummyInput", 16);
83   EXPECT_TRUE(output.ok()) << output.status();
84   BrokenDummyPrfSet broken_prfset;
85   auto broken_output = broken_prfset.ComputePrimary("DummyInput", 16);
86   EXPECT_FALSE(broken_output.ok())
87       << "Expected broken PrfSet to not be able to compute the primary PRF";
88 }
89 
TEST(PrfSetWrapperTest,TestPrimitivesEndToEnd)90 TEST(PrfSetWrapperTest, TestPrimitivesEndToEnd) {
91   auto status = PrfConfig::Register();
92   ASSERT_TRUE(status.ok()) << status;
93   auto keyset_manager_result =
94       KeysetManager::New(PrfKeyTemplates::HkdfSha256());
95   ASSERT_TRUE(keyset_manager_result.ok()) << keyset_manager_result.status();
96   auto keyset_manager = std::move(keyset_manager_result.value());
97   auto id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha256());
98   ASSERT_TRUE(id_result.ok()) << id_result.status();
99   uint32_t hmac_sha256_id = id_result.value();
100   id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha512());
101   ASSERT_TRUE(id_result.ok()) << id_result.status();
102   uint32_t hmac_sha512_id = id_result.value();
103   id_result = keyset_manager->Add(PrfKeyTemplates::AesCmac());
104   ASSERT_TRUE(id_result.ok()) << id_result.status();
105   uint32_t aes_cmac_id = id_result.value();
106   auto keyset_handle = keyset_manager->GetKeysetHandle();
107   uint32_t hkdf_id = keyset_handle->GetKeysetInfo().primary_key_id();
108   auto prf_set_result = keyset_handle->GetPrimitive<PrfSet>();
109   ASSERT_TRUE(prf_set_result.ok()) << prf_set_result.status();
110   auto prf_set = std::move(prf_set_result.value());
111   EXPECT_THAT(prf_set->GetPrimaryId(), Eq(hkdf_id));
112   auto prf_map = prf_set->GetPrfs();
113   EXPECT_THAT(prf_map, UnorderedElementsAre(Pair(Eq(hkdf_id), _),
114                                             Pair(Eq(hmac_sha256_id), _),
115                                             Pair(Eq(hmac_sha512_id), _),
116                                             Pair(Eq(aes_cmac_id), _)));
117   std::string input = "This is an input string";
118   std::string input2 = "This is a second input string";
119   std::vector<size_t> output_lengths = {15, 16, 17, 31, 32,
120                                         33, 63, 64, 65, 100};
121   for (size_t output_length : output_lengths) {
122     bool aes_cmac_ok = output_length <= 16;
123     bool hmac_sha256_ok = output_length <= 32;
124     bool hmac_sha512_ok = output_length <= 64;
125     bool hkdf_sha256_ok = output_length <= 8192;
126     std::vector<std::string> results;
127     for (auto prf : prf_map) {
128       SCOPED_TRACE(absl::StrCat("Computing prf ", prf.first,
129                                 " with output_length ", output_length));
130       bool ok = (prf.first == aes_cmac_id && aes_cmac_ok) ||
131                 (prf.first == hmac_sha256_id && hmac_sha256_ok) ||
132                 (prf.first == hmac_sha512_id && hmac_sha512_ok) ||
133                 (prf.first == hkdf_id && hkdf_sha256_ok);
134       auto output_result = prf.second->Compute(input, output_length);
135       EXPECT_THAT(output_result.ok(), Eq(ok)) << output_result.status();
136       if (!ok) {
137         continue;
138       }
139       std::string output;
140       if (output_result.ok()) {
141         output = output_result.value();
142         results.push_back(output);
143       }
144       output_result = prf.second->Compute(input2, output_length);
145       EXPECT_TRUE(output_result.ok()) << output_result.status();
146       if (output_result.ok()) {
147         results.push_back(output_result.value());
148       }
149       output_result = prf.second->Compute(input, output_length);
150       EXPECT_TRUE(output_result.ok()) << output_result.status();
151       if (output_result.ok()) {
152         EXPECT_THAT(output_result.value(), StrEq(output));
153       }
154     }
155     for (int i = 0; i < results.size(); i++) {
156       EXPECT_THAT(results[i], SizeIs(output_length));
157       EXPECT_THAT(test::ZTestUniformString(results[i]), IsOk());
158       EXPECT_THAT(test::ZTestAutocorrelationUniformString(results[i]), IsOk());
159       for (int j = i + 1; j < results.size(); j++) {
160         EXPECT_THAT(
161             test::ZTestCrosscorrelationUniformStrings(results[i], results[j]),
162             IsOk());
163       }
164     }
165   }
166 }
167 
168 }  // namespace
169 }  // namespace tink
170 }  // namespace crypto
171