1 /*
2  * Copyright 2019 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.xds.internal.security;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.mockito.ArgumentMatchers.any;
21 import static org.mockito.ArgumentMatchers.eq;
22 import static org.mockito.Mockito.mock;
23 import static org.mockito.Mockito.when;
24 
25 import com.google.common.collect.ImmutableSet;
26 import io.envoyproxy.envoy.config.core.v3.DataSource;
27 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
28 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
29 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
30 import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
31 import io.grpc.xds.Bootstrapper;
32 import io.grpc.xds.CommonBootstrapperTestUtils;
33 import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
34 import io.grpc.xds.XdsInitializationException;
35 import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory;
36 import io.grpc.xds.internal.security.certprovider.CertificateProvider;
37 import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider;
38 import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry;
39 import io.grpc.xds.internal.security.certprovider.CertificateProviderStore;
40 import io.grpc.xds.internal.security.certprovider.TestCertificateProvider;
41 import java.io.IOException;
42 import org.junit.Assert;
43 import org.junit.Before;
44 import org.junit.Test;
45 import org.junit.runner.RunWith;
46 import org.junit.runners.JUnit4;
47 import org.mockito.invocation.InvocationOnMock;
48 import org.mockito.stubbing.Answer;
49 
50 /** Unit tests for {@link ClientSslContextProviderFactory}. */
51 @RunWith(JUnit4.class)
52 public class ClientSslContextProviderFactoryTest {
53 
54   CertificateProviderRegistry certificateProviderRegistry;
55   CertificateProviderStore certificateProviderStore;
56   CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory;
57   ClientSslContextProviderFactory clientSslContextProviderFactory;
58 
59   @Before
setUp()60   public void setUp() {
61     certificateProviderRegistry = new CertificateProviderRegistry();
62     certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
63     certProviderClientSslContextProviderFactory =
64         new CertProviderClientSslContextProviderFactory(certificateProviderStore);
65   }
66 
67   @Test
createCertProviderClientSslContextProvider()68   public void createCertProviderClientSslContextProvider() throws XdsInitializationException {
69     final CertificateProvider.DistributorWatcher[] watcherCaptor =
70         new CertificateProvider.DistributorWatcher[1];
71     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
72     UpstreamTlsContext upstreamTlsContext =
73         CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
74             "gcp_id",
75             "cert-default",
76             "gcp_id",
77             "root-default",
78             /* alpnProtocols= */ null,
79             /* staticCertValidationContext= */ null);
80 
81     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
82     clientSslContextProviderFactory =
83             new ClientSslContextProviderFactory(
84                     bootstrapInfo, certProviderClientSslContextProviderFactory);
85     SslContextProvider sslContextProvider =
86         clientSslContextProviderFactory.create(upstreamTlsContext);
87     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
88         "CertProviderClientSslContextProvider");
89     verifyWatcher(sslContextProvider, watcherCaptor[0]);
90     // verify that bootstrapInfo is cached...
91     sslContextProvider =
92         clientSslContextProviderFactory.create(upstreamTlsContext);
93     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
94         "CertProviderClientSslContextProvider");
95   }
96 
97   @Test
bothPresent_expectCertProviderClientSslContextProvider()98   public void bothPresent_expectCertProviderClientSslContextProvider()
99       throws XdsInitializationException {
100     final CertificateProvider.DistributorWatcher[] watcherCaptor =
101         new CertificateProvider.DistributorWatcher[1];
102     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
103     UpstreamTlsContext upstreamTlsContext =
104         CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
105             "gcp_id",
106             "cert-default",
107             "gcp_id",
108             "root-default",
109             /* alpnProtocols= */ null,
110             /* staticCertValidationContext= */ null);
111 
112     CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder();
113     builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem");
114     upstreamTlsContext = new UpstreamTlsContext(builder.build());
115 
116     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
117     clientSslContextProviderFactory =
118             new ClientSslContextProviderFactory(
119                     bootstrapInfo, certProviderClientSslContextProviderFactory);
120     SslContextProvider sslContextProvider =
121         clientSslContextProviderFactory.create(upstreamTlsContext);
122     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
123         "CertProviderClientSslContextProvider");
124     verifyWatcher(sslContextProvider, watcherCaptor[0]);
125   }
126 
127   @Test
createCertProviderClientSslContextProvider_onlyRootCert()128   public void createCertProviderClientSslContextProvider_onlyRootCert()
129       throws XdsInitializationException {
130     final CertificateProvider.DistributorWatcher[] watcherCaptor =
131             new CertificateProvider.DistributorWatcher[1];
132     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
133     UpstreamTlsContext upstreamTlsContext =
134             CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
135                     /* certInstanceName= */ null,
136                     /* certName= */ null,
137                     "gcp_id",
138                     "root-default",
139                     /* alpnProtocols= */ null,
140                     /* staticCertValidationContext= */ null);
141 
142     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
143     clientSslContextProviderFactory =
144             new ClientSslContextProviderFactory(
145                     bootstrapInfo, certProviderClientSslContextProviderFactory);
146     SslContextProvider sslContextProvider =
147             clientSslContextProviderFactory.create(upstreamTlsContext);
148     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
149         "CertProviderClientSslContextProvider");
150     verifyWatcher(sslContextProvider, watcherCaptor[0]);
151   }
152 
153   @Test
createCertProviderClientSslContextProvider_withStaticContext()154   public void createCertProviderClientSslContextProvider_withStaticContext()
155       throws XdsInitializationException {
156     final CertificateProvider.DistributorWatcher[] watcherCaptor =
157             new CertificateProvider.DistributorWatcher[1];
158     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
159     @SuppressWarnings("deprecation")
160     CertificateValidationContext staticCertValidationContext =
161         CertificateValidationContext.newBuilder()
162             .addAllMatchSubjectAltNames(
163                 ImmutableSet.of(
164                     StringMatcher.newBuilder().setExact("foo").build(),
165                     StringMatcher.newBuilder().setExact("bar").build()))
166             .build();
167     UpstreamTlsContext upstreamTlsContext =
168             CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
169                     /* certInstanceName= */ null,
170                     /* certName= */ null,
171                     "gcp_id",
172                     "root-default",
173                     /* alpnProtocols= */ null,
174                     staticCertValidationContext);
175 
176     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
177     clientSslContextProviderFactory =
178             new ClientSslContextProviderFactory(bootstrapInfo,
179                     certProviderClientSslContextProviderFactory);
180     SslContextProvider sslContextProvider =
181             clientSslContextProviderFactory.create(upstreamTlsContext);
182     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
183         "CertProviderClientSslContextProvider");
184     verifyWatcher(sslContextProvider, watcherCaptor[0]);
185   }
186 
187   @Test
createCertProviderClientSslContextProvider_2providers()188   public void createCertProviderClientSslContextProvider_2providers()
189       throws XdsInitializationException {
190     final CertificateProvider.DistributorWatcher[] watcherCaptor =
191         new CertificateProvider.DistributorWatcher[2];
192     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
193 
194     createAndRegisterProviderProvider(
195         certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
196 
197     UpstreamTlsContext upstreamTlsContext =
198         CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
199             "gcp_id",
200             "cert-default",
201             "file_provider",
202             "root-default",
203             /* alpnProtocols= */ null,
204             /* staticCertValidationContext= */ null);
205 
206     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
207     clientSslContextProviderFactory =
208             new ClientSslContextProviderFactory(
209                     bootstrapInfo, certProviderClientSslContextProviderFactory);
210     SslContextProvider sslContextProvider =
211         clientSslContextProviderFactory.create(upstreamTlsContext);
212     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
213         "CertProviderClientSslContextProvider");
214     verifyWatcher(sslContextProvider, watcherCaptor[0]);
215     verifyWatcher(sslContextProvider, watcherCaptor[1]);
216   }
217 
218   @Test
createNewCertProviderClientSslContextProvider_withSans()219   public void createNewCertProviderClientSslContextProvider_withSans() {
220     final CertificateProvider.DistributorWatcher[] watcherCaptor =
221         new CertificateProvider.DistributorWatcher[2];
222     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
223     createAndRegisterProviderProvider(
224         certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
225 
226     @SuppressWarnings("deprecation")
227     CertificateValidationContext staticCertValidationContext =
228         CertificateValidationContext.newBuilder()
229             .addAllMatchSubjectAltNames(
230                 ImmutableSet.of(
231                     StringMatcher.newBuilder().setExact("foo").build(),
232                     StringMatcher.newBuilder().setExact("bar").build()))
233             .build();
234     UpstreamTlsContext upstreamTlsContext =
235         CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
236             "gcp_id",
237             "cert-default",
238             "file_provider",
239             "root-default",
240             /* alpnProtocols= */ null,
241             staticCertValidationContext);
242 
243     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
244     clientSslContextProviderFactory =
245         new ClientSslContextProviderFactory(
246             bootstrapInfo, certProviderClientSslContextProviderFactory);
247     SslContextProvider sslContextProvider =
248         clientSslContextProviderFactory.create(upstreamTlsContext);
249     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
250         "CertProviderClientSslContextProvider");
251     verifyWatcher(sslContextProvider, watcherCaptor[0]);
252     verifyWatcher(sslContextProvider, watcherCaptor[1]);
253   }
254 
255   @Test
createNewCertProviderClientSslContextProvider_onlyRootCert()256   public void createNewCertProviderClientSslContextProvider_onlyRootCert() {
257     final CertificateProvider.DistributorWatcher[] watcherCaptor =
258         new CertificateProvider.DistributorWatcher[1];
259     createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
260     @SuppressWarnings("deprecation")
261     CertificateValidationContext staticCertValidationContext =
262         CertificateValidationContext.newBuilder()
263             .addAllMatchSubjectAltNames(
264                 ImmutableSet.of(
265                     StringMatcher.newBuilder().setExact("foo").build(),
266                     StringMatcher.newBuilder().setExact("bar").build()))
267             .build();
268     UpstreamTlsContext upstreamTlsContext =
269         CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
270             /* certInstanceName= */ null,
271             /* certName= */ null,
272             "gcp_id",
273             "root-default",
274             /* alpnProtocols= */ null,
275             staticCertValidationContext);
276 
277     Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
278     clientSslContextProviderFactory =
279         new ClientSslContextProviderFactory(
280             bootstrapInfo, certProviderClientSslContextProviderFactory);
281     SslContextProvider sslContextProvider =
282         clientSslContextProviderFactory.create(upstreamTlsContext);
283     assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo(
284         "CertProviderClientSslContextProvider");
285     verifyWatcher(sslContextProvider, watcherCaptor[0]);
286   }
287 
288   @Test
createNullCommonTlsContext_exception()289   public void createNullCommonTlsContext_exception() throws IOException {
290     clientSslContextProviderFactory =
291             new ClientSslContextProviderFactory(
292                     null, certProviderClientSslContextProviderFactory);
293     UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null);
294     try {
295       clientSslContextProviderFactory.create(upstreamTlsContext);
296       Assert.fail("no exception thrown");
297     } catch (NullPointerException expected) {
298       assertThat(expected)
299               .hasMessageThat()
300               .isEqualTo("upstreamTlsContext should have CommonTlsContext");
301     }
302   }
303 
createAndRegisterProviderProvider( CertificateProviderRegistry certificateProviderRegistry, final CertificateProvider.DistributorWatcher[] watcherCaptor, String testca, final int i)304   static void createAndRegisterProviderProvider(
305       CertificateProviderRegistry certificateProviderRegistry,
306       final CertificateProvider.DistributorWatcher[] watcherCaptor,
307       String testca,
308       final int i) {
309     final CertificateProviderProvider mockProviderProviderTestCa =
310         mock(CertificateProviderProvider.class);
311     when(mockProviderProviderTestCa.getName()).thenReturn(testca);
312 
313     when(mockProviderProviderTestCa.createCertificateProvider(
314             any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true)))
315         .thenAnswer(
316             new Answer<CertificateProvider>() {
317               @Override
318               public CertificateProvider answer(InvocationOnMock invocation) throws Throwable {
319                 Object[] args = invocation.getArguments();
320                 CertificateProvider.DistributorWatcher watcher =
321                     (CertificateProvider.DistributorWatcher) args[1];
322                 watcherCaptor[i] = watcher;
323                 return new TestCertificateProvider(
324                     watcher, true, args[0], mockProviderProviderTestCa, false);
325               }
326             });
327     certificateProviderRegistry.register(mockProviderProviderTestCa);
328   }
329 
verifyWatcher( SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor)330   static void verifyWatcher(
331       SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) {
332     assertThat(watcherCaptor).isNotNull();
333     assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1);
334     assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
335         .isSameInstanceAs(sslContextProvider);
336   }
337 
338   @SuppressWarnings("deprecation")
addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa)339   static CommonTlsContext.Builder addFilenames(
340       CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) {
341     TlsCertificate tlsCert =
342         TlsCertificate.newBuilder()
343             .setCertificateChain(DataSource.newBuilder().setFilename(certChain))
344             .setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
345             .build();
346     CertificateValidationContext certContext =
347         CertificateValidationContext.newBuilder()
348             .setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
349             .build();
350     CommonTlsContext.CertificateProviderInstance certificateProviderInstance =
351         builder.getValidationContextCertificateProviderInstance();
352     CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder =
353         CommonTlsContext.CombinedCertificateValidationContext.newBuilder();
354     combinedBuilder
355         .setDefaultValidationContext(certContext)
356         .setValidationContextCertificateProviderInstance(certificateProviderInstance);
357     return builder
358         .addTlsCertificates(tlsCert)
359         .setCombinedValidationContext(combinedBuilder.build());
360   }
361 }
362