1 //
2 //
3 // Copyright 2020 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18
19 #include "src/core/ext/xds/certificate_provider_store.h"
20
21 #include <algorithm>
22 #include <memory>
23 #include <thread>
24 #include <vector>
25
26 #include "gtest/gtest.h"
27
28 #include <grpc/grpc.h>
29 #include <grpc/support/log.h>
30
31 #include "src/core/lib/config/core_configuration.h"
32 #include "src/core/lib/gprpp/unique_type_name.h"
33 #include "test/core/util/test_config.h"
34
35 namespace grpc_core {
36 namespace testing {
37 namespace {
38
39 class CertificateProviderStoreTest : public ::testing::Test {
40 public:
CertificateProviderStoreTest()41 CertificateProviderStoreTest() { grpc_init(); }
42
~CertificateProviderStoreTest()43 ~CertificateProviderStoreTest() override { grpc_shutdown_blocking(); }
44 };
45
46 class FakeCertificateProvider : public grpc_tls_certificate_provider {
47 public:
distributor() const48 RefCountedPtr<grpc_tls_certificate_distributor> distributor() const override {
49 // never called
50 GPR_ASSERT(0);
51 return nullptr;
52 }
53
type() const54 UniqueTypeName type() const override {
55 static UniqueTypeName::Factory kFactory("fake");
56 return kFactory.Create();
57 }
58
59 private:
CompareImpl(const grpc_tls_certificate_provider * other) const60 int CompareImpl(const grpc_tls_certificate_provider* other) const override {
61 // TODO(yashykt): Maybe do something better here.
62 return QsortCompare(static_cast<const grpc_tls_certificate_provider*>(this),
63 other);
64 }
65 };
66
67 class FakeCertificateProviderFactory1 : public CertificateProviderFactory {
68 public:
69 class Config : public CertificateProviderFactory::Config {
70 public:
name() const71 absl::string_view name() const override { return "fake1"; }
72
ToString() const73 std::string ToString() const override { return "{}"; }
74 };
75
name() const76 absl::string_view name() const override { return "fake1"; }
77
78 RefCountedPtr<CertificateProviderFactory::Config>
CreateCertificateProviderConfig(const Json &,const JsonArgs &,ValidationErrors *)79 CreateCertificateProviderConfig(const Json& /*config_json*/,
80 const JsonArgs& /*args*/,
81 ValidationErrors* /*errors*/) override {
82 return MakeRefCounted<Config>();
83 }
84
CreateCertificateProvider(RefCountedPtr<CertificateProviderFactory::Config>)85 RefCountedPtr<grpc_tls_certificate_provider> CreateCertificateProvider(
86 RefCountedPtr<CertificateProviderFactory::Config> /*config*/) override {
87 return MakeRefCounted<FakeCertificateProvider>();
88 }
89 };
90
91 class FakeCertificateProviderFactory2 : public CertificateProviderFactory {
92 public:
93 class Config : public CertificateProviderFactory::Config {
94 public:
name() const95 absl::string_view name() const override { return "fake2"; }
96
ToString() const97 std::string ToString() const override { return "{}"; }
98 };
99
name() const100 absl::string_view name() const override { return "fake2"; }
101
102 RefCountedPtr<CertificateProviderFactory::Config>
CreateCertificateProviderConfig(const Json &,const JsonArgs &,ValidationErrors *)103 CreateCertificateProviderConfig(const Json& /*config_json*/,
104 const JsonArgs& /*args*/,
105 ValidationErrors* /*errors*/) override {
106 return MakeRefCounted<Config>();
107 }
108
CreateCertificateProvider(RefCountedPtr<CertificateProviderFactory::Config>)109 RefCountedPtr<grpc_tls_certificate_provider> CreateCertificateProvider(
110 RefCountedPtr<CertificateProviderFactory::Config> /*config*/) override {
111 return MakeRefCounted<FakeCertificateProvider>();
112 }
113 };
114
TEST_F(CertificateProviderStoreTest,Basic)115 TEST_F(CertificateProviderStoreTest, Basic) {
116 // Set up factories. (Register only one of the factories.)
117 auto* fake_factory_1 = new FakeCertificateProviderFactory1;
118 CoreConfiguration::RunWithSpecialConfiguration(
119 [=](CoreConfiguration::Builder* builder) {
120 builder->certificate_provider_registry()
121 ->RegisterCertificateProviderFactory(
122 std::unique_ptr<CertificateProviderFactory>(fake_factory_1));
123 },
124 [=] {
125 auto fake_factory_2 =
126 std::make_unique<FakeCertificateProviderFactory2>();
127 // Set up store
128 CertificateProviderStore::PluginDefinitionMap map = {
129 {"fake_plugin_1",
130 {"fake1", fake_factory_1->CreateCertificateProviderConfig(
131 Json::FromObject({}), JsonArgs(), nullptr)}},
132 {"fake_plugin_2",
133 {"fake2", fake_factory_2->CreateCertificateProviderConfig(
134 Json::FromObject({}), JsonArgs(), nullptr)}},
135 {"fake_plugin_3",
136 {"fake1", fake_factory_1->CreateCertificateProviderConfig(
137 Json::FromObject({}), JsonArgs(), nullptr)}},
138 };
139 auto store = MakeOrphanable<CertificateProviderStore>(std::move(map));
140 // Test for creating certificate providers with known plugin
141 // configuration.
142 auto cert_provider_1 =
143 store->CreateOrGetCertificateProvider("fake_plugin_1");
144 ASSERT_NE(cert_provider_1, nullptr);
145 auto cert_provider_3 =
146 store->CreateOrGetCertificateProvider("fake_plugin_3");
147 ASSERT_NE(cert_provider_3, nullptr);
148 // Test for creating certificate provider with known plugin
149 // configuration but unregistered factory.
150 ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_2"),
151 nullptr);
152 // Test for creating certificate provider with unknown plugin
153 // configuration.
154 ASSERT_EQ(store->CreateOrGetCertificateProvider("unknown"), nullptr);
155 // Test for getting previously created certificate providers.
156 ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_1"),
157 cert_provider_1);
158 ASSERT_EQ(store->CreateOrGetCertificateProvider("fake_plugin_3"),
159 cert_provider_3);
160 // Release previously created certificate providers so that the store
161 // outlasts the certificate providers.
162 cert_provider_1.reset();
163 cert_provider_3.reset();
164 });
165 }
166
TEST_F(CertificateProviderStoreTest,Multithreaded)167 TEST_F(CertificateProviderStoreTest, Multithreaded) {
168 auto* fake_factory_1 = new FakeCertificateProviderFactory1;
169 CoreConfiguration::RunWithSpecialConfiguration(
170 [=](CoreConfiguration::Builder* builder) {
171 builder->certificate_provider_registry()
172 ->RegisterCertificateProviderFactory(
173 std::unique_ptr<CertificateProviderFactory>(fake_factory_1));
174 },
175 [=] {
176 CertificateProviderStore::PluginDefinitionMap map = {
177 {"fake_plugin_1",
178 {"fake1", fake_factory_1->CreateCertificateProviderConfig(
179 Json::FromObject({}), JsonArgs(), nullptr)}}};
180 auto store = MakeOrphanable<CertificateProviderStore>(std::move(map));
181 // Test concurrent `CreateOrGetCertificateProvider()` with the same key.
182 std::vector<std::thread> threads;
183 threads.reserve(1000);
184 for (auto i = 0; i < 1000; i++) {
185 threads.emplace_back([&store]() {
186 for (auto i = 0; i < 10; ++i) {
187 ASSERT_NE(store->CreateOrGetCertificateProvider("fake_plugin_1"),
188 nullptr);
189 }
190 });
191 }
192 for (auto& thread : threads) {
193 thread.join();
194 }
195 });
196 }
197
198 } // namespace
199 } // namespace testing
200 } // namespace grpc_core
201
main(int argc,char ** argv)202 int main(int argc, char** argv) {
203 ::testing::InitGoogleTest(&argc, argv);
204 grpc::testing::TestEnvironment env(&argc, argv);
205 auto result = RUN_ALL_TESTS();
206 return result;
207 }
208