1 /*
2  * Copyright 2016 Google LLC
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions are
6  * met:
7  *
8  *     * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  *     * Redistributions in binary form must reproduce the above
11  * copyright notice, this list of conditions and the following disclaimer
12  * in the documentation and/or other materials provided with the
13  * distribution.
14  *     * Neither the name of Google LLC nor the names of its
15  * contributors may be used to endorse or promote products derived from
16  * this software without specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  */
30 package com.google.api.gax.grpc;
31 
32 import static com.google.common.base.Preconditions.checkArgument;
33 import static com.google.common.truth.Truth.assertThat;
34 import static org.junit.Assert.assertEquals;
35 
36 import com.google.api.core.ApiFunction;
37 import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
38 import com.google.api.gax.rpc.HeaderProvider;
39 import com.google.api.gax.rpc.TransportChannelProvider;
40 import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
41 import com.google.api.gax.rpc.mtls.MtlsProvider;
42 import com.google.auth.oauth2.CloudShellCredentials;
43 import com.google.auth.oauth2.ComputeEngineCredentials;
44 import com.google.common.collect.ImmutableList;
45 import com.google.common.collect.ImmutableMap;
46 import io.grpc.ManagedChannel;
47 import io.grpc.ManagedChannelBuilder;
48 import io.grpc.alts.ComputeEngineChannelBuilder;
49 import java.io.IOException;
50 import java.security.GeneralSecurityException;
51 import java.util.ArrayList;
52 import java.util.Collections;
53 import java.util.HashMap;
54 import java.util.List;
55 import java.util.Map;
56 import java.util.concurrent.Executor;
57 import java.util.concurrent.ScheduledExecutorService;
58 import java.util.concurrent.ScheduledThreadPoolExecutor;
59 import javax.annotation.Nullable;
60 import org.junit.Test;
61 import org.junit.runner.RunWith;
62 import org.junit.runners.JUnit4;
63 import org.mockito.ArgumentCaptor;
64 import org.mockito.Mockito;
65 import org.threeten.bp.Duration;
66 
67 @RunWith(JUnit4.class)
68 public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
69 
70   @Test
testEndpoint()71   public void testEndpoint() {
72     String endpoint = "localhost:8080";
73     InstantiatingGrpcChannelProvider.Builder builder =
74         InstantiatingGrpcChannelProvider.newBuilder();
75     builder.setEndpoint(endpoint);
76     assertEquals(builder.getEndpoint(), endpoint);
77 
78     InstantiatingGrpcChannelProvider provider = builder.build();
79     assertEquals(provider.getEndpoint(), endpoint);
80   }
81 
82   @Test(expected = IllegalArgumentException.class)
testEndpointNoPort()83   public void testEndpointNoPort() {
84     InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost");
85   }
86 
87   @Test(expected = IllegalArgumentException.class)
testEndpointBadPort()88   public void testEndpointBadPort() {
89     InstantiatingGrpcChannelProvider.newBuilder().setEndpoint("localhost:abcd");
90   }
91 
92   @Test
testKeepAlive()93   public void testKeepAlive() {
94     Duration keepaliveTime = Duration.ofSeconds(1);
95     Duration keepaliveTimeout = Duration.ofSeconds(2);
96     boolean keepaliveWithoutCalls = true;
97 
98     InstantiatingGrpcChannelProvider provider =
99         InstantiatingGrpcChannelProvider.newBuilder()
100             .setKeepAliveTime(keepaliveTime)
101             .setKeepAliveTimeout(keepaliveTimeout)
102             .setKeepAliveWithoutCalls(keepaliveWithoutCalls)
103             .build();
104 
105     assertEquals(provider.getKeepAliveTime(), keepaliveTime);
106     assertEquals(provider.getKeepAliveTimeout(), keepaliveTimeout);
107     assertEquals(provider.getKeepAliveWithoutCalls(), keepaliveWithoutCalls);
108   }
109 
110   @Test
testMaxInboundMetadataSize()111   public void testMaxInboundMetadataSize() {
112     InstantiatingGrpcChannelProvider provider =
113         InstantiatingGrpcChannelProvider.newBuilder().setMaxInboundMetadataSize(4096).build();
114     assertThat(provider.getMaxInboundMetadataSize()).isEqualTo(4096);
115   }
116 
117   @Test
testCpuPoolSize()118   public void testCpuPoolSize() {
119     // happy path
120     Builder builder = InstantiatingGrpcChannelProvider.newBuilder().setProcessorCount(2);
121     builder.setChannelsPerCpu(2.5);
122     assertEquals(5, builder.getPoolSize());
123 
124     // User specified max
125     builder = builder.setProcessorCount(50);
126     builder.setChannelsPerCpu(100, 10);
127     assertEquals(10, builder.getPoolSize());
128 
129     // Sane default maximum
130     builder.setChannelsPerCpu(200);
131     assertEquals(100, builder.getPoolSize());
132   }
133 
134   @Test
testWithPoolSize()135   public void testWithPoolSize() throws IOException {
136     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
137     executor.shutdown();
138 
139     TransportChannelProvider provider =
140         InstantiatingGrpcChannelProvider.newBuilder()
141             .build()
142             .withExecutor((Executor) executor)
143             .withHeaders(Collections.<String, String>emptyMap())
144             .withEndpoint("localhost:8080");
145     assertThat(provider.acceptsPoolSize()).isTrue();
146 
147     // Make sure we can create channels OK.
148     provider.getTransportChannel().shutdownNow();
149 
150     provider = provider.withPoolSize(2);
151     provider.getTransportChannel().shutdownNow();
152   }
153 
154   @Test
testToBuilder()155   public void testToBuilder() {
156     Duration keepaliveTime = Duration.ofSeconds(1);
157     Duration keepaliveTimeout = Duration.ofSeconds(2);
158     ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
159         builder -> {
160           throw new UnsupportedOperationException();
161         };
162     Map<String, ?> directPathServiceConfig = ImmutableMap.of("loadbalancingConfig", "grpclb");
163 
164     InstantiatingGrpcChannelProvider provider =
165         InstantiatingGrpcChannelProvider.newBuilder()
166             .setProcessorCount(2)
167             .setEndpoint("fake.endpoint:443")
168             .setMaxInboundMessageSize(12345678)
169             .setMaxInboundMetadataSize(4096)
170             .setKeepAliveTime(keepaliveTime)
171             .setKeepAliveTimeout(keepaliveTimeout)
172             .setKeepAliveWithoutCalls(true)
173             .setChannelConfigurator(channelConfigurator)
174             .setChannelsPerCpu(2.5)
175             .setDirectPathServiceConfig(directPathServiceConfig)
176             .build();
177 
178     InstantiatingGrpcChannelProvider.Builder builder = provider.toBuilder();
179 
180     assertThat(builder.getEndpoint()).isEqualTo("fake.endpoint:443");
181     assertThat(builder.getMaxInboundMessageSize()).isEqualTo(12345678);
182     assertThat(builder.getMaxInboundMetadataSize()).isEqualTo(4096);
183     assertThat(builder.getKeepAliveTime()).isEqualTo(keepaliveTime);
184     assertThat(builder.getKeepAliveTimeout()).isEqualTo(keepaliveTimeout);
185     assertThat(builder.getChannelConfigurator()).isEqualTo(channelConfigurator);
186     assertThat(builder.getPoolSize()).isEqualTo(5);
187     assertThat(builder.build().directPathServiceConfig).isEqualTo(directPathServiceConfig);
188   }
189 
190   @Test
testWithInterceptors()191   public void testWithInterceptors() throws Exception {
192     testWithInterceptors(1);
193   }
194 
195   @Test
testWithInterceptorsAndMultipleChannels()196   public void testWithInterceptorsAndMultipleChannels() throws Exception {
197     testWithInterceptors(5);
198   }
199 
testWithInterceptors(int numChannels)200   private void testWithInterceptors(int numChannels) throws Exception {
201     final GrpcInterceptorProvider interceptorProvider = Mockito.mock(GrpcInterceptorProvider.class);
202 
203     InstantiatingGrpcChannelProvider channelProvider =
204         InstantiatingGrpcChannelProvider.newBuilder()
205             .setEndpoint("localhost:8080")
206             .setPoolSize(numChannels)
207             .setHeaderProvider(Mockito.mock(HeaderProvider.class))
208             .setExecutor(Mockito.mock(Executor.class))
209             .setInterceptorProvider(interceptorProvider)
210             .build();
211 
212     Mockito.verify(interceptorProvider, Mockito.never()).getInterceptors();
213     channelProvider.getTransportChannel().shutdownNow();
214     Mockito.verify(interceptorProvider, Mockito.times(numChannels)).getInterceptors();
215   }
216 
217   @Test
testChannelConfigurator()218   public void testChannelConfigurator() throws IOException {
219     final int numChannels = 5;
220 
221     // Create a mock configurator that will insert mock channels
222     @SuppressWarnings("unchecked")
223     ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
224         Mockito.mock(ApiFunction.class);
225 
226     ArgumentCaptor<ManagedChannelBuilder<?>> channelBuilderCaptor =
227         ArgumentCaptor.forClass(ManagedChannelBuilder.class);
228 
229     ManagedChannelBuilder<?> swappedBuilder = Mockito.mock(ManagedChannelBuilder.class);
230     ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
231     Mockito.when(swappedBuilder.build()).thenReturn(fakeChannel);
232 
233     Mockito.when(channelConfigurator.apply(channelBuilderCaptor.capture()))
234         .thenReturn(swappedBuilder);
235 
236     // Invoke the provider
237     InstantiatingGrpcChannelProvider.newBuilder()
238         .setEndpoint("localhost:8080")
239         .setHeaderProvider(Mockito.mock(HeaderProvider.class))
240         .setExecutor(Mockito.mock(Executor.class))
241         .setChannelConfigurator(channelConfigurator)
242         .setPoolSize(numChannels)
243         .build()
244         .getTransportChannel();
245 
246     // Make sure that the provider passed in a configured channel
247     assertThat(channelBuilderCaptor.getValue()).isNotNull();
248     // And that it was replaced with the mock
249     Mockito.verify(swappedBuilder, Mockito.times(numChannels)).build();
250   }
251 
252   @Test
testWithGCECredentials()253   public void testWithGCECredentials() throws IOException {
254     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
255     executor.shutdown();
256 
257     TransportChannelProvider provider =
258         InstantiatingGrpcChannelProvider.newBuilder()
259             .setAttemptDirectPath(true)
260             .build()
261             .withExecutor((Executor) executor)
262             .withHeaders(Collections.<String, String>emptyMap())
263             .withEndpoint("localhost:8080");
264 
265     assertThat(provider.needsCredentials()).isTrue();
266     if (InstantiatingGrpcChannelProvider.isOnComputeEngine()) {
267       provider = provider.withCredentials(ComputeEngineCredentials.create());
268     } else {
269       provider = provider.withCredentials(CloudShellCredentials.create(3000));
270     }
271     assertThat(provider.needsCredentials()).isFalse();
272 
273     provider.getTransportChannel().shutdownNow();
274   }
275 
276   @Test
testDirectPathXdsDisableByDefault()277   public void testDirectPathXdsDisableByDefault() throws IOException {
278     InstantiatingGrpcChannelProvider provider =
279         InstantiatingGrpcChannelProvider.newBuilder().setAttemptDirectPath(true).build();
280 
281     assertThat(provider.isDirectPathXdsEnabled()).isFalse();
282   }
283 
284   @Test
testDirectPathXdsEnabled()285   public void testDirectPathXdsEnabled() throws IOException {
286     InstantiatingGrpcChannelProvider provider =
287         InstantiatingGrpcChannelProvider.newBuilder()
288             .setAttemptDirectPath(true)
289             .setAttemptDirectPathXds()
290             .build();
291 
292     assertThat(provider.isDirectPathXdsEnabled()).isTrue();
293   }
294 
295   @Test
testWithNonGCECredentials()296   public void testWithNonGCECredentials() throws IOException {
297     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
298     executor.shutdown();
299 
300     ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
301         channelBuilder -> {
302           // Clients with non-GCE credentials will not attempt DirectPath.
303           assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse();
304           return channelBuilder;
305         };
306 
307     TransportChannelProvider provider =
308         InstantiatingGrpcChannelProvider.newBuilder()
309             .setAttemptDirectPath(true)
310             .setChannelConfigurator(channelConfigurator)
311             .build()
312             .withExecutor((Executor) executor)
313             .withHeaders(Collections.<String, String>emptyMap())
314             .withEndpoint("localhost:8080");
315 
316     assertThat(provider.needsCredentials()).isTrue();
317     provider = provider.withCredentials(CloudShellCredentials.create(3000));
318     assertThat(provider.needsCredentials()).isFalse();
319 
320     provider.getTransportChannel().shutdownNow();
321   }
322 
323   @Test
testWithDirectPathDisabled()324   public void testWithDirectPathDisabled() throws IOException {
325     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
326     executor.shutdown();
327 
328     ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
329         channelBuilder -> {
330           // Clients without setting attemptDirectPath flag will not attempt DirectPath
331           assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse();
332           return channelBuilder;
333         };
334 
335     TransportChannelProvider provider =
336         InstantiatingGrpcChannelProvider.newBuilder()
337             .setAttemptDirectPath(false)
338             .setChannelConfigurator(channelConfigurator)
339             .build()
340             .withExecutor((Executor) executor)
341             .withHeaders(Collections.<String, String>emptyMap())
342             .withEndpoint("localhost:8080");
343 
344     assertThat(provider.needsCredentials()).isTrue();
345     provider = provider.withCredentials(ComputeEngineCredentials.create());
346     assertThat(provider.needsCredentials()).isFalse();
347 
348     provider.getTransportChannel().shutdownNow();
349   }
350 
351   @Test
testWithNoDirectPathFlagSet()352   public void testWithNoDirectPathFlagSet() throws IOException {
353     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
354     executor.shutdown();
355 
356     ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator =
357         channelBuilder -> {
358           // Clients without setting attemptDirectPath flag will not attempt DirectPath
359           assertThat(channelBuilder instanceof ComputeEngineChannelBuilder).isFalse();
360           return channelBuilder;
361         };
362 
363     TransportChannelProvider provider =
364         InstantiatingGrpcChannelProvider.newBuilder()
365             .setChannelConfigurator(channelConfigurator)
366             .build()
367             .withExecutor((Executor) executor)
368             .withHeaders(Collections.<String, String>emptyMap())
369             .withEndpoint("localhost:8080");
370 
371     assertThat(provider.needsCredentials()).isTrue();
372     provider = provider.withCredentials(ComputeEngineCredentials.create());
373     assertThat(provider.needsCredentials()).isFalse();
374 
375     provider.getTransportChannel().shutdownNow();
376   }
377 
378   @Test
testWithIPv6Address()379   public void testWithIPv6Address() throws IOException {
380     ScheduledExecutorService executor = new ScheduledThreadPoolExecutor(1);
381     executor.shutdown();
382 
383     TransportChannelProvider provider =
384         InstantiatingGrpcChannelProvider.newBuilder()
385             .build()
386             .withExecutor((Executor) executor)
387             .withHeaders(Collections.<String, String>emptyMap())
388             .withEndpoint("[::1]:8080");
389     assertThat(provider.needsEndpoint()).isFalse();
390 
391     // Make sure we can create channels OK.
392     provider.getTransportChannel().shutdownNow();
393   }
394 
395   // Test that if ChannelPrimer is provided, it is called during creation
396   @Test
testWithPrimeChannel()397   public void testWithPrimeChannel() throws IOException {
398     // create channelProvider with different pool sizes to verify ChannelPrimer is called the
399     // correct number of times
400     for (int poolSize = 1; poolSize < 5; poolSize++) {
401       final ChannelPrimer mockChannelPrimer = Mockito.mock(ChannelPrimer.class);
402 
403       InstantiatingGrpcChannelProvider provider =
404           InstantiatingGrpcChannelProvider.newBuilder()
405               .setEndpoint("localhost:8080")
406               .setPoolSize(poolSize)
407               .setHeaderProvider(Mockito.mock(HeaderProvider.class))
408               .setExecutor(Mockito.mock(Executor.class))
409               .setChannelPrimer(mockChannelPrimer)
410               .build();
411 
412       provider.getTransportChannel().shutdownNow();
413 
414       // every channel in the pool should call primeChannel during creation.
415       Mockito.verify(mockChannelPrimer, Mockito.times(poolSize))
416           .primeChannel(Mockito.any(ManagedChannel.class));
417     }
418   }
419 
420   @Test
testWithDefaultDirectPathServiceConfig()421   public void testWithDefaultDirectPathServiceConfig() {
422     InstantiatingGrpcChannelProvider provider =
423         InstantiatingGrpcChannelProvider.newBuilder().build();
424 
425     ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
426 
427     List<Map<String, ?>> lbConfigs = getAsObjectList(defaultServiceConfig, "loadBalancingConfig");
428     assertThat(lbConfigs).hasSize(1);
429     Map<String, ?> lbConfig = lbConfigs.get(0);
430     Map<String, ?> grpclb = getAsObject(lbConfig, "grpclb");
431     List<Map<String, ?>> childPolicies = getAsObjectList(grpclb, "childPolicy");
432     assertThat(childPolicies).hasSize(1);
433     Map<String, ?> childPolicy = childPolicies.get(0);
434     assertThat(childPolicy.keySet()).containsExactly("pick_first");
435   }
436 
437   @Nullable
getAsObject(Map<String, ?> json, String key)438   private static Map<String, ?> getAsObject(Map<String, ?> json, String key) {
439     Object mapObject = json.get(key);
440     if (mapObject == null) {
441       return null;
442     }
443     return checkObject(mapObject);
444   }
445 
446   @SuppressWarnings("unchecked")
checkObject(Object json)447   private static Map<String, ?> checkObject(Object json) {
448     checkArgument(json instanceof Map, "Invalid json object representation: %s", json);
449     for (Map.Entry<Object, Object> entry : ((Map<Object, Object>) json).entrySet()) {
450       checkArgument(entry.getKey() instanceof String, "Key is not string");
451     }
452     return (Map<String, ?>) json;
453   }
454 
getAsObjectList(Map<String, ?> json, String key)455   private static List<Map<String, ?>> getAsObjectList(Map<String, ?> json, String key) {
456     Object listObject = json.get(key);
457     if (listObject == null) {
458       return null;
459     }
460     return checkListOfObjects(listObject);
461   }
462 
463   @SuppressWarnings("unchecked")
checkListOfObjects(Object listObject)464   private static List<Map<String, ?>> checkListOfObjects(Object listObject) {
465     checkArgument(listObject instanceof List, "Passed object is not a list");
466     List<Map<String, ?>> list = new ArrayList<>();
467     for (Object object : ((List<Object>) listObject)) {
468       list.add(checkObject(object));
469     }
470     return list;
471   }
472 
473   @Test
testWithCustomDirectPathServiceConfig()474   public void testWithCustomDirectPathServiceConfig() {
475     ImmutableMap<String, Object> pickFirstStrategy =
476         ImmutableMap.<String, Object>of("round_robin", ImmutableMap.of());
477     ImmutableMap<String, Object> childPolicy =
478         ImmutableMap.<String, Object>of(
479             "childPolicy", ImmutableList.of(pickFirstStrategy), "foo", "bar");
480     ImmutableMap<String, Object> grpcLbPolicy =
481         ImmutableMap.<String, Object>of("grpclb", childPolicy);
482     Map<String, Object> passedServiceConfig = new HashMap<>();
483     passedServiceConfig.put("loadBalancingConfig", ImmutableList.of(grpcLbPolicy));
484 
485     InstantiatingGrpcChannelProvider provider =
486         InstantiatingGrpcChannelProvider.newBuilder()
487             .setDirectPathServiceConfig(passedServiceConfig)
488             .build();
489 
490     ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
491     assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
492   }
493 
494   @Override
getMtlsObjectFromTransportChannel(MtlsProvider provider)495   protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
496       throws IOException, GeneralSecurityException {
497     InstantiatingGrpcChannelProvider channelProvider =
498         InstantiatingGrpcChannelProvider.newBuilder()
499             .setEndpoint("localhost:8080")
500             .setMtlsProvider(provider)
501             .setHeaderProvider(Mockito.mock(HeaderProvider.class))
502             .setExecutor(Mockito.mock(Executor.class))
503             .build();
504     return channelProvider.createMtlsChannelCredentials();
505   }
506 }
507