1 /* 2 * Copyright 2019 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.xds.internal.security; 18 19 import static com.google.common.truth.Truth.assertThat; 20 import static org.mockito.ArgumentMatchers.any; 21 import static org.mockito.ArgumentMatchers.eq; 22 import static org.mockito.Mockito.mock; 23 import static org.mockito.Mockito.when; 24 25 import com.google.common.collect.ImmutableSet; 26 import io.envoyproxy.envoy.config.core.v3.DataSource; 27 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; 28 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; 29 import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; 30 import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; 31 import io.grpc.xds.Bootstrapper; 32 import io.grpc.xds.CommonBootstrapperTestUtils; 33 import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; 34 import io.grpc.xds.XdsInitializationException; 35 import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; 36 import io.grpc.xds.internal.security.certprovider.CertificateProvider; 37 import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; 38 import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; 39 import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; 40 import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; 41 import java.io.IOException; 42 import org.junit.Assert; 43 import org.junit.Before; 44 import org.junit.Test; 45 import org.junit.runner.RunWith; 46 import org.junit.runners.JUnit4; 47 import org.mockito.invocation.InvocationOnMock; 48 import org.mockito.stubbing.Answer; 49 50 /** Unit tests for {@link ClientSslContextProviderFactory}. */ 51 @RunWith(JUnit4.class) 52 public class ClientSslContextProviderFactoryTest { 53 54 CertificateProviderRegistry certificateProviderRegistry; 55 CertificateProviderStore certificateProviderStore; 56 CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; 57 ClientSslContextProviderFactory clientSslContextProviderFactory; 58 59 @Before setUp()60 public void setUp() { 61 certificateProviderRegistry = new CertificateProviderRegistry(); 62 certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); 63 certProviderClientSslContextProviderFactory = 64 new CertProviderClientSslContextProviderFactory(certificateProviderStore); 65 } 66 67 @Test createCertProviderClientSslContextProvider()68 public void createCertProviderClientSslContextProvider() throws XdsInitializationException { 69 final CertificateProvider.DistributorWatcher[] watcherCaptor = 70 new CertificateProvider.DistributorWatcher[1]; 71 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 72 UpstreamTlsContext upstreamTlsContext = 73 CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( 74 "gcp_id", 75 "cert-default", 76 "gcp_id", 77 "root-default", 78 /* alpnProtocols= */ null, 79 /* staticCertValidationContext= */ null); 80 81 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 82 clientSslContextProviderFactory = 83 new ClientSslContextProviderFactory( 84 bootstrapInfo, certProviderClientSslContextProviderFactory); 85 SslContextProvider sslContextProvider = 86 clientSslContextProviderFactory.create(upstreamTlsContext); 87 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 88 "CertProviderClientSslContextProvider"); 89 verifyWatcher(sslContextProvider, watcherCaptor[0]); 90 // verify that bootstrapInfo is cached... 91 sslContextProvider = 92 clientSslContextProviderFactory.create(upstreamTlsContext); 93 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 94 "CertProviderClientSslContextProvider"); 95 } 96 97 @Test bothPresent_expectCertProviderClientSslContextProvider()98 public void bothPresent_expectCertProviderClientSslContextProvider() 99 throws XdsInitializationException { 100 final CertificateProvider.DistributorWatcher[] watcherCaptor = 101 new CertificateProvider.DistributorWatcher[1]; 102 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 103 UpstreamTlsContext upstreamTlsContext = 104 CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( 105 "gcp_id", 106 "cert-default", 107 "gcp_id", 108 "root-default", 109 /* alpnProtocols= */ null, 110 /* staticCertValidationContext= */ null); 111 112 CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); 113 builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); 114 upstreamTlsContext = new UpstreamTlsContext(builder.build()); 115 116 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 117 clientSslContextProviderFactory = 118 new ClientSslContextProviderFactory( 119 bootstrapInfo, certProviderClientSslContextProviderFactory); 120 SslContextProvider sslContextProvider = 121 clientSslContextProviderFactory.create(upstreamTlsContext); 122 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 123 "CertProviderClientSslContextProvider"); 124 verifyWatcher(sslContextProvider, watcherCaptor[0]); 125 } 126 127 @Test createCertProviderClientSslContextProvider_onlyRootCert()128 public void createCertProviderClientSslContextProvider_onlyRootCert() 129 throws XdsInitializationException { 130 final CertificateProvider.DistributorWatcher[] watcherCaptor = 131 new CertificateProvider.DistributorWatcher[1]; 132 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 133 UpstreamTlsContext upstreamTlsContext = 134 CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( 135 /* certInstanceName= */ null, 136 /* certName= */ null, 137 "gcp_id", 138 "root-default", 139 /* alpnProtocols= */ null, 140 /* staticCertValidationContext= */ null); 141 142 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 143 clientSslContextProviderFactory = 144 new ClientSslContextProviderFactory( 145 bootstrapInfo, certProviderClientSslContextProviderFactory); 146 SslContextProvider sslContextProvider = 147 clientSslContextProviderFactory.create(upstreamTlsContext); 148 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 149 "CertProviderClientSslContextProvider"); 150 verifyWatcher(sslContextProvider, watcherCaptor[0]); 151 } 152 153 @Test createCertProviderClientSslContextProvider_withStaticContext()154 public void createCertProviderClientSslContextProvider_withStaticContext() 155 throws XdsInitializationException { 156 final CertificateProvider.DistributorWatcher[] watcherCaptor = 157 new CertificateProvider.DistributorWatcher[1]; 158 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 159 @SuppressWarnings("deprecation") 160 CertificateValidationContext staticCertValidationContext = 161 CertificateValidationContext.newBuilder() 162 .addAllMatchSubjectAltNames( 163 ImmutableSet.of( 164 StringMatcher.newBuilder().setExact("foo").build(), 165 StringMatcher.newBuilder().setExact("bar").build())) 166 .build(); 167 UpstreamTlsContext upstreamTlsContext = 168 CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( 169 /* certInstanceName= */ null, 170 /* certName= */ null, 171 "gcp_id", 172 "root-default", 173 /* alpnProtocols= */ null, 174 staticCertValidationContext); 175 176 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 177 clientSslContextProviderFactory = 178 new ClientSslContextProviderFactory(bootstrapInfo, 179 certProviderClientSslContextProviderFactory); 180 SslContextProvider sslContextProvider = 181 clientSslContextProviderFactory.create(upstreamTlsContext); 182 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 183 "CertProviderClientSslContextProvider"); 184 verifyWatcher(sslContextProvider, watcherCaptor[0]); 185 } 186 187 @Test createCertProviderClientSslContextProvider_2providers()188 public void createCertProviderClientSslContextProvider_2providers() 189 throws XdsInitializationException { 190 final CertificateProvider.DistributorWatcher[] watcherCaptor = 191 new CertificateProvider.DistributorWatcher[2]; 192 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 193 194 createAndRegisterProviderProvider( 195 certificateProviderRegistry, watcherCaptor, "file_watcher", 1); 196 197 UpstreamTlsContext upstreamTlsContext = 198 CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( 199 "gcp_id", 200 "cert-default", 201 "file_provider", 202 "root-default", 203 /* alpnProtocols= */ null, 204 /* staticCertValidationContext= */ null); 205 206 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 207 clientSslContextProviderFactory = 208 new ClientSslContextProviderFactory( 209 bootstrapInfo, certProviderClientSslContextProviderFactory); 210 SslContextProvider sslContextProvider = 211 clientSslContextProviderFactory.create(upstreamTlsContext); 212 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 213 "CertProviderClientSslContextProvider"); 214 verifyWatcher(sslContextProvider, watcherCaptor[0]); 215 verifyWatcher(sslContextProvider, watcherCaptor[1]); 216 } 217 218 @Test createNewCertProviderClientSslContextProvider_withSans()219 public void createNewCertProviderClientSslContextProvider_withSans() { 220 final CertificateProvider.DistributorWatcher[] watcherCaptor = 221 new CertificateProvider.DistributorWatcher[2]; 222 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 223 createAndRegisterProviderProvider( 224 certificateProviderRegistry, watcherCaptor, "file_watcher", 1); 225 226 @SuppressWarnings("deprecation") 227 CertificateValidationContext staticCertValidationContext = 228 CertificateValidationContext.newBuilder() 229 .addAllMatchSubjectAltNames( 230 ImmutableSet.of( 231 StringMatcher.newBuilder().setExact("foo").build(), 232 StringMatcher.newBuilder().setExact("bar").build())) 233 .build(); 234 UpstreamTlsContext upstreamTlsContext = 235 CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( 236 "gcp_id", 237 "cert-default", 238 "file_provider", 239 "root-default", 240 /* alpnProtocols= */ null, 241 staticCertValidationContext); 242 243 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 244 clientSslContextProviderFactory = 245 new ClientSslContextProviderFactory( 246 bootstrapInfo, certProviderClientSslContextProviderFactory); 247 SslContextProvider sslContextProvider = 248 clientSslContextProviderFactory.create(upstreamTlsContext); 249 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 250 "CertProviderClientSslContextProvider"); 251 verifyWatcher(sslContextProvider, watcherCaptor[0]); 252 verifyWatcher(sslContextProvider, watcherCaptor[1]); 253 } 254 255 @Test createNewCertProviderClientSslContextProvider_onlyRootCert()256 public void createNewCertProviderClientSslContextProvider_onlyRootCert() { 257 final CertificateProvider.DistributorWatcher[] watcherCaptor = 258 new CertificateProvider.DistributorWatcher[1]; 259 createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); 260 @SuppressWarnings("deprecation") 261 CertificateValidationContext staticCertValidationContext = 262 CertificateValidationContext.newBuilder() 263 .addAllMatchSubjectAltNames( 264 ImmutableSet.of( 265 StringMatcher.newBuilder().setExact("foo").build(), 266 StringMatcher.newBuilder().setExact("bar").build())) 267 .build(); 268 UpstreamTlsContext upstreamTlsContext = 269 CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( 270 /* certInstanceName= */ null, 271 /* certName= */ null, 272 "gcp_id", 273 "root-default", 274 /* alpnProtocols= */ null, 275 staticCertValidationContext); 276 277 Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); 278 clientSslContextProviderFactory = 279 new ClientSslContextProviderFactory( 280 bootstrapInfo, certProviderClientSslContextProviderFactory); 281 SslContextProvider sslContextProvider = 282 clientSslContextProviderFactory.create(upstreamTlsContext); 283 assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( 284 "CertProviderClientSslContextProvider"); 285 verifyWatcher(sslContextProvider, watcherCaptor[0]); 286 } 287 288 @Test createNullCommonTlsContext_exception()289 public void createNullCommonTlsContext_exception() throws IOException { 290 clientSslContextProviderFactory = 291 new ClientSslContextProviderFactory( 292 null, certProviderClientSslContextProviderFactory); 293 UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null); 294 try { 295 clientSslContextProviderFactory.create(upstreamTlsContext); 296 Assert.fail("no exception thrown"); 297 } catch (NullPointerException expected) { 298 assertThat(expected) 299 .hasMessageThat() 300 .isEqualTo("upstreamTlsContext should have CommonTlsContext"); 301 } 302 } 303 createAndRegisterProviderProvider( CertificateProviderRegistry certificateProviderRegistry, final CertificateProvider.DistributorWatcher[] watcherCaptor, String testca, final int i)304 static void createAndRegisterProviderProvider( 305 CertificateProviderRegistry certificateProviderRegistry, 306 final CertificateProvider.DistributorWatcher[] watcherCaptor, 307 String testca, 308 final int i) { 309 final CertificateProviderProvider mockProviderProviderTestCa = 310 mock(CertificateProviderProvider.class); 311 when(mockProviderProviderTestCa.getName()).thenReturn(testca); 312 313 when(mockProviderProviderTestCa.createCertificateProvider( 314 any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true))) 315 .thenAnswer( 316 new Answer<CertificateProvider>() { 317 @Override 318 public CertificateProvider answer(InvocationOnMock invocation) throws Throwable { 319 Object[] args = invocation.getArguments(); 320 CertificateProvider.DistributorWatcher watcher = 321 (CertificateProvider.DistributorWatcher) args[1]; 322 watcherCaptor[i] = watcher; 323 return new TestCertificateProvider( 324 watcher, true, args[0], mockProviderProviderTestCa, false); 325 } 326 }); 327 certificateProviderRegistry.register(mockProviderProviderTestCa); 328 } 329 verifyWatcher( SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor)330 static void verifyWatcher( 331 SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { 332 assertThat(watcherCaptor).isNotNull(); 333 assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); 334 assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) 335 .isSameInstanceAs(sslContextProvider); 336 } 337 338 @SuppressWarnings("deprecation") addFilenames( CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa)339 static CommonTlsContext.Builder addFilenames( 340 CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) { 341 TlsCertificate tlsCert = 342 TlsCertificate.newBuilder() 343 .setCertificateChain(DataSource.newBuilder().setFilename(certChain)) 344 .setPrivateKey(DataSource.newBuilder().setFilename(privateKey)) 345 .build(); 346 CertificateValidationContext certContext = 347 CertificateValidationContext.newBuilder() 348 .setTrustedCa(DataSource.newBuilder().setFilename(trustCa)) 349 .build(); 350 CommonTlsContext.CertificateProviderInstance certificateProviderInstance = 351 builder.getValidationContextCertificateProviderInstance(); 352 CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder = 353 CommonTlsContext.CombinedCertificateValidationContext.newBuilder(); 354 combinedBuilder 355 .setDefaultValidationContext(certContext) 356 .setValidationContextCertificateProviderInstance(certificateProviderInstance); 357 return builder 358 .addTlsCertificates(tlsCert) 359 .setCombinedValidationContext(combinedBuilder.build()); 360 } 361 } 362