1*e7b1675dSTing-Kang Chang // Copyright 2020 Google LLC
2*e7b1675dSTing-Kang Chang //
3*e7b1675dSTing-Kang Chang // Licensed under the Apache License, Version 2.0 (the "License");
4*e7b1675dSTing-Kang Chang // you may not use this file except in compliance with the License.
5*e7b1675dSTing-Kang Chang // You may obtain a copy of the License at
6*e7b1675dSTing-Kang Chang //
7*e7b1675dSTing-Kang Chang // http://www.apache.org/licenses/LICENSE-2.0
8*e7b1675dSTing-Kang Chang //
9*e7b1675dSTing-Kang Chang // Unless required by applicable law or agreed to in writing, software
10*e7b1675dSTing-Kang Chang // distributed under the License is distributed on an "AS IS" BASIS,
11*e7b1675dSTing-Kang Chang // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*e7b1675dSTing-Kang Chang // See the License for the specific language governing permissions and
13*e7b1675dSTing-Kang Chang // limitations under the License.
14*e7b1675dSTing-Kang Chang //
15*e7b1675dSTing-Kang Chang ///////////////////////////////////////////////////////////////////////////////
16*e7b1675dSTing-Kang Chang
17*e7b1675dSTing-Kang Chang #include "tink/prf/prf_set.h"
18*e7b1675dSTing-Kang Chang
19*e7b1675dSTing-Kang Chang #include <map>
20*e7b1675dSTing-Kang Chang #include <memory>
21*e7b1675dSTing-Kang Chang #include <string>
22*e7b1675dSTing-Kang Chang #include <utility>
23*e7b1675dSTing-Kang Chang #include <vector>
24*e7b1675dSTing-Kang Chang
25*e7b1675dSTing-Kang Chang #include "gmock/gmock.h"
26*e7b1675dSTing-Kang Chang #include "gtest/gtest.h"
27*e7b1675dSTing-Kang Chang #include "absl/memory/memory.h"
28*e7b1675dSTing-Kang Chang #include "absl/strings/str_cat.h"
29*e7b1675dSTing-Kang Chang #include "absl/strings/string_view.h"
30*e7b1675dSTing-Kang Chang #include "tink/keyset_handle.h"
31*e7b1675dSTing-Kang Chang #include "tink/keyset_manager.h"
32*e7b1675dSTing-Kang Chang #include "tink/prf/prf_config.h"
33*e7b1675dSTing-Kang Chang #include "tink/prf/prf_key_templates.h"
34*e7b1675dSTing-Kang Chang #include "tink/util/statusor.h"
35*e7b1675dSTing-Kang Chang #include "tink/util/test_matchers.h"
36*e7b1675dSTing-Kang Chang #include "tink/util/test_util.h"
37*e7b1675dSTing-Kang Chang
38*e7b1675dSTing-Kang Chang namespace crypto {
39*e7b1675dSTing-Kang Chang namespace tink {
40*e7b1675dSTing-Kang Chang namespace {
41*e7b1675dSTing-Kang Chang
42*e7b1675dSTing-Kang Chang using ::crypto::tink::test::IsOk;
43*e7b1675dSTing-Kang Chang using ::testing::_;
44*e7b1675dSTing-Kang Chang using ::testing::Eq;
45*e7b1675dSTing-Kang Chang using ::testing::Pair;
46*e7b1675dSTing-Kang Chang using ::testing::SizeIs;
47*e7b1675dSTing-Kang Chang using ::testing::StrEq;
48*e7b1675dSTing-Kang Chang using ::testing::UnorderedElementsAre;
49*e7b1675dSTing-Kang Chang
50*e7b1675dSTing-Kang Chang class DummyPrf : public Prf {
Compute(absl::string_view input,size_t output_length) const51*e7b1675dSTing-Kang Chang util::StatusOr<std::string> Compute(absl::string_view input,
52*e7b1675dSTing-Kang Chang size_t output_length) const override {
53*e7b1675dSTing-Kang Chang return std::string("DummyPRF");
54*e7b1675dSTing-Kang Chang }
55*e7b1675dSTing-Kang Chang };
56*e7b1675dSTing-Kang Chang
57*e7b1675dSTing-Kang Chang class DummyPrfSet : public PrfSet {
58*e7b1675dSTing-Kang Chang public:
GetPrimaryId() const59*e7b1675dSTing-Kang Chang uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const60*e7b1675dSTing-Kang Chang const std::map<uint32_t, Prf*>& GetPrfs() const override {
61*e7b1675dSTing-Kang Chang static const std::map<uint32_t, Prf*>* prfs =
62*e7b1675dSTing-Kang Chang new std::map<uint32_t, Prf*>({{1, dummy_.get()}});
63*e7b1675dSTing-Kang Chang return *prfs;
64*e7b1675dSTing-Kang Chang }
65*e7b1675dSTing-Kang Chang
66*e7b1675dSTing-Kang Chang private:
67*e7b1675dSTing-Kang Chang std::unique_ptr<Prf> dummy_ = absl::make_unique<DummyPrf>();
68*e7b1675dSTing-Kang Chang };
69*e7b1675dSTing-Kang Chang
70*e7b1675dSTing-Kang Chang class BrokenDummyPrfSet : public PrfSet {
71*e7b1675dSTing-Kang Chang public:
GetPrimaryId() const72*e7b1675dSTing-Kang Chang uint32_t GetPrimaryId() const override { return 1; }
GetPrfs() const73*e7b1675dSTing-Kang Chang const std::map<uint32_t, Prf*>& GetPrfs() const override {
74*e7b1675dSTing-Kang Chang static const std::map<uint32_t, Prf*>* prfs =
75*e7b1675dSTing-Kang Chang new std::map<uint32_t, Prf*>();
76*e7b1675dSTing-Kang Chang return *prfs;
77*e7b1675dSTing-Kang Chang }
78*e7b1675dSTing-Kang Chang };
79*e7b1675dSTing-Kang Chang
TEST(PrfSetTest,ComputePrimary)80*e7b1675dSTing-Kang Chang TEST(PrfSetTest, ComputePrimary) {
81*e7b1675dSTing-Kang Chang DummyPrfSet prfset;
82*e7b1675dSTing-Kang Chang auto output = prfset.ComputePrimary("DummyInput", 16);
83*e7b1675dSTing-Kang Chang EXPECT_TRUE(output.ok()) << output.status();
84*e7b1675dSTing-Kang Chang BrokenDummyPrfSet broken_prfset;
85*e7b1675dSTing-Kang Chang auto broken_output = broken_prfset.ComputePrimary("DummyInput", 16);
86*e7b1675dSTing-Kang Chang EXPECT_FALSE(broken_output.ok())
87*e7b1675dSTing-Kang Chang << "Expected broken PrfSet to not be able to compute the primary PRF";
88*e7b1675dSTing-Kang Chang }
89*e7b1675dSTing-Kang Chang
TEST(PrfSetWrapperTest,TestPrimitivesEndToEnd)90*e7b1675dSTing-Kang Chang TEST(PrfSetWrapperTest, TestPrimitivesEndToEnd) {
91*e7b1675dSTing-Kang Chang auto status = PrfConfig::Register();
92*e7b1675dSTing-Kang Chang ASSERT_TRUE(status.ok()) << status;
93*e7b1675dSTing-Kang Chang auto keyset_manager_result =
94*e7b1675dSTing-Kang Chang KeysetManager::New(PrfKeyTemplates::HkdfSha256());
95*e7b1675dSTing-Kang Chang ASSERT_TRUE(keyset_manager_result.ok()) << keyset_manager_result.status();
96*e7b1675dSTing-Kang Chang auto keyset_manager = std::move(keyset_manager_result.value());
97*e7b1675dSTing-Kang Chang auto id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha256());
98*e7b1675dSTing-Kang Chang ASSERT_TRUE(id_result.ok()) << id_result.status();
99*e7b1675dSTing-Kang Chang uint32_t hmac_sha256_id = id_result.value();
100*e7b1675dSTing-Kang Chang id_result = keyset_manager->Add(PrfKeyTemplates::HmacSha512());
101*e7b1675dSTing-Kang Chang ASSERT_TRUE(id_result.ok()) << id_result.status();
102*e7b1675dSTing-Kang Chang uint32_t hmac_sha512_id = id_result.value();
103*e7b1675dSTing-Kang Chang id_result = keyset_manager->Add(PrfKeyTemplates::AesCmac());
104*e7b1675dSTing-Kang Chang ASSERT_TRUE(id_result.ok()) << id_result.status();
105*e7b1675dSTing-Kang Chang uint32_t aes_cmac_id = id_result.value();
106*e7b1675dSTing-Kang Chang auto keyset_handle = keyset_manager->GetKeysetHandle();
107*e7b1675dSTing-Kang Chang uint32_t hkdf_id = keyset_handle->GetKeysetInfo().primary_key_id();
108*e7b1675dSTing-Kang Chang auto prf_set_result = keyset_handle->GetPrimitive<PrfSet>();
109*e7b1675dSTing-Kang Chang ASSERT_TRUE(prf_set_result.ok()) << prf_set_result.status();
110*e7b1675dSTing-Kang Chang auto prf_set = std::move(prf_set_result.value());
111*e7b1675dSTing-Kang Chang EXPECT_THAT(prf_set->GetPrimaryId(), Eq(hkdf_id));
112*e7b1675dSTing-Kang Chang auto prf_map = prf_set->GetPrfs();
113*e7b1675dSTing-Kang Chang EXPECT_THAT(prf_map, UnorderedElementsAre(Pair(Eq(hkdf_id), _),
114*e7b1675dSTing-Kang Chang Pair(Eq(hmac_sha256_id), _),
115*e7b1675dSTing-Kang Chang Pair(Eq(hmac_sha512_id), _),
116*e7b1675dSTing-Kang Chang Pair(Eq(aes_cmac_id), _)));
117*e7b1675dSTing-Kang Chang std::string input = "This is an input string";
118*e7b1675dSTing-Kang Chang std::string input2 = "This is a second input string";
119*e7b1675dSTing-Kang Chang std::vector<size_t> output_lengths = {15, 16, 17, 31, 32,
120*e7b1675dSTing-Kang Chang 33, 63, 64, 65, 100};
121*e7b1675dSTing-Kang Chang for (size_t output_length : output_lengths) {
122*e7b1675dSTing-Kang Chang bool aes_cmac_ok = output_length <= 16;
123*e7b1675dSTing-Kang Chang bool hmac_sha256_ok = output_length <= 32;
124*e7b1675dSTing-Kang Chang bool hmac_sha512_ok = output_length <= 64;
125*e7b1675dSTing-Kang Chang bool hkdf_sha256_ok = output_length <= 8192;
126*e7b1675dSTing-Kang Chang std::vector<std::string> results;
127*e7b1675dSTing-Kang Chang for (auto prf : prf_map) {
128*e7b1675dSTing-Kang Chang SCOPED_TRACE(absl::StrCat("Computing prf ", prf.first,
129*e7b1675dSTing-Kang Chang " with output_length ", output_length));
130*e7b1675dSTing-Kang Chang bool ok = (prf.first == aes_cmac_id && aes_cmac_ok) ||
131*e7b1675dSTing-Kang Chang (prf.first == hmac_sha256_id && hmac_sha256_ok) ||
132*e7b1675dSTing-Kang Chang (prf.first == hmac_sha512_id && hmac_sha512_ok) ||
133*e7b1675dSTing-Kang Chang (prf.first == hkdf_id && hkdf_sha256_ok);
134*e7b1675dSTing-Kang Chang auto output_result = prf.second->Compute(input, output_length);
135*e7b1675dSTing-Kang Chang EXPECT_THAT(output_result.ok(), Eq(ok)) << output_result.status();
136*e7b1675dSTing-Kang Chang if (!ok) {
137*e7b1675dSTing-Kang Chang continue;
138*e7b1675dSTing-Kang Chang }
139*e7b1675dSTing-Kang Chang std::string output;
140*e7b1675dSTing-Kang Chang if (output_result.ok()) {
141*e7b1675dSTing-Kang Chang output = output_result.value();
142*e7b1675dSTing-Kang Chang results.push_back(output);
143*e7b1675dSTing-Kang Chang }
144*e7b1675dSTing-Kang Chang output_result = prf.second->Compute(input2, output_length);
145*e7b1675dSTing-Kang Chang EXPECT_TRUE(output_result.ok()) << output_result.status();
146*e7b1675dSTing-Kang Chang if (output_result.ok()) {
147*e7b1675dSTing-Kang Chang results.push_back(output_result.value());
148*e7b1675dSTing-Kang Chang }
149*e7b1675dSTing-Kang Chang output_result = prf.second->Compute(input, output_length);
150*e7b1675dSTing-Kang Chang EXPECT_TRUE(output_result.ok()) << output_result.status();
151*e7b1675dSTing-Kang Chang if (output_result.ok()) {
152*e7b1675dSTing-Kang Chang EXPECT_THAT(output_result.value(), StrEq(output));
153*e7b1675dSTing-Kang Chang }
154*e7b1675dSTing-Kang Chang }
155*e7b1675dSTing-Kang Chang for (int i = 0; i < results.size(); i++) {
156*e7b1675dSTing-Kang Chang EXPECT_THAT(results[i], SizeIs(output_length));
157*e7b1675dSTing-Kang Chang EXPECT_THAT(test::ZTestUniformString(results[i]), IsOk());
158*e7b1675dSTing-Kang Chang EXPECT_THAT(test::ZTestAutocorrelationUniformString(results[i]), IsOk());
159*e7b1675dSTing-Kang Chang for (int j = i + 1; j < results.size(); j++) {
160*e7b1675dSTing-Kang Chang EXPECT_THAT(
161*e7b1675dSTing-Kang Chang test::ZTestCrosscorrelationUniformStrings(results[i], results[j]),
162*e7b1675dSTing-Kang Chang IsOk());
163*e7b1675dSTing-Kang Chang }
164*e7b1675dSTing-Kang Chang }
165*e7b1675dSTing-Kang Chang }
166*e7b1675dSTing-Kang Chang }
167*e7b1675dSTing-Kang Chang
168*e7b1675dSTing-Kang Chang } // namespace
169*e7b1675dSTing-Kang Chang } // namespace tink
170*e7b1675dSTing-Kang Chang } // namespace crypto
171