1 /* 2 * Copyright 2016 Google LLC 3 * 4 * Redistribution and use in source and binary forms, with or without 5 * modification, are permitted provided that the following conditions are 6 * met: 7 * 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above 11 * copyright notice, this list of conditions and the following disclaimer 12 * in the documentation and/or other materials provided with the 13 * distribution. 14 * * Neither the name of Google LLC nor the names of its 15 * contributors may be used to endorse or promote products derived from 16 * this software without specific prior written permission. 17 * 18 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 */ 30 package com.google.api.gax.grpc; 31 32 import static com.google.common.base.Preconditions.checkArgument; 33 import static com.google.common.truth.Truth.assertThat; 34 import static org.junit.Assert.assertEquals; 35 36 import com.google.api.core.ApiFunction; 37 import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder; 38 import com.google.api.gax.rpc.HeaderProvider; 39 import com.google.api.gax.rpc.TransportChannelProvider; 40 import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest; 41 import com.google.api.gax.rpc.mtls.MtlsProvider; 42 import com.google.auth.oauth2.CloudShellCredentials; 43 import com.google.auth.oauth2.ComputeEngineCredentials; 44 import com.google.common.collect.ImmutableList; 45 import com.google.common.collect.ImmutableMap; 46 import io.grpc.ManagedChannel; 47 import io.grpc.ManagedChannelBuilder; 48 import io.grpc.alts.ComputeEngineChannelBuilder; 49 import java.io.IOException; 50 import java.security.GeneralSecurityException; 51 import java.util.ArrayList; 52 import java.util.Collections; 53 import java.util.HashMap; 54 import java.util.List; 55 import java.util.Map; 56 import java.util.concurrent.Executor; 57 import java.util.concurrent.ScheduledExecutorService; 58 import java.util.concurrent.ScheduledThreadPoolExecutor; 59 import javax.annotation.Nullable; 60 import org.junit.Test; 61 import org.junit.runner.RunWith; 62 import org.junit.runners.JUnit4; 63 import org.mockito.ArgumentCaptor; 64 import org.mockito.Mockito; 65 import org.threeten.bp.Duration; 66 67 @RunWith(JUnit4.class) 68 public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest { 69 70 @Test testEndpoint()71 public void testEndpoint() { 72 String endpoint = "localhost:8080"; 73 InstantiatingGrpcChannelProvider.Builder builder = 74 InstantiatingGrpcChannelProvider.newBuilder(); 75 builder.setEndpoint(endpoint); 76 assertEquals(builder.getEndpoint(), endpoint); 77 78 InstantiatingGrpcChannelProvider provider = builder.build(); 79 assertEquals(provider.getEndpoint(), endpoint); 80 } 81 82 @Test(expected = IllegalArgumentException.class) testEndpointNoPort()83 public void testEndpointNoPort() { 84 InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost"); 85 } 86 87 @Test(expected = IllegalArgumentException.class) testEndpointBadPort()88 public void testEndpointBadPort() { 89 InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:abcd"); 90 } 91 92 @Test testKeepAlive()93 public void testKeepAlive() { 94 Duration keepaliveTime = Duration.ofSeconds(1); 95 Duration keepaliveTimeout = Duration.ofSeconds(2); 96 boolean keepaliveWithoutCalls = true; 97 98 InstantiatingGrpcChannelProvider provider = 99 InstantiatingGrpcChannelProvider.newBuilder() 100 .setKeepAliveTime(keepaliveTime) 101 .setKeepAliveTimeout(keepaliveTimeout) 102 .setKeepAliveWithoutCalls(keepaliveWithoutCalls) 103 .build(); 104 105 assertEquals(provider.getKeepAliveTime(), keepaliveTime); 106 assertEquals(provider.getKeepAliveTimeout(), keepaliveTimeout); 107 assertEquals(provider.getKeepAliveWithoutCalls(), keepaliveWithoutCalls); 108 } 109 110 @Test testMaxInboundMetadataSize()111 public void testMaxInboundMetadataSize() { 112 InstantiatingGrpcChannelProvider provider = 113 InstantiatingGrpcChannelProvider.newBuilder().setMaxInboundMetadataSize(4096).build(); 114 assertThat(provider.getMaxInboundMetadataSize()).isEqualTo(4096); 115 } 116 117 @Test testCpuPoolSize()118 public void testCpuPoolSize() { 119 // happy path 120 Builder builder = InstantiatingGrpcChannelProvider.newBuilder().setProcessorCount(2); 121 builder.setChannelsPerCpu(2.5); 122 assertEquals(5, builder.getPoolSize()); 123 124 // User specified max 125 builder = builder.setProcessorCount(50); 126 builder.setChannelsPerCpu(100, 10); 127 assertEquals(10, builder.getPoolSize()); 128 129 // Sane default maximum 130 builder.setChannelsPerCpu(200); 131 assertEquals(100, builder.getPoolSize()); 132 } 133 134 @Test testWithPoolSize()135 public void testWithPoolSize() throws IOException { 136 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 137 executor.shutdown(); 138 139 TransportChannelProvider provider = 140 InstantiatingGrpcChannelProvider.newBuilder() 141 .build() 142 .withExecutor((Executor) executor) 143 .withHeaders(Collections.<String, String>emptyMap()) 144 .withEndpoint("localhost:8080"); 145 assertThat(provider.acceptsPoolSize()).isTrue(); 146 147 // Make sure we can create channels OK. 148 provider.getTransportChannel().shutdownNow(); 149 150 provider = provider.withPoolSize(2); 151 provider.getTransportChannel().shutdownNow(); 152 } 153 154 @Test testToBuilder()155 public void testToBuilder() { 156 Duration keepaliveTime = Duration.ofSeconds(1); 157 Duration keepaliveTimeout = Duration.ofSeconds(2); 158 ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator = 159 builder -> { 160 throw new UnsupportedOperationException(); 161 }; 162 Map<String, ?> directPathServiceConfig = ImmutableMap.of("loadbalancingConfig", "grpclb"); 163 164 InstantiatingGrpcChannelProvider provider = 165 InstantiatingGrpcChannelProvider.newBuilder() 166 .setProcessorCount(2) 167 .setEndpoint("fake.endpoint:443") 168 .setMaxInboundMessageSize(12345678) 169 .setMaxInboundMetadataSize(4096) 170 .setKeepAliveTime(keepaliveTime) 171 .setKeepAliveTimeout(keepaliveTimeout) 172 .setKeepAliveWithoutCalls(true) 173 .setChannelConfigurator(channelConfigurator) 174 .setChannelsPerCpu(2.5) 175 .setDirectPathServiceConfig(directPathServiceConfig) 176 .build(); 177 178 InstantiatingGrpcChannelProvider.Builder builder = provider.toBuilder(); 179 180 assertThat(builder.getEndpoint()).isEqualTo("fake.endpoint:443"); 181 assertThat(builder.getMaxInboundMessageSize()).isEqualTo(12345678); 182 assertThat(builder.getMaxInboundMetadataSize()).isEqualTo(4096); 183 assertThat(builder.getKeepAliveTime()).isEqualTo(keepaliveTime); 184 assertThat(builder.getKeepAliveTimeout()).isEqualTo(keepaliveTimeout); 185 assertThat(builder.getChannelConfigurator()).isEqualTo(channelConfigurator); 186 assertThat(builder.getPoolSize()).isEqualTo(5); 187 assertThat(builder.build().directPathServiceConfig).isEqualTo(directPathServiceConfig); 188 } 189 190 @Test testWithInterceptors()191 public void testWithInterceptors() throws Exception { 192 testWithInterceptors(1); 193 } 194 195 @Test testWithInterceptorsAndMultipleChannels()196 public void testWithInterceptorsAndMultipleChannels() throws Exception { 197 testWithInterceptors(5); 198 } 199 testWithInterceptors(int numChannels)200 private void testWithInterceptors(int numChannels) throws Exception { 201 final GrpcInterceptorProvider interceptorProvider = Mockito.mock(GrpcInterceptorProvider.class); 202 203 InstantiatingGrpcChannelProvider channelProvider = 204 InstantiatingGrpcChannelProvider.newBuilder() 205 .setEndpoint("localhost:8080") 206 .setPoolSize(numChannels) 207 .setHeaderProvider(Mockito.mock(HeaderProvider.class)) 208 .setExecutor(Mockito.mock(Executor.class)) 209 .setInterceptorProvider(interceptorProvider) 210 .build(); 211 212 Mockito.verify(interceptorProvider, Mockito.never()).getInterceptors(); 213 channelProvider.getTransportChannel().shutdownNow(); 214 Mockito.verify(interceptorProvider, Mockito.times(numChannels)).getInterceptors(); 215 } 216 217 @Test testChannelConfigurator()218 public void testChannelConfigurator() throws IOException { 219 final int numChannels = 5; 220 221 // Create a mock configurator that will insert mock channels 222 @SuppressWarnings("unchecked") 223 ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator = 224 Mockito.mock(ApiFunction.class); 225 226 ArgumentCaptor<ManagedChannelBuilder<?>> channelBuilderCaptor = 227 ArgumentCaptor.forClass(ManagedChannelBuilder.class); 228 229 ManagedChannelBuilder<?> swappedBuilder = Mockito.mock(ManagedChannelBuilder.class); 230 ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); 231 Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel); 232 233 Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture())) 234 .thenReturn(swappedBuilder); 235 236 // Invoke the provider 237 InstantiatingGrpcChannelProvider.newBuilder() 238 .setEndpoint("localhost:8080") 239 .setHeaderProvider(Mockito.mock(HeaderProvider.class)) 240 .setExecutor(Mockito.mock(Executor.class)) 241 .setChannelConfigurator(channelConfigurator) 242 .setPoolSize(numChannels) 243 .build() 244 .getTransportChannel(); 245 246 // Make sure that the provider passed in a configured channel 247 assertThat(channelBuilderCaptor.getValue()).isNotNull(); 248 // And that it was replaced with the mock 249 Mockito.verify(swappedBuilder, Mockito.times(numChannels)).build(); 250 } 251 252 @Test testWithGCECredentials()253 public void testWithGCECredentials() throws IOException { 254 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 255 executor.shutdown(); 256 257 TransportChannelProvider provider = 258 InstantiatingGrpcChannelProvider.newBuilder() 259 .setAttemptDirectPath(true) 260 .build() 261 .withExecutor((Executor) executor) 262 .withHeaders(Collections.<String, String>emptyMap()) 263 .withEndpoint("localhost:8080"); 264 265 assertThat(provider.needsCredentials()).isTrue(); 266 if (InstantiatingGrpcChannelProvider.isOnComputeEngine()) { 267 provider = provider.withCredentials(ComputeEngineCredentials.create()); 268 } else { 269 provider = provider.withCredentials(CloudShellCredentials.create(3000)); 270 } 271 assertThat(provider.needsCredentials()).isFalse(); 272 273 provider.getTransportChannel().shutdownNow(); 274 } 275 276 @Test testDirectPathXdsDisableByDefault()277 public void testDirectPathXdsDisableByDefault() throws IOException { 278 InstantiatingGrpcChannelProvider provider = 279 InstantiatingGrpcChannelProvider.newBuilder().setAttemptDirectPath(true).build(); 280 281 assertThat(provider.isDirectPathXdsEnabled()).isFalse(); 282 } 283 284 @Test testDirectPathXdsEnabled()285 public void testDirectPathXdsEnabled() throws IOException { 286 InstantiatingGrpcChannelProvider provider = 287 InstantiatingGrpcChannelProvider.newBuilder() 288 .setAttemptDirectPath(true) 289 .setAttemptDirectPathXds() 290 .build(); 291 292 assertThat(provider.isDirectPathXdsEnabled()).isTrue(); 293 } 294 295 @Test testWithNonGCECredentials()296 public void testWithNonGCECredentials() throws IOException { 297 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 298 executor.shutdown(); 299 300 ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator = 301 channelBuilder -> { 302 // Clients with non-GCE credentials will not attempt DirectPath. 303 assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse(); 304 return channelBuilder; 305 }; 306 307 TransportChannelProvider provider = 308 InstantiatingGrpcChannelProvider.newBuilder() 309 .setAttemptDirectPath(true) 310 .setChannelConfigurator(channelConfigurator) 311 .build() 312 .withExecutor((Executor) executor) 313 .withHeaders(Collections.<String, String>emptyMap()) 314 .withEndpoint("localhost:8080"); 315 316 assertThat(provider.needsCredentials()).isTrue(); 317 provider = provider.withCredentials(CloudShellCredentials.create(3000)); 318 assertThat(provider.needsCredentials()).isFalse(); 319 320 provider.getTransportChannel().shutdownNow(); 321 } 322 323 @Test testWithDirectPathDisabled()324 public void testWithDirectPathDisabled() throws IOException { 325 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 326 executor.shutdown(); 327 328 ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator = 329 channelBuilder -> { 330 // Clients without setting attemptDirectPath flag will not attempt DirectPath 331 assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse(); 332 return channelBuilder; 333 }; 334 335 TransportChannelProvider provider = 336 InstantiatingGrpcChannelProvider.newBuilder() 337 .setAttemptDirectPath(false) 338 .setChannelConfigurator(channelConfigurator) 339 .build() 340 .withExecutor((Executor) executor) 341 .withHeaders(Collections.<String, String>emptyMap()) 342 .withEndpoint("localhost:8080"); 343 344 assertThat(provider.needsCredentials()).isTrue(); 345 provider = provider.withCredentials(ComputeEngineCredentials.create()); 346 assertThat(provider.needsCredentials()).isFalse(); 347 348 provider.getTransportChannel().shutdownNow(); 349 } 350 351 @Test testWithNoDirectPathFlagSet()352 public void testWithNoDirectPathFlagSet() throws IOException { 353 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 354 executor.shutdown(); 355 356 ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator = 357 channelBuilder -> { 358 // Clients without setting attemptDirectPath flag will not attempt DirectPath 359 assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse(); 360 return channelBuilder; 361 }; 362 363 TransportChannelProvider provider = 364 InstantiatingGrpcChannelProvider.newBuilder() 365 .setChannelConfigurator(channelConfigurator) 366 .build() 367 .withExecutor((Executor) executor) 368 .withHeaders(Collections.<String, String>emptyMap()) 369 .withEndpoint("localhost:8080"); 370 371 assertThat(provider.needsCredentials()).isTrue(); 372 provider = provider.withCredentials(ComputeEngineCredentials.create()); 373 assertThat(provider.needsCredentials()).isFalse(); 374 375 provider.getTransportChannel().shutdownNow(); 376 } 377 378 @Test testWithIPv6Address()379 public void testWithIPv6Address() throws IOException { 380 ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1); 381 executor.shutdown(); 382 383 TransportChannelProvider provider = 384 InstantiatingGrpcChannelProvider.newBuilder() 385 .build() 386 .withExecutor((Executor) executor) 387 .withHeaders(Collections.<String, String>emptyMap()) 388 .withEndpoint("[::1]:8080"); 389 assertThat(provider.needsEndpoint()).isFalse(); 390 391 // Make sure we can create channels OK. 392 provider.getTransportChannel().shutdownNow(); 393 } 394 395 // Test that if ChannelPrimer is provided, it is called during creation 396 @Test testWithPrimeChannel()397 public void testWithPrimeChannel() throws IOException { 398 // create channelProvider with different pool sizes to verify ChannelPrimer is called the 399 // correct number of times 400 for (int poolSize = 1; poolSize < 5; poolSize++) { 401 final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class); 402 403 InstantiatingGrpcChannelProvider provider = 404 InstantiatingGrpcChannelProvider.newBuilder() 405 .setEndpoint("localhost:8080") 406 .setPoolSize(poolSize) 407 .setHeaderProvider(Mockito.mock(HeaderProvider.class)) 408 .setExecutor(Mockito.mock(Executor.class)) 409 .setChannelPrimer(mockChannelPrimer) 410 .build(); 411 412 provider.getTransportChannel().shutdownNow(); 413 414 // every channel in the pool should call primeChannel during creation. 415 Mockito.verify(mockChannelPrimer, Mockito.times(poolSize)) 416 .primeChannel(Mockito.any(ManagedChannel.class)); 417 } 418 } 419 420 @Test testWithDefaultDirectPathServiceConfig()421 public void testWithDefaultDirectPathServiceConfig() { 422 InstantiatingGrpcChannelProvider provider = 423 InstantiatingGrpcChannelProvider.newBuilder().build(); 424 425 ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig; 426 427 List<Map<String, ?>> lbConfigs = getAsObjectList(defaultServiceConfig, "loadBalancingConfig"); 428 assertThat(lbConfigs).hasSize(1); 429 Map<String, ?> lbConfig = lbConfigs.get(0); 430 Map<String, ?> grpclb = getAsObject(lbConfig, "grpclb"); 431 List<Map<String, ?>> childPolicies = getAsObjectList(grpclb, "childPolicy"); 432 assertThat(childPolicies).hasSize(1); 433 Map<String, ?> childPolicy = childPolicies.get(0); 434 assertThat(childPolicy.keySet()).containsExactly("pick_first"); 435 } 436 437 @Nullable getAsObject(Map<String, ?> json, String key)438 private static Map<String, ?> getAsObject(Map<String, ?> json, String key) { 439 Object mapObject = json.get(key); 440 if (mapObject == null) { 441 return null; 442 } 443 return checkObject(mapObject); 444 } 445 446 @SuppressWarnings("unchecked") checkObject(Object json)447 private static Map<String, ?> checkObject(Object json) { 448 checkArgument(json instanceof Map, "Invalid json object representation: %s", json); 449 for (Map.Entry<Object, Object> entry : ((Map<Object, Object>) json).entrySet()) { 450 checkArgument(entry.getKey() instanceof String, "Key is not string"); 451 } 452 return (Map<String, ?>) json; 453 } 454 getAsObjectList(Map<String, ?> json, String key)455 private static List<Map<String, ?>> getAsObjectList(Map<String, ?> json, String key) { 456 Object listObject = json.get(key); 457 if (listObject == null) { 458 return null; 459 } 460 return checkListOfObjects(listObject); 461 } 462 463 @SuppressWarnings("unchecked") checkListOfObjects(Object listObject)464 private static List<Map<String, ?>> checkListOfObjects(Object listObject) { 465 checkArgument(listObject instanceof List, "Passed object is not a list"); 466 List<Map<String, ?>> list = new ArrayList<>(); 467 for (Object object : ((List<Object>) listObject)) { 468 list.add(checkObject(object)); 469 } 470 return list; 471 } 472 473 @Test testWithCustomDirectPathServiceConfig()474 public void testWithCustomDirectPathServiceConfig() { 475 ImmutableMap<String, Object> pickFirstStrategy = 476 ImmutableMap.<String, Object>of("round_robin", ImmutableMap.of()); 477 ImmutableMap<String, Object> childPolicy = 478 ImmutableMap.<String, Object>of( 479 "childPolicy", ImmutableList.of(pickFirstStrategy), "foo", "bar"); 480 ImmutableMap<String, Object> grpcLbPolicy = 481 ImmutableMap.<String, Object>of("grpclb", childPolicy); 482 Map<String, Object> passedServiceConfig = new HashMap<>(); 483 passedServiceConfig.put("loadBalancingConfig", ImmutableList.of(grpcLbPolicy)); 484 485 InstantiatingGrpcChannelProvider provider = 486 InstantiatingGrpcChannelProvider.newBuilder() 487 .setDirectPathServiceConfig(passedServiceConfig) 488 .build(); 489 490 ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig; 491 assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig); 492 } 493 494 @Override getMtlsObjectFromTransportChannel(MtlsProvider provider)495 protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider) 496 throws IOException, GeneralSecurityException { 497 InstantiatingGrpcChannelProvider channelProvider = 498 InstantiatingGrpcChannelProvider.newBuilder() 499 .setEndpoint("localhost:8080") 500 .setMtlsProvider(provider) 501 .setHeaderProvider(Mockito.mock(HeaderProvider.class)) 502 .setExecutor(Mockito.mock(Executor.class)) 503 .build(); 504 return channelProvider.createMtlsChannelCredentials(); 505 } 506 } 507