1 /*
2  * Copyright 2022 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.gcp.observability;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static org.junit.Assert.fail;
21 import static org.mockito.AdditionalAnswers.delegatesTo;
22 import static org.mockito.Mockito.doReturn;
23 import static org.mockito.Mockito.mock;
24 import static org.mockito.Mockito.reset;
25 import static org.mockito.Mockito.verify;
26 import static org.mockito.Mockito.verifyNoInteractions;
27 import static org.mockito.Mockito.when;
28 
29 import io.grpc.CallOptions;
30 import io.grpc.Channel;
31 import io.grpc.ClientCall;
32 import io.grpc.ClientInterceptor;
33 import io.grpc.InternalGlobalInterceptors;
34 import io.grpc.Metadata;
35 import io.grpc.MethodDescriptor;
36 import io.grpc.ServerCall;
37 import io.grpc.ServerCallHandler;
38 import io.grpc.ServerInterceptor;
39 import io.grpc.StaticTestingClassLoader;
40 import io.grpc.gcp.observability.interceptors.ConditionalClientInterceptor;
41 import io.grpc.gcp.observability.interceptors.InternalLoggingChannelInterceptor;
42 import io.grpc.gcp.observability.interceptors.InternalLoggingServerInterceptor;
43 import io.grpc.gcp.observability.logging.Sink;
44 import io.opencensus.trace.samplers.Samplers;
45 import java.io.IOException;
46 import java.util.List;
47 import java.util.regex.Pattern;
48 import org.junit.Test;
49 import org.junit.runner.RunWith;
50 import org.junit.runners.JUnit4;
51 
52 @RunWith(JUnit4.class)
53 public class GcpObservabilityTest {
54 
55   private final StaticTestingClassLoader classLoader =
56       new StaticTestingClassLoader(
57           getClass().getClassLoader(),
58           Pattern.compile(
59               "io\\.grpc\\.InternalGlobalInterceptors|io\\.grpc\\.GlobalInterceptors|"
60                   + "io\\.grpc\\.gcp\\.observability\\.[^.]+|"
61                   + "io\\.grpc\\.gcp\\.observability\\.interceptors\\.[^.]+|"
62                   + "io\\.grpc\\.gcp\\.observability\\.GcpObservabilityTest\\$.*"));
63 
64   @Test
initFinish()65   public void initFinish() throws Exception {
66     Class<?> runnable =
67         classLoader.loadClass(StaticTestingClassInitFinish.class.getName());
68     ((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
69   }
70 
71   @Test
enableObservability()72   public void enableObservability() throws Exception {
73     Class<?> runnable =
74         classLoader.loadClass(StaticTestingClassEnableObservability.class.getName());
75     ((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
76   }
77 
78   @Test
disableObservability()79   public void disableObservability() throws Exception {
80     Class<?> runnable =
81         classLoader.loadClass(StaticTestingClassDisableObservability.class.getName());
82     ((Runnable) runnable.getDeclaredConstructor().newInstance()).run();
83   }
84 
85   @Test
86   @SuppressWarnings("unchecked")
conditionalInterceptor()87   public void conditionalInterceptor() {
88     ClientInterceptor delegate = mock(ClientInterceptor.class);
89     Channel channel = mock(Channel.class);
90     ClientCall<?, ?> returnedCall = mock(ClientCall.class);
91 
92     ConditionalClientInterceptor conditionalClientInterceptor
93         = GcpObservability.getConditionalInterceptor(
94         delegate);
95     MethodDescriptor<?, ?> method = MethodDescriptor.newBuilder()
96         .setType(MethodDescriptor.MethodType.UNARY)
97         .setFullMethodName("google.logging.v2.LoggingServiceV2/method")
98         .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class))
99         .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class))
100         .build();
101     doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT);
102     ClientCall<?, ?> clientCall = conditionalClientInterceptor.interceptCall(method,
103         CallOptions.DEFAULT, channel);
104     verifyNoInteractions(delegate);
105     assertThat(clientCall).isSameInstanceAs(returnedCall);
106 
107     method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY)
108         .setFullMethodName("google.monitoring.v3.MetricService/method2")
109         .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class))
110         .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class))
111         .build();
112     doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT);
113     clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel);
114     verifyNoInteractions(delegate);
115     assertThat(clientCall).isSameInstanceAs(returnedCall);
116 
117     method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY)
118         .setFullMethodName("google.devtools.cloudtrace.v2.TraceService/method3")
119         .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class))
120         .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class))
121         .build();
122     doReturn(returnedCall).when(channel).newCall(method, CallOptions.DEFAULT);
123     clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel);
124     verifyNoInteractions(delegate);
125     assertThat(clientCall).isSameInstanceAs(returnedCall);
126 
127     reset(channel);
128     ClientCall<?, ?> interceptedCall = mock(ClientCall.class);
129     method = MethodDescriptor.newBuilder().setType(MethodDescriptor.MethodType.UNARY)
130         .setFullMethodName("some.other.random.service/method4")
131         .setRequestMarshaller(mock(MethodDescriptor.Marshaller.class))
132         .setResponseMarshaller(mock(MethodDescriptor.Marshaller.class))
133         .build();
134     doReturn(interceptedCall).when(delegate).interceptCall(method, CallOptions.DEFAULT, channel);
135     clientCall = conditionalClientInterceptor.interceptCall(method, CallOptions.DEFAULT, channel);
136     verifyNoInteractions(channel);
137     assertThat(clientCall).isSameInstanceAs(interceptedCall);
138   }
139 
140   // UsedReflectively
141   public static final class StaticTestingClassInitFinish implements Runnable {
142 
143     @Override
run()144     public void run() {
145       Sink sink = mock(Sink.class);
146       ObservabilityConfig config = mock(ObservabilityConfig.class);
147       InternalLoggingChannelInterceptor.Factory channelInterceptorFactory =
148           mock(InternalLoggingChannelInterceptor.Factory.class);
149       InternalLoggingServerInterceptor.Factory serverInterceptorFactory =
150           mock(InternalLoggingServerInterceptor.Factory.class);
151       GcpObservability observability1;
152       try {
153         GcpObservability observability =
154             GcpObservability.grpcInit(
155               sink, config, channelInterceptorFactory, serverInterceptorFactory);
156         observability1 =
157             GcpObservability.grpcInit(
158                 sink, config, channelInterceptorFactory, serverInterceptorFactory);
159         assertThat(observability1).isSameInstanceAs(observability);
160         observability.close();
161         verify(sink).close();
162         try {
163           observability1.close();
164           fail("should have failed for calling close() second time");
165         } catch (IllegalStateException e) {
166           assertThat(e).hasMessageThat().contains("GcpObservability already closed!");
167         }
168       } catch (IOException e) {
169         fail("Encountered exception: " + e);
170       }
171     }
172   }
173 
174   public static final class StaticTestingClassEnableObservability implements Runnable {
175 
176     @Override
run()177     public void run() {
178       Sink sink = mock(Sink.class);
179       ObservabilityConfig config = mock(ObservabilityConfig.class);
180       when(config.isEnableCloudLogging()).thenReturn(true);
181       when(config.isEnableCloudMonitoring()).thenReturn(true);
182       when(config.isEnableCloudTracing()).thenReturn(true);
183       when(config.getSampler()).thenReturn(Samplers.neverSample());
184 
185       ClientInterceptor clientInterceptor =
186           mock(ClientInterceptor.class, delegatesTo(new NoopClientInterceptor()));
187       InternalLoggingChannelInterceptor.Factory channelInterceptorFactory =
188           mock(InternalLoggingChannelInterceptor.Factory.class);
189       when(channelInterceptorFactory.create()).thenReturn(clientInterceptor);
190 
191       ServerInterceptor serverInterceptor =
192           mock(ServerInterceptor.class, delegatesTo(new NoopServerInterceptor()));
193       InternalLoggingServerInterceptor.Factory serverInterceptorFactory =
194           mock(InternalLoggingServerInterceptor.Factory.class);
195       when(serverInterceptorFactory.create()).thenReturn(serverInterceptor);
196 
197       try (GcpObservability unused =
198           GcpObservability.grpcInit(
199               sink, config, channelInterceptorFactory, serverInterceptorFactory)) {
200         List<ClientInterceptor> list = InternalGlobalInterceptors.getClientInterceptors();
201         assertThat(list).hasSize(3);
202         assertThat(list.get(1)).isInstanceOf(ConditionalClientInterceptor.class);
203         assertThat(list.get(2)).isInstanceOf(ConditionalClientInterceptor.class);
204         assertThat(InternalGlobalInterceptors.getServerInterceptors()).hasSize(1);
205         assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).hasSize(2);
206       } catch (Exception e) {
207         fail("Encountered exception: " + e);
208       }
209     }
210   }
211 
212   public static final class StaticTestingClassDisableObservability implements Runnable {
213 
214     @Override
run()215     public void run() {
216       Sink sink = mock(Sink.class);
217       ObservabilityConfig config = mock(ObservabilityConfig.class);
218       when(config.isEnableCloudLogging()).thenReturn(false);
219       when(config.isEnableCloudMonitoring()).thenReturn(false);
220       when(config.isEnableCloudTracing()).thenReturn(false);
221       when(config.getSampler()).thenReturn(Samplers.neverSample());
222 
223       InternalLoggingChannelInterceptor.Factory channelInterceptorFactory =
224           mock(InternalLoggingChannelInterceptor.Factory.class);
225       InternalLoggingServerInterceptor.Factory serverInterceptorFactory =
226           mock(InternalLoggingServerInterceptor.Factory.class);
227 
228       try (GcpObservability unused =
229           GcpObservability.grpcInit(
230               sink, config, channelInterceptorFactory, serverInterceptorFactory)) {
231         assertThat(InternalGlobalInterceptors.getClientInterceptors()).isEmpty();
232         assertThat(InternalGlobalInterceptors.getServerInterceptors()).isEmpty();
233         assertThat(InternalGlobalInterceptors.getServerStreamTracerFactories()).isEmpty();
234       } catch (Exception e) {
235         fail("Encountered exception: " + e);
236       }
237       verify(sink).close();
238     }
239   }
240 
241   private static class NoopClientInterceptor implements ClientInterceptor {
242     @Override
interceptCall( MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next)243     public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
244         MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
245       return next.newCall(method, callOptions);
246     }
247   }
248 
249   private static class NoopServerInterceptor implements ServerInterceptor {
250     @Override
interceptCall( ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next)251     public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
252         ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
253       return next.startCall(call, headers);
254     }
255   }
256 }
257