xref: /aosp_15_r20/external/grpc-grpc-java/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java (revision e07d83d3ffcef9ecfc9f7f475418ec639ff0e5fe)
1 /*
2  * Copyright 2015 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.netty;
18 
19 import static com.google.common.base.Charsets.UTF_8;
20 import static com.google.common.base.Preconditions.checkNotNull;
21 import static com.google.common.base.Preconditions.checkState;
22 import static com.google.common.truth.Truth.assertThat;
23 import static org.junit.Assert.assertEquals;
24 import static org.junit.Assert.assertFalse;
25 import static org.junit.Assert.assertNotNull;
26 import static org.junit.Assert.assertNull;
27 import static org.junit.Assert.assertTrue;
28 import static org.mockito.ArgumentMatchers.any;
29 import static org.mockito.Mockito.mock;
30 import static org.mockito.Mockito.timeout;
31 import static org.mockito.Mockito.times;
32 import static org.mockito.Mockito.verify;
33 
34 import io.grpc.Attributes;
35 import io.grpc.CallCredentials;
36 import io.grpc.ChannelCredentials;
37 import io.grpc.ChannelLogger;
38 import io.grpc.ChoiceChannelCredentials;
39 import io.grpc.ChoiceServerCredentials;
40 import io.grpc.CompositeChannelCredentials;
41 import io.grpc.Grpc;
42 import io.grpc.InsecureChannelCredentials;
43 import io.grpc.InsecureServerCredentials;
44 import io.grpc.InternalChannelz;
45 import io.grpc.InternalChannelz.Security;
46 import io.grpc.Metadata;
47 import io.grpc.SecurityLevel;
48 import io.grpc.ServerCredentials;
49 import io.grpc.ServerStreamTracer;
50 import io.grpc.Status;
51 import io.grpc.StatusException;
52 import io.grpc.StatusRuntimeException;
53 import io.grpc.TlsChannelCredentials;
54 import io.grpc.TlsServerCredentials;
55 import io.grpc.internal.ClientTransportFactory;
56 import io.grpc.internal.GrpcAttributes;
57 import io.grpc.internal.InternalServer;
58 import io.grpc.internal.ManagedClientTransport;
59 import io.grpc.internal.ServerListener;
60 import io.grpc.internal.ServerStream;
61 import io.grpc.internal.ServerTransport;
62 import io.grpc.internal.ServerTransportListener;
63 import io.grpc.internal.TestUtils.NoopChannelLogger;
64 import io.grpc.internal.testing.TestUtils;
65 import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler;
66 import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
67 import io.grpc.netty.ProtocolNegotiators.HostPort;
68 import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
69 import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
70 import io.grpc.testing.TlsTesting;
71 import io.netty.bootstrap.Bootstrap;
72 import io.netty.bootstrap.ServerBootstrap;
73 import io.netty.buffer.ByteBuf;
74 import io.netty.buffer.ByteBufUtil;
75 import io.netty.channel.Channel;
76 import io.netty.channel.ChannelDuplexHandler;
77 import io.netty.channel.ChannelFuture;
78 import io.netty.channel.ChannelHandler;
79 import io.netty.channel.ChannelHandlerAdapter;
80 import io.netty.channel.ChannelHandlerContext;
81 import io.netty.channel.ChannelInboundHandler;
82 import io.netty.channel.ChannelInboundHandlerAdapter;
83 import io.netty.channel.ChannelInitializer;
84 import io.netty.channel.ChannelOutboundHandlerAdapter;
85 import io.netty.channel.ChannelPipeline;
86 import io.netty.channel.ChannelPromise;
87 import io.netty.channel.DefaultEventLoop;
88 import io.netty.channel.DefaultEventLoopGroup;
89 import io.netty.channel.EventLoopGroup;
90 import io.netty.channel.embedded.EmbeddedChannel;
91 import io.netty.channel.local.LocalAddress;
92 import io.netty.channel.local.LocalChannel;
93 import io.netty.channel.local.LocalServerChannel;
94 import io.netty.handler.codec.http.HttpServerCodec;
95 import io.netty.handler.codec.http.HttpServerUpgradeHandler;
96 import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec;
97 import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory;
98 import io.netty.handler.codec.http2.DefaultHttp2Connection;
99 import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
100 import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
101 import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
102 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
103 import io.netty.handler.codec.http2.Http2ConnectionDecoder;
104 import io.netty.handler.codec.http2.Http2ConnectionEncoder;
105 import io.netty.handler.codec.http2.Http2ServerUpgradeCodec;
106 import io.netty.handler.codec.http2.Http2Settings;
107 import io.netty.handler.proxy.ProxyConnectException;
108 import io.netty.handler.ssl.ApplicationProtocolConfig;
109 import io.netty.handler.ssl.SslContext;
110 import io.netty.handler.ssl.SslContextBuilder;
111 import io.netty.handler.ssl.SslHandler;
112 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
113 import io.netty.handler.ssl.util.SelfSignedCertificate;
114 import java.io.File;
115 import java.io.InputStream;
116 import java.net.InetSocketAddress;
117 import java.net.SocketAddress;
118 import java.security.KeyStore;
119 import java.security.cert.Certificate;
120 import java.security.cert.X509Certificate;
121 import java.util.ArrayDeque;
122 import java.util.Arrays;
123 import java.util.Collections;
124 import java.util.List;
125 import java.util.Queue;
126 import java.util.concurrent.CountDownLatch;
127 import java.util.concurrent.TimeUnit;
128 import java.util.concurrent.atomic.AtomicReference;
129 import java.util.logging.Filter;
130 import java.util.logging.Level;
131 import java.util.logging.LogRecord;
132 import java.util.logging.Logger;
133 import javax.net.ssl.KeyManagerFactory;
134 import javax.net.ssl.SSLContext;
135 import javax.net.ssl.SSLEngine;
136 import javax.net.ssl.SSLException;
137 import javax.net.ssl.SSLHandshakeException;
138 import javax.net.ssl.SSLSession;
139 import javax.net.ssl.TrustManagerFactory;
140 import org.junit.After;
141 import org.junit.Before;
142 import org.junit.BeforeClass;
143 import org.junit.Rule;
144 import org.junit.Test;
145 import org.junit.rules.DisableOnDebug;
146 import org.junit.rules.ExpectedException;
147 import org.junit.rules.TestRule;
148 import org.junit.rules.Timeout;
149 import org.junit.runner.RunWith;
150 import org.junit.runners.JUnit4;
151 import org.mockito.ArgumentCaptor;
152 import org.mockito.ArgumentMatchers;
153 import org.mockito.Mockito;
154 
155 @RunWith(JUnit4.class)
156 public class ProtocolNegotiatorsTest {
157   private static final Runnable NOOP_RUNNABLE = new Runnable() {
158     @Override public void run() {}
159   };
160 
161   private static File server1Cert;
162   private static File server1Key;
163   private static File caCert;
164 
165   @BeforeClass
loadCerts()166   public static void loadCerts() throws Exception {
167     server1Cert = TestUtils.loadCert("server1.pem");
168     server1Key = TestUtils.loadCert("server1.key");
169     caCert = TestUtils.loadCert("ca.pem");
170   }
171 
172   private static final int TIMEOUT_SECONDS = 60;
173   @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS));
174   @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
175   @Rule public final ExpectedException thrown = ExpectedException.none();
176 
177   private final EventLoopGroup group = new DefaultEventLoop();
178   private Channel chan;
179   private Channel server;
180 
181   private final GrpcHttp2ConnectionHandler grpcHandler =
182       FakeGrpcHttp2ConnectionHandler.newHandler();
183 
184   private EmbeddedChannel channel = new EmbeddedChannel();
185   private ChannelPipeline pipeline = channel.pipeline();
186   private SslContext sslContext;
187   private SSLEngine engine;
188   private ChannelHandlerContext channelHandlerCtx;
189   private static ChannelLogger noopLogger = new NoopChannelLogger();
190 
191   @Before
setUp()192   public void setUp() throws Exception {
193     InputStream serverCert = TlsTesting.loadCert("server1.pem");
194     InputStream key = TlsTesting.loadCert("server1.key");
195     sslContext = GrpcSslContexts.forServer(serverCert, key).build();
196     engine = SSLContext.getDefault().createSSLEngine();
197     engine.setUseClientMode(true);
198   }
199 
200   @After
tearDown()201   public void tearDown() {
202     if (server != null) {
203       server.close();
204     }
205     if (chan != null) {
206       chan.close();
207     }
208     group.shutdownGracefully();
209   }
210 
211   @Test
fromClient_unknown()212   public void fromClient_unknown() {
213     ProtocolNegotiators.FromChannelCredentialsResult result =
214         ProtocolNegotiators.from(new ChannelCredentials() {
215           @Override
216           public ChannelCredentials withoutBearerTokens() {
217             throw new UnsupportedOperationException();
218           }
219         });
220     assertThat(result.error).isNotNull();
221     assertThat(result.callCredentials).isNull();
222     assertThat(result.negotiator).isNull();
223   }
224 
225   @Test
fromClient_tls()226   public void fromClient_tls() {
227     ProtocolNegotiators.FromChannelCredentialsResult result =
228         ProtocolNegotiators.from(TlsChannelCredentials.create());
229     assertThat(result.error).isNull();
230     assertThat(result.callCredentials).isNull();
231     assertThat(result.negotiator)
232         .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
233   }
234 
235   @Test
fromClient_unsupportedTls()236   public void fromClient_unsupportedTls() {
237     ProtocolNegotiators.FromChannelCredentialsResult result =
238         ProtocolNegotiators.from(TlsChannelCredentials.newBuilder().requireFakeFeature().build());
239     assertThat(result.error).contains("FAKE");
240     assertThat(result.callCredentials).isNull();
241     assertThat(result.negotiator).isNull();
242   }
243 
244   @Test
fromClient_insecure()245   public void fromClient_insecure() {
246     ProtocolNegotiators.FromChannelCredentialsResult result =
247         ProtocolNegotiators.from(InsecureChannelCredentials.create());
248     assertThat(result.error).isNull();
249     assertThat(result.callCredentials).isNull();
250     assertThat(result.negotiator)
251         .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
252   }
253 
254   @Test
fromClient_composite()255   public void fromClient_composite() {
256     CallCredentials callCredentials = mock(CallCredentials.class);
257     ProtocolNegotiators.FromChannelCredentialsResult result =
258         ProtocolNegotiators.from(CompositeChannelCredentials.create(
259           TlsChannelCredentials.create(), callCredentials));
260     assertThat(result.error).isNull();
261     assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
262     assertThat(result.negotiator)
263         .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
264 
265     result = ProtocolNegotiators.from(CompositeChannelCredentials.create(
266           InsecureChannelCredentials.create(), callCredentials));
267     assertThat(result.error).isNull();
268     assertThat(result.callCredentials).isSameInstanceAs(callCredentials);
269     assertThat(result.negotiator)
270         .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
271   }
272 
273   @Test
fromClient_netty()274   public void fromClient_netty() {
275     ProtocolNegotiator.ClientFactory factory = mock(ProtocolNegotiator.ClientFactory.class);
276     ProtocolNegotiators.FromChannelCredentialsResult result =
277         ProtocolNegotiators.from(NettyChannelCredentials.create(factory));
278     assertThat(result.error).isNull();
279     assertThat(result.callCredentials).isNull();
280     assertThat(result.negotiator).isSameInstanceAs(factory);
281   }
282 
283   @Test
fromClient_choice()284   public void fromClient_choice() {
285     ProtocolNegotiators.FromChannelCredentialsResult result =
286         ProtocolNegotiators.from(ChoiceChannelCredentials.create(
287           new ChannelCredentials() {
288             @Override
289             public ChannelCredentials withoutBearerTokens() {
290               throw new UnsupportedOperationException();
291             }
292           },
293           TlsChannelCredentials.create(),
294           InsecureChannelCredentials.create()));
295     assertThat(result.error).isNull();
296     assertThat(result.callCredentials).isNull();
297     assertThat(result.negotiator)
298         .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class);
299 
300     result = ProtocolNegotiators.from(ChoiceChannelCredentials.create(
301           InsecureChannelCredentials.create(),
302           new ChannelCredentials() {
303             @Override
304             public ChannelCredentials withoutBearerTokens() {
305               throw new UnsupportedOperationException();
306             }
307           },
308           TlsChannelCredentials.create()));
309     assertThat(result.error).isNull();
310     assertThat(result.callCredentials).isNull();
311     assertThat(result.negotiator)
312         .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class);
313   }
314 
315   @Test
fromClient_choice_unknown()316   public void fromClient_choice_unknown() {
317     ProtocolNegotiators.FromChannelCredentialsResult result =
318         ProtocolNegotiators.from(ChoiceChannelCredentials.create(
319           new ChannelCredentials() {
320             @Override
321             public ChannelCredentials withoutBearerTokens() {
322               throw new UnsupportedOperationException();
323             }
324           }));
325     assertThat(result.error).isNotNull();
326     assertThat(result.callCredentials).isNull();
327     assertThat(result.negotiator).isNull();
328   }
329 
expectSuccessfulHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds)330   private InternalChannelz.Tls expectSuccessfulHandshake(
331       ChannelCredentials channelCreds, ServerCredentials serverCreds) throws Exception {
332     return (InternalChannelz.Tls) expectHandshake(channelCreds, serverCreds, true);
333   }
334 
expectFailedHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds)335   private Status expectFailedHandshake(
336       ChannelCredentials channelCreds, ServerCredentials serverCreds) throws Exception {
337     return (Status) expectHandshake(channelCreds, serverCreds, false);
338   }
339 
expectHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds, boolean expectSuccess)340   private Object expectHandshake(
341       ChannelCredentials channelCreds, ServerCredentials serverCreds, boolean expectSuccess)
342       throws Exception {
343     MockServerListener serverListener = new MockServerListener();
344     ClientTransportFactory clientFactory = NettyChannelBuilder
345         // Although specified here, address is ignored because we never call build.
346         .forAddress("localhost", 0, channelCreds)
347         .buildTransportFactory();
348     InternalServer server = NettyServerBuilder
349         .forPort(0, serverCreds)
350         .buildTransportServers(Collections.<ServerStreamTracer.Factory>emptyList());
351     server.start(serverListener);
352 
353     ManagedClientTransport.Listener clientTransportListener =
354         mock(ManagedClientTransport.Listener.class);
355     ManagedClientTransport client = clientFactory.newClientTransport(
356         server.getListenSocketAddress(),
357         new ClientTransportFactory.ClientTransportOptions()
358           .setAuthority(TestUtils.TEST_SERVER_HOST),
359         mock(ChannelLogger.class));
360     callMeMaybe(client.start(clientTransportListener));
361     Object result;
362     if (expectSuccess) {
363       verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)).transportReady();
364       InternalChannelz.SocketStats stats = serverListener.transports.poll().getStats().get();
365       assertThat(stats.security).isNotNull();
366       assertThat(stats.security.tls).isNotNull();
367       result = stats.security.tls;
368     } else {
369       ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class);
370       verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000))
371           .transportShutdown(captor.capture());
372       result = captor.getValue();
373     }
374 
375     client.shutdownNow(Status.UNAVAILABLE.withDescription("trash it"));
376     server.shutdown();
377     assertTrue(
378         serverListener.waitForShutdown(TIMEOUT_SECONDS * 1000, TimeUnit.MILLISECONDS));
379     verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)).transportTerminated();
380     clientFactory.close();
381     return result;
382   }
383 
384   @Test
from_tls_clientAuthNone_noClientCert()385   public void from_tls_clientAuthNone_noClientCert() throws Exception {
386     // Use convenience API to better match most user's usage
387     ServerCredentials serverCreds = TlsServerCredentials.create(server1Cert, server1Key);
388     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
389         .trustManager(caCert)
390         .build();
391     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
392     assertThat(tls.remoteCert).isNull();
393   }
394 
395   @Test
from_tls_clientAuthNone_clientCert()396   public void from_tls_clientAuthNone_clientCert() throws Exception {
397     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
398         .keyManager(server1Cert, server1Key)
399         .trustManager(caCert)
400         .build();
401     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
402         .keyManager(server1Cert, server1Key)
403         .trustManager(caCert)
404         .build();
405     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
406     assertThat(tls.remoteCert).isNull();
407   }
408 
409   @Test
from_tls_clientAuthRequire_noClientCert()410   public void from_tls_clientAuthRequire_noClientCert() throws Exception {
411     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
412         .keyManager(server1Cert, server1Key)
413         .trustManager(caCert)
414         .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
415         .build();
416     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
417         .trustManager(caCert)
418         .build();
419     Status status = expectFailedHandshake(channelCreds, serverCreds);
420     assertEquals(Status.Code.UNAVAILABLE, status.getCode());
421     StatusException sre = status.asException();
422     // because of netty/netty#11604 we need to check for both TLSv1.2 and v1.3 behaviors
423     if (sre.getCause() instanceof SSLHandshakeException) {
424       assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class);
425       assertThat(sre).hasCauseThat().hasMessageThat().contains("SSLV3_ALERT_HANDSHAKE_FAILURE");
426     } else {
427       // Client cert verification is after handshake in TLSv1.3
428       assertThat(sre).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class);
429       assertThat(sre).hasCauseThat().hasMessageThat().contains("CERTIFICATE_REQUIRED");
430     }
431   }
432 
433   @Test
from_tls_clientAuthRequire_clientCert()434   public void from_tls_clientAuthRequire_clientCert() throws Exception {
435     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
436         .keyManager(server1Cert, server1Key)
437         .trustManager(caCert)
438         .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
439         .build();
440     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
441         .keyManager(server1Cert, server1Key)
442         .trustManager(caCert)
443         .build();
444     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
445     assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName())
446         .contains("CN=*.test.google.com");
447   }
448 
449   @Test
from_tls_clientAuthOptional_noClientCert()450   public void from_tls_clientAuthOptional_noClientCert() throws Exception {
451     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
452         .keyManager(server1Cert, server1Key)
453         .trustManager(caCert)
454         .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL)
455         .build();
456     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
457         .trustManager(caCert)
458         .build();
459     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
460     assertThat(tls.remoteCert).isNull();
461   }
462 
463   @Test
from_tls_clientAuthOptional_clientCert()464   public void from_tls_clientAuthOptional_clientCert() throws Exception {
465     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
466         .keyManager(server1Cert, server1Key)
467         .trustManager(caCert)
468         .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL)
469         .build();
470     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
471         .keyManager(server1Cert, server1Key)
472         .trustManager(caCert)
473         .build();
474     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
475     assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName())
476         .contains("CN=*.test.google.com");
477   }
478 
479   @Test
from_tls_managers()480   public void from_tls_managers() throws Exception {
481     SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST);
482     KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType());
483     keyStore.load(null);
484     keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()});
485     KeyManagerFactory keyManagerFactory =
486         KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
487     keyManagerFactory.init(keyStore, new char[0]);
488 
489     KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType());
490     certStore.load(null);
491     certStore.setCertificateEntry("mycert", cert.cert());
492     TrustManagerFactory trustManagerFactory =
493         TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
494     trustManagerFactory.init(certStore);
495 
496     ServerCredentials serverCreds = TlsServerCredentials.newBuilder()
497         .keyManager(keyManagerFactory.getKeyManagers())
498         .trustManager(trustManagerFactory.getTrustManagers())
499         .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE)
500         .build();
501     ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder()
502         .keyManager(keyManagerFactory.getKeyManagers())
503         .trustManager(trustManagerFactory.getTrustManagers())
504         .build();
505     InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds);
506     assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName())
507         .isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST);
508     cert.delete();
509   }
510 
511   @Test
fromServer_unknown()512   public void fromServer_unknown() {
513     ProtocolNegotiators.FromServerCredentialsResult result =
514         ProtocolNegotiators.from(new ServerCredentials() {});
515     assertThat(result.error).isNotNull();
516     assertThat(result.negotiator).isNull();
517   }
518 
519   @Test
fromServer_tls()520   public void fromServer_tls() throws Exception {
521     ProtocolNegotiators.FromServerCredentialsResult result =
522         ProtocolNegotiators.from(TlsServerCredentials.create(server1Cert, server1Key));
523     assertThat(result.error).isNull();
524     assertThat(result.negotiator)
525         .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class);
526   }
527 
528   @Test
fromServer_unsupportedTls()529   public void fromServer_unsupportedTls() throws Exception {
530     ProtocolNegotiators.FromServerCredentialsResult result = ProtocolNegotiators.from(
531         TlsServerCredentials.newBuilder()
532           .keyManager(server1Cert, server1Key)
533           .requireFakeFeature()
534           .build());
535     assertThat(result.error).contains("FAKE");
536     assertThat(result.negotiator).isNull();
537   }
538 
539   @Test
fromServer_insecure()540   public void fromServer_insecure() {
541     ProtocolNegotiators.FromServerCredentialsResult result =
542         ProtocolNegotiators.from(InsecureServerCredentials.create());
543     assertThat(result.error).isNull();
544     assertThat(result.negotiator)
545         .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class);
546   }
547 
548   @Test
fromServer_netty()549   public void fromServer_netty() {
550     ProtocolNegotiator.ServerFactory factory = mock(ProtocolNegotiator.ServerFactory.class);
551     ProtocolNegotiators.FromServerCredentialsResult result =
552         ProtocolNegotiators.from(NettyServerCredentials.create(factory));
553     assertThat(result.error).isNull();
554     assertThat(result.negotiator).isSameInstanceAs(factory);
555   }
556 
557   @Test
fromServer_choice()558   public void fromServer_choice() throws Exception {
559     ProtocolNegotiators.FromServerCredentialsResult result =
560         ProtocolNegotiators.from(ChoiceServerCredentials.create(
561           new ServerCredentials() {},
562           TlsServerCredentials.create(server1Cert, server1Key),
563           InsecureServerCredentials.create()));
564     assertThat(result.error).isNull();
565     assertThat(result.negotiator)
566         .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class);
567 
568     result = ProtocolNegotiators.from(ChoiceServerCredentials.create(
569           InsecureServerCredentials.create(),
570           new ServerCredentials() {},
571           TlsServerCredentials.create(server1Cert, server1Key)));
572     assertThat(result.error).isNull();
573     assertThat(result.negotiator)
574         .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class);
575   }
576 
577   @Test
fromServer_choice_unknown()578   public void fromServer_choice_unknown() {
579     ProtocolNegotiators.FromServerCredentialsResult result =
580         ProtocolNegotiators.from(ChoiceServerCredentials.create(
581           new ServerCredentials() {}));
582     assertThat(result.error).isNotNull();
583     assertThat(result.negotiator).isNull();
584   }
585 
586 
587   @Test
waitUntilActiveHandler_handlerAdded()588   public void waitUntilActiveHandler_handlerAdded() throws Exception {
589     final CountDownLatch latch = new CountDownLatch(1);
590 
591     final WaitUntilActiveHandler handler =
592         new WaitUntilActiveHandler(new ChannelHandlerAdapter() {
593           @Override
594           public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
595             assertTrue(ctx.channel().isActive());
596             latch.countDown();
597             super.handlerAdded(ctx);
598           }
599         }, noopLogger);
600 
601     ChannelHandler lateAddingHandler = new ChannelInboundHandlerAdapter() {
602       @Override
603       public void channelActive(ChannelHandlerContext ctx) throws Exception {
604         ctx.pipeline().addLast(handler);
605         ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
606         // do not propagate channelActive().
607       }
608     };
609 
610     LocalAddress addr = new LocalAddress("local");
611     ChannelFuture cf = new Bootstrap()
612         .channel(LocalChannel.class)
613         .handler(lateAddingHandler)
614         .group(group)
615         .register();
616     chan = cf.channel();
617     ChannelFuture sf = new ServerBootstrap()
618         .channel(LocalServerChannel.class)
619         .childHandler(new ChannelHandlerAdapter() {})
620         .group(group)
621         .bind(addr);
622     server = sf.channel();
623     sf.sync();
624 
625     assertEquals(1, latch.getCount());
626 
627     chan.connect(addr).sync();
628     assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
629     assertNull(chan.pipeline().context(WaitUntilActiveHandler.class));
630   }
631 
632   @Test
waitUntilActiveHandler_channelActive()633   public void waitUntilActiveHandler_channelActive() throws Exception {
634     final CountDownLatch latch = new CountDownLatch(1);
635     WaitUntilActiveHandler handler =
636         new WaitUntilActiveHandler(new ChannelHandlerAdapter() {
637           @Override
638           public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
639             assertTrue(ctx.channel().isActive());
640             latch.countDown();
641             super.handlerAdded(ctx);
642           }
643         }, noopLogger);
644 
645     LocalAddress addr = new LocalAddress("local");
646     ChannelFuture cf = new Bootstrap()
647         .channel(LocalChannel.class)
648         .handler(handler)
649         .group(group)
650         .register();
651     chan = cf.channel();
652     ChannelFuture sf = new ServerBootstrap()
653         .channel(LocalServerChannel.class)
654         .childHandler(new ChannelHandlerAdapter() {})
655         .group(group)
656         .bind(addr);
657     server = sf.channel();
658     sf.sync();
659 
660     assertEquals(1, latch.getCount());
661 
662     chan.connect(addr).sync();
663     chan.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
664     assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
665     assertNull(chan.pipeline().context(WaitUntilActiveHandler.class));
666   }
667 
668   @Test
tlsHandler_failsOnNullEngine()669   public void tlsHandler_failsOnNullEngine() throws Exception {
670     thrown.expect(NullPointerException.class);
671     thrown.expectMessage("ssl");
672 
673     Object unused = ProtocolNegotiators.serverTls(null);
674   }
675 
676 
677   @Test
tlsHandler_handlerAddedAddsSslHandler()678   public void tlsHandler_handlerAddedAddsSslHandler() throws Exception {
679     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
680 
681     pipeline.addLast(handler);
682 
683     assertTrue(pipeline.first() instanceof SslHandler);
684   }
685 
686   @Test
tlsHandler_userEventTriggeredNonSslEvent()687   public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception {
688     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
689     pipeline.addLast(handler);
690     channelHandlerCtx = pipeline.context(handler);
691     Object nonSslEvent = new Object();
692 
693     pipeline.fireUserEventTriggered(nonSslEvent);
694 
695     // A non ssl event should not cause the grpcHandler to be in the pipeline yet.
696     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
697     assertNull(grpcHandlerCtx);
698   }
699 
700   @Test
tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol()701   public void tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
702     SslHandler badSslHandler = new SslHandler(engine, false) {
703       @Override
704       public String applicationProtocol() {
705         return "badprotocol";
706       }
707     };
708 
709     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
710     pipeline.addLast(handler);
711 
712     final AtomicReference<Throwable> error = new AtomicReference<>();
713     ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
714       @Override
715       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
716         error.set(cause);
717       }
718     };
719 
720     pipeline.addLast(errorCapture);
721 
722     pipeline.replace(SslHandler.class, null, badSslHandler);
723     channelHandlerCtx = pipeline.context(handler);
724     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
725 
726     pipeline.fireUserEventTriggered(sslEvent);
727 
728     // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
729     assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
730     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
731     assertNull(grpcHandlerCtx);
732   }
733 
734   @Test
tlsHandler_userEventTriggeredSslEvent_handshakeFailure()735   public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception {
736     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
737     pipeline.addLast(handler);
738     channelHandlerCtx = pipeline.context(handler);
739     Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad"));
740 
741     final AtomicReference<Throwable> error = new AtomicReference<>();
742     ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
743       @Override
744       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
745         error.set(cause);
746       }
747     };
748 
749     pipeline.addLast(errorCapture);
750 
751     pipeline.fireUserEventTriggered(sslEvent);
752 
753     // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
754     assertThat(error.get()).hasMessageThat().contains("bad");
755     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
756     assertNull(grpcHandlerCtx);
757   }
758 
759   @Test
tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2()760   public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
761     SslHandler goodSslHandler = new SslHandler(engine, false) {
762       @Override
763       public String applicationProtocol() {
764         return "h2";
765       }
766     };
767 
768     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
769     pipeline.addLast(handler);
770 
771     pipeline.replace(SslHandler.class, null, goodSslHandler);
772     channelHandlerCtx = pipeline.context(handler);
773     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
774 
775     pipeline.fireUserEventTriggered(sslEvent);
776 
777     assertTrue(channel.isOpen());
778     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
779     assertNotNull(grpcHandlerCtx);
780   }
781 
782   @Test
serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()783   public void serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
784       throws Exception {
785     SslHandler goodSslHandler = new SslHandler(engine, false) {
786       @Override
787       public String applicationProtocol() {
788         return "managed_mtls";
789       }
790     };
791 
792     InputStream serverCert = TlsTesting.loadCert("server1.pem");
793     InputStream key = TlsTesting.loadCert("server1.key");
794     List<String> alpnList = Arrays.asList("managed_mtls", "h2");
795     ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
796         ApplicationProtocolConfig.Protocol.ALPN,
797         ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
798         ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
799         alpnList);
800 
801     sslContext = GrpcSslContexts.forServer(serverCert, key)
802         .applicationProtocolConfig(apn).build();
803 
804     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
805     pipeline.addLast(handler);
806 
807     pipeline.replace(SslHandler.class, null, goodSslHandler);
808     channelHandlerCtx = pipeline.context(handler);
809     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
810 
811     pipeline.fireUserEventTriggered(sslEvent);
812 
813     assertTrue(channel.isOpen());
814     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
815     assertNotNull(grpcHandlerCtx);
816   }
817 
818   @Test
serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom()819   public void serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom()
820       throws Exception {
821     SslHandler badSslHandler = new SslHandler(engine, false) {
822       @Override
823       public String applicationProtocol() {
824         return "badprotocol";
825       }
826     };
827 
828     InputStream serverCert = TlsTesting.loadCert("server1.pem");
829     InputStream key = TlsTesting.loadCert("server1.key");
830     List<String> alpnList = Arrays.asList("managed_mtls", "h2");
831     ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
832         ApplicationProtocolConfig.Protocol.ALPN,
833         ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
834         ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
835         alpnList);
836 
837     sslContext = GrpcSslContexts.forServer(serverCert, key)
838         .applicationProtocolConfig(apn).build();
839     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
840     pipeline.addLast(handler);
841 
842     final AtomicReference<Throwable> error = new AtomicReference<>();
843     ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
844       @Override
845       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
846         error.set(cause);
847       }
848     };
849 
850     pipeline.addLast(errorCapture);
851 
852     pipeline.replace(SslHandler.class, null, badSslHandler);
853     channelHandlerCtx = pipeline.context(handler);
854     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
855 
856     pipeline.fireUserEventTriggered(sslEvent);
857 
858     // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH)
859     assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
860     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
861     assertNull(grpcHandlerCtx);
862   }
863 
864   @Test
clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2()865   public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception {
866     SslHandler goodSslHandler = new SslHandler(engine, false) {
867       @Override
868       public String applicationProtocol() {
869         return "h2";
870       }
871     };
872     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
873 
874     ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
875         "authority", elg, noopLogger);
876     pipeline.addLast(handler);
877     pipeline.replace(SslHandler.class, null, goodSslHandler);
878     pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
879     channelHandlerCtx = pipeline.context(handler);
880     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
881 
882     pipeline.fireUserEventTriggered(sslEvent);
883 
884     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
885     assertNotNull(grpcHandlerCtx);
886   }
887 
888   @Test
clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()889   public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()
890       throws Exception {
891     SslHandler goodSslHandler = new SslHandler(engine, false) {
892       @Override
893       public String applicationProtocol() {
894         return "managed_mtls";
895       }
896     };
897     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
898 
899     InputStream clientCert = TlsTesting.loadCert("client.pem");
900     InputStream key = TlsTesting.loadCert("client.key");
901     List<String> alpnList = Arrays.asList("managed_mtls", "h2");
902     ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
903         ApplicationProtocolConfig.Protocol.ALPN,
904         ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
905         ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
906         alpnList);
907 
908     sslContext = GrpcSslContexts.forClient()
909         .keyManager(clientCert, key)
910         .applicationProtocolConfig(apn).build();
911 
912     ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
913         "authority", elg, noopLogger);
914     pipeline.addLast(handler);
915     pipeline.replace(SslHandler.class, null, goodSslHandler);
916     pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
917     channelHandlerCtx = pipeline.context(handler);
918     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
919 
920     pipeline.fireUserEventTriggered(sslEvent);
921 
922     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
923     assertNotNull(grpcHandlerCtx);
924   }
925 
926   @Test
clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol()927   public void clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception {
928     SslHandler goodSslHandler = new SslHandler(engine, false) {
929       @Override
930       public String applicationProtocol() {
931         return "badproto";
932       }
933     };
934     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
935 
936     ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
937         "authority", elg, noopLogger);
938     pipeline.addLast(handler);
939 
940     final AtomicReference<Throwable> error = new AtomicReference<>();
941     ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() {
942       @Override
943       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
944         error.set(cause);
945       }
946     };
947 
948     pipeline.addLast(errorCapture);
949     pipeline.replace(SslHandler.class, null, goodSslHandler);
950     pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
951     channelHandlerCtx = pipeline.context(handler);
952     Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
953 
954     pipeline.fireUserEventTriggered(sslEvent);
955 
956     // Bad protocol was specified, so there should be an error, (normally handled by WBAEH)
957     assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol");
958     ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler);
959     assertNull(grpcHandlerCtx);
960   }
961 
962   @Test
clientTlsHandler_closeDuringNegotiation()963   public void clientTlsHandler_closeDuringNegotiation() throws Exception {
964     ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
965         "authority", null, noopLogger);
966     pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
967     ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
968 
969     // SslHandler fires userEventTriggered() before channelInactive()
970     pipeline.fireChannelInactive();
971 
972     assertThat(pendingWrite.cause()).isInstanceOf(StatusRuntimeException.class);
973     assertThat(Status.fromThrowable(pendingWrite.cause()).getCode())
974         .isEqualTo(Status.Code.UNAVAILABLE);
975   }
976 
977   @Test
engineLog()978   public void engineLog() {
979     ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null);
980     pipeline.addLast(handler);
981     channelHandlerCtx = pipeline.context(handler);
982 
983     Logger logger = Logger.getLogger(ProtocolNegotiators.class.getName());
984     Filter oldFilter = logger.getFilter();
985     try {
986       logger.setFilter(new Filter() {
987         @Override
988         public boolean isLoggable(LogRecord record) {
989           // We still want to the log method to be exercised, just not printed to stderr.
990           return false;
991         }
992       });
993 
994       ProtocolNegotiators.logSslEngineDetails(
995           Level.INFO, channelHandlerCtx, "message", new Exception("bad"));
996     } finally {
997       logger.setFilter(oldFilter);
998     }
999   }
1000 
1001   @Test
tls_failsOnNullSslContext()1002   public void tls_failsOnNullSslContext() {
1003     thrown.expect(NullPointerException.class);
1004 
1005     Object unused = ProtocolNegotiators.tls(null);
1006   }
1007 
1008   @Test
tls_hostAndPort()1009   public void tls_hostAndPort() {
1010     HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");
1011 
1012     assertEquals("authority", hostPort.host);
1013     assertEquals(1234, hostPort.port);
1014   }
1015 
1016   @Test
tls_host()1017   public void tls_host() {
1018     HostPort hostPort = ProtocolNegotiators.parseAuthority("[::1]");
1019 
1020     assertEquals("[::1]", hostPort.host);
1021     assertEquals(-1, hostPort.port);
1022   }
1023 
1024   @Test
tls_invalidHost()1025   public void tls_invalidHost() throws SSLException {
1026     HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_host:1234");
1027 
1028     // Even though it looks like a port, we treat it as part of the authority, since the host is
1029     // invalid.
1030     assertEquals("bad_host:1234", hostPort.host);
1031     assertEquals(-1, hostPort.port);
1032   }
1033 
1034   @Test
httpProxy_nullAddressNpe()1035   public void httpProxy_nullAddressNpe() throws Exception {
1036     thrown.expect(NullPointerException.class);
1037     Object unused =
1038         ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext());
1039   }
1040 
1041   @Test
httpProxy_nullNegotiatorNpe()1042   public void httpProxy_nullNegotiatorNpe() throws Exception {
1043     thrown.expect(NullPointerException.class);
1044     Object unused = ProtocolNegotiators.httpProxy(
1045         InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null);
1046   }
1047 
1048   @Test
httpProxy_nullUserPassNoException()1049   public void httpProxy_nullUserPassNoException() throws Exception {
1050     assertNotNull(ProtocolNegotiators.httpProxy(
1051         InetSocketAddress.createUnresolved("localhost", 80), null, null,
1052         ProtocolNegotiators.plaintext()));
1053   }
1054 
1055   @Test
httpProxy_completes()1056   public void httpProxy_completes() throws Exception {
1057     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
1058     // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
1059     // the channel is already active.
1060     LocalAddress proxy = new LocalAddress("httpProxy_completes");
1061     SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);
1062 
1063     ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
1064     Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
1065         .childHandler(mockHandler)
1066         .bind(proxy).sync().channel();
1067 
1068     ProtocolNegotiator nego =
1069         ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
1070     // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation,
1071     // mocking the behavior using KickStartHandler.
1072     ChannelHandler handler =
1073         new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()));
1074     Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
1075         .register().sync().channel();
1076     pipeline = channel.pipeline();
1077     // Wait for initialization to complete
1078     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
1079     channel.connect(host).sync();
1080     serverChannel.close();
1081     ArgumentCaptor<ChannelHandlerContext> contextCaptor =
1082         ArgumentCaptor.forClass(ChannelHandlerContext.class);
1083     Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
1084     ChannelHandlerContext serverContext = contextCaptor.getValue();
1085 
1086     final String golden = "isThisThingOn?";
1087     ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));
1088 
1089     // Wait for sending initial request to complete
1090     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
1091     ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
1092     Mockito.verify(mockHandler)
1093         .channelRead(ArgumentMatchers.<ChannelHandlerContext>any(), objectCaptor.capture());
1094     ByteBuf b = (ByteBuf) objectCaptor.getValue();
1095     String request = b.toString(UTF_8);
1096     b.release();
1097     assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n"));
1098     assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 "));
1099     assertTrue("No host header: " + request, request.contains("host: specialHost:314"));
1100 
1101     assertFalse(negotiationFuture.isDone());
1102     serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync();
1103     negotiationFuture.sync();
1104 
1105     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
1106     objectCaptor = ArgumentCaptor.forClass(Object.class);
1107     Mockito.verify(mockHandler, times(2))
1108         .channelRead(ArgumentMatchers.<ChannelHandlerContext>any(), objectCaptor.capture());
1109     b = (ByteBuf) objectCaptor.getAllValues().get(1);
1110     // If we were using the real grpcHandler, this would have been the HTTP/2 preface
1111     String preface = b.toString(UTF_8);
1112     b.release();
1113     assertEquals(golden, preface);
1114 
1115     channel.close();
1116   }
1117 
1118   @Test
httpProxy_500()1119   public void httpProxy_500() throws Exception {
1120     DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
1121     // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
1122     // the channel is already active.
1123     LocalAddress proxy = new LocalAddress("httpProxy_500");
1124     SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);
1125 
1126     ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
1127     Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
1128         .childHandler(mockHandler)
1129         .bind(proxy).sync().channel();
1130 
1131     ProtocolNegotiator nego =
1132         ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
1133     // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation,
1134     // mocking the behavior using KickStartHandler.
1135     ChannelHandler handler =
1136         new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()));
1137     Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
1138         .register().sync().channel();
1139     pipeline = channel.pipeline();
1140     // Wait for initialization to complete
1141     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
1142     channel.connect(host).sync();
1143     serverChannel.close();
1144     ArgumentCaptor<ChannelHandlerContext> contextCaptor =
1145         ArgumentCaptor.forClass(ChannelHandlerContext.class);
1146     Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
1147     ChannelHandlerContext serverContext = contextCaptor.getValue();
1148 
1149     final String golden = "isThisThingOn?";
1150     ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));
1151 
1152     // Wait for sending initial request to complete
1153     channel.eventLoop().submit(NOOP_RUNNABLE).sync();
1154     ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
1155     Mockito.verify(mockHandler)
1156         .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
1157     ByteBuf request = (ByteBuf) objectCaptor.getValue();
1158     request.release();
1159 
1160     assertFalse(negotiationFuture.isDone());
1161     String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops";
1162     serverContext.writeAndFlush(bb(response, serverContext.channel())).sync();
1163     thrown.expect(ProxyConnectException.class);
1164     try {
1165       negotiationFuture.sync();
1166     } finally {
1167       channel.close();
1168     }
1169   }
1170 
1171   @Test
waitUntilActiveHandler_firesNegotiation()1172   public void waitUntilActiveHandler_firesNegotiation() throws Exception {
1173     EventLoopGroup elg = new DefaultEventLoopGroup(1);
1174     SocketAddress addr = new LocalAddress("addr");
1175     final AtomicReference<Object> event = new AtomicReference<>();
1176     ChannelHandler next = new ChannelInboundHandlerAdapter() {
1177       @Override
1178       public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
1179         event.set(evt);
1180         ctx.close();
1181       }
1182     };
1183     Channel s = new ServerBootstrap()
1184         .childHandler(new ChannelInboundHandlerAdapter())
1185         .group(elg)
1186         .channel(LocalServerChannel.class)
1187         .bind(addr)
1188         .sync()
1189         .channel();
1190     Channel c = new Bootstrap()
1191         .handler(new WaitUntilActiveHandler(next, noopLogger))
1192         .channel(LocalChannel.class).group(group)
1193         .connect(addr)
1194         .sync()
1195         .channel();
1196     c.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
1197     SocketAddress localAddr = c.localAddress();
1198     ProtocolNegotiationEvent expectedEvent = ProtocolNegotiationEvent.DEFAULT
1199         .withAttributes(
1200             Attributes.newBuilder()
1201                 .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddr)
1202                 .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, addr)
1203                 .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE)
1204                 .build());
1205 
1206     c.closeFuture().sync();
1207     assertThat(event.get()).isInstanceOf(ProtocolNegotiationEvent.class);
1208     ProtocolNegotiationEvent actual = (ProtocolNegotiationEvent) event.get();
1209     assertThat(actual).isEqualTo(expectedEvent);
1210 
1211     s.close();
1212     elg.shutdownGracefully();
1213   }
1214 
1215   @Test
clientTlsHandler_firesNegotiation()1216   public void clientTlsHandler_firesNegotiation() throws Exception {
1217     SelfSignedCertificate cert = new SelfSignedCertificate("authority");
1218     SslContext clientSslContext =
1219         GrpcSslContexts.configure(SslContextBuilder.forClient().trustManager(cert.cert())).build();
1220     SslContext serverSslContext =
1221         GrpcSslContexts.configure(SslContextBuilder.forServer(cert.key(), cert.cert())).build();
1222     FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
1223     ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
1224     WriteBufferingAndExceptionHandler clientWbaeh =
1225         new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
1226 
1227     SocketAddress addr = new LocalAddress("addr");
1228 
1229     ChannelHandler sh =
1230         ProtocolNegotiators.serverTls(serverSslContext)
1231             .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler());
1232     WriteBufferingAndExceptionHandler serverWbaeh = new WriteBufferingAndExceptionHandler(sh);
1233     Channel s = new ServerBootstrap()
1234         .childHandler(serverWbaeh)
1235         .group(group)
1236         .channel(LocalServerChannel.class)
1237         .bind(addr)
1238         .sync()
1239         .channel();
1240     Channel c = new Bootstrap()
1241         .handler(clientWbaeh)
1242         .channel(LocalChannel.class)
1243         .group(group)
1244         .register()
1245         .sync()
1246         .channel();
1247     ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
1248     c.connect(addr).sync();
1249     write.sync();
1250 
1251     boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
1252     if (!completed) {
1253       assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
1254       // sync should fail if we are in this block.
1255       write.sync();
1256       throw new AssertionError("neither wrote nor negotiated");
1257     }
1258     c.close();
1259     s.close();
1260     pn.close();
1261 
1262     assertThat(gh.securityInfo).isNotNull();
1263     assertThat(gh.securityInfo.tls).isNotNull();
1264     assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL))
1265         .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
1266     assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_SSL_SESSION)).isInstanceOf(SSLSession.class);
1267     // This is not part of the ClientTls negotiation, but shows that the negotiation event happens
1268     // in the right order.
1269     assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
1270   }
1271 
1272   @Test
plaintextUpgradeNegotiator()1273   public void plaintextUpgradeNegotiator() throws Exception {
1274     LocalAddress addr = new LocalAddress("plaintextUpgradeNegotiator");
1275     UpgradeCodecFactory ucf = new UpgradeCodecFactory() {
1276 
1277       @Override
1278       public UpgradeCodec newUpgradeCodec(CharSequence protocol) {
1279         return new Http2ServerUpgradeCodec(FakeGrpcHttp2ConnectionHandler.newHandler());
1280       }
1281     };
1282     final HttpServerCodec serverCodec = new HttpServerCodec();
1283     final HttpServerUpgradeHandler serverUpgradeHandler =
1284         new HttpServerUpgradeHandler(serverCodec, ucf);
1285     Channel serverChannel = new ServerBootstrap()
1286         .group(group)
1287         .channel(LocalServerChannel.class)
1288         .childHandler(new ChannelInitializer<Channel>() {
1289 
1290           @Override
1291           protected void initChannel(Channel ch) throws Exception {
1292             ch.pipeline().addLast(serverCodec, serverUpgradeHandler);
1293           }
1294         })
1295         .bind(addr)
1296         .sync()
1297         .channel();
1298 
1299     FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
1300     ProtocolNegotiator nego = ProtocolNegotiators.plaintextUpgrade();
1301     ChannelHandler ch = nego.newHandler(gh);
1302     WriteBufferingAndExceptionHandler wbaeh = new WriteBufferingAndExceptionHandler(ch);
1303 
1304     Channel channel = new Bootstrap()
1305         .group(group)
1306         .channel(LocalChannel.class)
1307         .handler(wbaeh)
1308         .register()
1309         .sync()
1310         .channel();
1311 
1312     ChannelFuture write = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
1313     channel.connect(serverChannel.localAddress());
1314 
1315     boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
1316     if (!completed) {
1317       assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS));
1318       // sync should fail if we are in this block.
1319       write.sync();
1320       throw new AssertionError("neither wrote nor negotiated");
1321     }
1322 
1323     channel.close().sync();
1324     serverChannel.close();
1325 
1326     assertThat(gh.securityInfo).isNull();
1327     assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)).isEqualTo(SecurityLevel.NONE);
1328     assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
1329   }
1330 
callMeMaybe(Runnable runnable)1331   private static void callMeMaybe(Runnable runnable) {
1332     if (runnable != null) {
1333       runnable.run();
1334     }
1335   }
1336 
1337   private static class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
1338 
noopHandler()1339     static FakeGrpcHttp2ConnectionHandler noopHandler() {
1340       return newHandler(true);
1341     }
1342 
newHandler()1343     static FakeGrpcHttp2ConnectionHandler newHandler() {
1344       return newHandler(false);
1345     }
1346 
newHandler(boolean noop)1347     private static FakeGrpcHttp2ConnectionHandler newHandler(boolean noop) {
1348       DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
1349       DefaultHttp2ConnectionEncoder encoder =
1350           new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
1351       DefaultHttp2ConnectionDecoder decoder =
1352           new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader());
1353       Http2Settings settings = new Http2Settings();
1354       return new FakeGrpcHttp2ConnectionHandler(
1355           /*channelUnused=*/ null, decoder, encoder, settings, noop, noopLogger);
1356     }
1357 
1358     private final boolean noop;
1359     private Attributes attrs;
1360     private Security securityInfo;
1361     private final CountDownLatch negotiated = new CountDownLatch(1);
1362     private ChannelHandlerContext ctx;
1363 
FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused, Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings, boolean noop, ChannelLogger negotiationLogger)1364     FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused,
1365         Http2ConnectionDecoder decoder,
1366         Http2ConnectionEncoder encoder,
1367         Http2Settings initialSettings,
1368         boolean noop,
1369         ChannelLogger negotiationLogger) {
1370       super(channelUnused, decoder, encoder, initialSettings, negotiationLogger);
1371       this.noop = noop;
1372     }
1373 
1374     @Override
handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo)1375     public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) {
1376       checkNotNull(ctx, "handleProtocolNegotiationCompleted cannot be called before handlerAdded");
1377       super.handleProtocolNegotiationCompleted(attrs, securityInfo);
1378       this.attrs = attrs;
1379       this.securityInfo = securityInfo;
1380       // Add a temp handler that verifies first message is a NOOP_MESSAGE
1381       ctx.pipeline().addBefore(ctx.name(), null, new ChannelOutboundHandlerAdapter() {
1382         @Override
1383         public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
1384             throws Exception {
1385           checkState(
1386               msg == NettyClientHandler.NOOP_MESSAGE, "First message should be NOOP_MESSAGE");
1387           promise.trySuccess();
1388           ctx.pipeline().remove(this);
1389         }
1390       });
1391       NettyClientHandler.writeBufferingAndRemove(ctx.channel());
1392       negotiated.countDown();
1393     }
1394 
1395     @Override
handlerAdded(ChannelHandlerContext ctx)1396     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
1397       if (noop) {
1398         ctx.pipeline().remove(ctx.name());
1399       } else {
1400         super.handlerAdded(ctx);
1401       }
1402       this.ctx = ctx;
1403     }
1404 
1405     @Override
getAuthority()1406     public String getAuthority() {
1407       return "authority";
1408     }
1409   }
1410 
bb(String s, Channel c)1411   private static ByteBuf bb(String s, Channel c) {
1412     return ByteBufUtil.writeUtf8(c.alloc(), s);
1413   }
1414 
1415   private static final class KickStartHandler extends ChannelDuplexHandler {
1416 
1417     private final ChannelHandler next;
1418 
KickStartHandler(ChannelHandler next)1419     public KickStartHandler(ChannelHandler next) {
1420       this.next = checkNotNull(next, "next");
1421     }
1422 
1423     @Override
handlerAdded(ChannelHandlerContext ctx)1424     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
1425       ctx.pipeline().replace(ctx.name(), null, next);
1426       ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
1427     }
1428   }
1429 
1430   private static class MockServerListener implements ServerListener {
1431     private final CountDownLatch latch = new CountDownLatch(1);
1432     public Queue<ServerTransport> transports = new ArrayDeque<>();
1433 
1434     @Override
transportCreated(ServerTransport transport)1435     public ServerTransportListener transportCreated(ServerTransport transport) {
1436       transports.add(transport);
1437       return new MockServerTransportListener();
1438     }
1439 
1440     @Override
serverShutdown()1441     public void serverShutdown() {
1442       latch.countDown();
1443     }
1444 
waitForShutdown(long timeout, TimeUnit unit)1445     public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException {
1446       return latch.await(timeout, unit);
1447     }
1448   }
1449 
1450   private static class MockServerTransportListener implements ServerTransportListener {
1451     @Override
streamCreated(ServerStream stream, String method, Metadata headers)1452     public void streamCreated(ServerStream stream, String method, Metadata headers) {}
1453 
1454     @Override
transportReady(Attributes attributes)1455     public Attributes transportReady(Attributes attributes) {
1456       return attributes;
1457     }
1458 
1459     @Override
transportTerminated()1460     public void transportTerminated() {}
1461   }
1462 }
1463