1 // Copyright 2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16
17 #include "tink/prf/prf_set.h"
18
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/strings/string_view.h"
30 #include "tink/keyset_handle.h"
31 #include "tink/keyset_manager.h"
32 #include "tink/prf/prf_config.h"
33 #include "tink/prf/prf_key_templates.h"
34 #include "tink/util/statusor.h"
35 #include "tink/util/test_matchers.h"
36 #include "tink/util/test_util.h"
37
38 namespace crypto {
39 namespace tink {
40 namespace {
41
42 using ::crypto::tink::test::IsOk;
43 using ::testing::_;
44 using ::testing::Eq;
45 using ::testing::Pair;
46 using ::testing::SizeIs;
47 using ::testing::StrEq;
48 using ::testing::UnorderedElementsAre;
49
50 class DummyPrf : public Prf {
Compute(absl::string_view input,size_t output_length) const51 util::StatusOr<std::string> Compute(absl::string_view input,
52 size_t output_length) const override {
53 return std::string("DummyPRF");
54 }
55 };
56
57 class DummyPrfSet : public PrfSet {
58 public:
GetPrimaryId() const59 uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const60 const std::map<uint32_t, Prf*>& GetPrfs() const override {
61 static const std::map<uint32_t, Prf*>* prfs =
62 new std::map<uint32_t, Prf*>({{1, dummy_.get()}});
63 return *prfs;
64 }
65
66 private:
67 std::unique_ptr<Prf> dummy_ = absl::make_unique<DummyPrf>();
68 };
69
70 class BrokenDummyPrfSet : public PrfSet {
71 public:
GetPrimaryId() const72 uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const73 const std::map<uint32_t, Prf*>& GetPrfs() const override {
74 static const std::map<uint32_t, Prf*>* prfs =
75 new std::map<uint32_t, Prf*>();
76 return *prfs;
77 }
78 };
79
TEST(PrfSetTest,ComputePrimary)80 TEST(PrfSetTest, ComputePrimary) {
81 DummyPrfSet prfset;
82 auto output = prfset.ComputePrimary("DummyInput", 16);
83 EXPECT_TRUE(output.ok()) << output.status();
84 BrokenDummyPrfSet broken_prfset;
85 auto broken_output = broken_prfset.ComputePrimary("DummyInput", 16);
86 EXPECT_FALSE(broken_output.ok())
87 << "Expected broken PrfSet to not be able to compute the primary PRF";
88 }
89
TEST(PrfSetWrapperTest,TestPrimitivesEndToEnd)90 TEST(PrfSetWrapperTest, TestPrimitivesEndToEnd) {
91 auto status = PrfConfig::Register();
92 ASSERT_TRUE(status.ok()) << status;
93 auto keyset_manager_result =
94 KeysetManager::New(PrfKeyTemplates::HkdfSha256());
95 ASSERT_TRUE(keyset_manager_result.ok()) << keyset_manager_result.status();
96 auto keyset_manager = std::move(keyset_manager_result.value());
97 auto id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha256());
98 ASSERT_TRUE(id_result.ok()) << id_result.status();
99 uint32_t hmac_sha256_id = id_result.value();
100 id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha512());
101 ASSERT_TRUE(id_result.ok()) << id_result.status();
102 uint32_t hmac_sha512_id = id_result.value();
103 id_result = keyset_manager->Add(PrfKeyTemplates::AesCmac());
104 ASSERT_TRUE(id_result.ok()) << id_result.status();
105 uint32_t aes_cmac_id = id_result.value();
106 auto keyset_handle = keyset_manager->GetKeysetHandle();
107 uint32_t hkdf_id = keyset_handle->GetKeysetInfo().primary_key_id();
108 auto prf_set_result = keyset_handle->GetPrimitive<PrfSet>();
109 ASSERT_TRUE(prf_set_result.ok()) << prf_set_result.status();
110 auto prf_set = std::move(prf_set_result.value());
111 EXPECT_THAT(prf_set->GetPrimaryId(), Eq(hkdf_id));
112 auto prf_map = prf_set->GetPrfs();
113 EXPECT_THAT(prf_map, UnorderedElementsAre(Pair(Eq(hkdf_id), _),
114 Pair(Eq(hmac_sha256_id), _),
115 Pair(Eq(hmac_sha512_id), _),
116 Pair(Eq(aes_cmac_id), _)));
117 std::string input = "This is an input string";
118 std::string input2 = "This is a second input string";
119 std::vector<size_t> output_lengths = {15, 16, 17, 31, 32,
120 33, 63, 64, 65, 100};
121 for (size_t output_length : output_lengths) {
122 bool aes_cmac_ok = output_length <= 16;
123 bool hmac_sha256_ok = output_length <= 32;
124 bool hmac_sha512_ok = output_length <= 64;
125 bool hkdf_sha256_ok = output_length <= 8192;
126 std::vector<std::string> results;
127 for (auto prf : prf_map) {
128 SCOPED_TRACE(absl::StrCat("Computing prf ", prf.first,
129 " with output_length ", output_length));
130 bool ok = (prf.first == aes_cmac_id && aes_cmac_ok) ||
131 (prf.first == hmac_sha256_id && hmac_sha256_ok) ||
132 (prf.first == hmac_sha512_id && hmac_sha512_ok) ||
133 (prf.first == hkdf_id && hkdf_sha256_ok);
134 auto output_result = prf.second->Compute(input, output_length);
135 EXPECT_THAT(output_result.ok(), Eq(ok)) << output_result.status();
136 if (!ok) {
137 continue;
138 }
139 std::string output;
140 if (output_result.ok()) {
141 output = output_result.value();
142 results.push_back(output);
143 }
144 output_result = prf.second->Compute(input2, output_length);
145 EXPECT_TRUE(output_result.ok()) << output_result.status();
146 if (output_result.ok()) {
147 results.push_back(output_result.value());
148 }
149 output_result = prf.second->Compute(input, output_length);
150 EXPECT_TRUE(output_result.ok()) << output_result.status();
151 if (output_result.ok()) {
152 EXPECT_THAT(output_result.value(), StrEq(output));
153 }
154 }
155 for (int i = 0; i < results.size(); i++) {
156 EXPECT_THAT(results[i], SizeIs(output_length));
157 EXPECT_THAT(test::ZTestUniformString(results[i]), IsOk());
158 EXPECT_THAT(test::ZTestAutocorrelationUniformString(results[i]), IsOk());
159 for (int j = i + 1; j < results.size(); j++) {
160 EXPECT_THAT(
161 test::ZTestCrosscorrelationUniformStrings(results[i], results[j]),
162 IsOk());
163 }
164 }
165 }
166 }
167
168 } // namespace
169 } // namespace tink
170 } // namespace crypto
171