1 /* 2 * Copyright 2015 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.netty; 18 19 import static com.google.common.base.Charsets.UTF_8; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 import static com.google.common.base.Preconditions.checkState; 22 import static com.google.common.truth.Truth.assertThat; 23 import static org.junit.Assert.assertEquals; 24 import static org.junit.Assert.assertFalse; 25 import static org.junit.Assert.assertNotNull; 26 import static org.junit.Assert.assertNull; 27 import static org.junit.Assert.assertTrue; 28 import static org.mockito.ArgumentMatchers.any; 29 import static org.mockito.Mockito.mock; 30 import static org.mockito.Mockito.timeout; 31 import static org.mockito.Mockito.times; 32 import static org.mockito.Mockito.verify; 33 34 import io.grpc.Attributes; 35 import io.grpc.CallCredentials; 36 import io.grpc.ChannelCredentials; 37 import io.grpc.ChannelLogger; 38 import io.grpc.ChoiceChannelCredentials; 39 import io.grpc.ChoiceServerCredentials; 40 import io.grpc.CompositeChannelCredentials; 41 import io.grpc.Grpc; 42 import io.grpc.InsecureChannelCredentials; 43 import io.grpc.InsecureServerCredentials; 44 import io.grpc.InternalChannelz; 45 import io.grpc.InternalChannelz.Security; 46 import io.grpc.Metadata; 47 import io.grpc.SecurityLevel; 48 import io.grpc.ServerCredentials; 49 import io.grpc.ServerStreamTracer; 50 import io.grpc.Status; 51 import io.grpc.StatusException; 52 import io.grpc.StatusRuntimeException; 53 import io.grpc.TlsChannelCredentials; 54 import io.grpc.TlsServerCredentials; 55 import io.grpc.internal.ClientTransportFactory; 56 import io.grpc.internal.GrpcAttributes; 57 import io.grpc.internal.InternalServer; 58 import io.grpc.internal.ManagedClientTransport; 59 import io.grpc.internal.ServerListener; 60 import io.grpc.internal.ServerStream; 61 import io.grpc.internal.ServerTransport; 62 import io.grpc.internal.ServerTransportListener; 63 import io.grpc.internal.TestUtils.NoopChannelLogger; 64 import io.grpc.internal.testing.TestUtils; 65 import io.grpc.netty.ProtocolNegotiators.ClientTlsHandler; 66 import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator; 67 import io.grpc.netty.ProtocolNegotiators.HostPort; 68 import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler; 69 import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler; 70 import io.grpc.testing.TlsTesting; 71 import io.netty.bootstrap.Bootstrap; 72 import io.netty.bootstrap.ServerBootstrap; 73 import io.netty.buffer.ByteBuf; 74 import io.netty.buffer.ByteBufUtil; 75 import io.netty.channel.Channel; 76 import io.netty.channel.ChannelDuplexHandler; 77 import io.netty.channel.ChannelFuture; 78 import io.netty.channel.ChannelHandler; 79 import io.netty.channel.ChannelHandlerAdapter; 80 import io.netty.channel.ChannelHandlerContext; 81 import io.netty.channel.ChannelInboundHandler; 82 import io.netty.channel.ChannelInboundHandlerAdapter; 83 import io.netty.channel.ChannelInitializer; 84 import io.netty.channel.ChannelOutboundHandlerAdapter; 85 import io.netty.channel.ChannelPipeline; 86 import io.netty.channel.ChannelPromise; 87 import io.netty.channel.DefaultEventLoop; 88 import io.netty.channel.DefaultEventLoopGroup; 89 import io.netty.channel.EventLoopGroup; 90 import io.netty.channel.embedded.EmbeddedChannel; 91 import io.netty.channel.local.LocalAddress; 92 import io.netty.channel.local.LocalChannel; 93 import io.netty.channel.local.LocalServerChannel; 94 import io.netty.handler.codec.http.HttpServerCodec; 95 import io.netty.handler.codec.http.HttpServerUpgradeHandler; 96 import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec; 97 import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory; 98 import io.netty.handler.codec.http2.DefaultHttp2Connection; 99 import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder; 100 import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder; 101 import io.netty.handler.codec.http2.DefaultHttp2FrameReader; 102 import io.netty.handler.codec.http2.DefaultHttp2FrameWriter; 103 import io.netty.handler.codec.http2.Http2ConnectionDecoder; 104 import io.netty.handler.codec.http2.Http2ConnectionEncoder; 105 import io.netty.handler.codec.http2.Http2ServerUpgradeCodec; 106 import io.netty.handler.codec.http2.Http2Settings; 107 import io.netty.handler.proxy.ProxyConnectException; 108 import io.netty.handler.ssl.ApplicationProtocolConfig; 109 import io.netty.handler.ssl.SslContext; 110 import io.netty.handler.ssl.SslContextBuilder; 111 import io.netty.handler.ssl.SslHandler; 112 import io.netty.handler.ssl.SslHandshakeCompletionEvent; 113 import io.netty.handler.ssl.util.SelfSignedCertificate; 114 import java.io.File; 115 import java.io.InputStream; 116 import java.net.InetSocketAddress; 117 import java.net.SocketAddress; 118 import java.security.KeyStore; 119 import java.security.cert.Certificate; 120 import java.security.cert.X509Certificate; 121 import java.util.ArrayDeque; 122 import java.util.Arrays; 123 import java.util.Collections; 124 import java.util.List; 125 import java.util.Queue; 126 import java.util.concurrent.CountDownLatch; 127 import java.util.concurrent.TimeUnit; 128 import java.util.concurrent.atomic.AtomicReference; 129 import java.util.logging.Filter; 130 import java.util.logging.Level; 131 import java.util.logging.LogRecord; 132 import java.util.logging.Logger; 133 import javax.net.ssl.KeyManagerFactory; 134 import javax.net.ssl.SSLContext; 135 import javax.net.ssl.SSLEngine; 136 import javax.net.ssl.SSLException; 137 import javax.net.ssl.SSLHandshakeException; 138 import javax.net.ssl.SSLSession; 139 import javax.net.ssl.TrustManagerFactory; 140 import org.junit.After; 141 import org.junit.Before; 142 import org.junit.BeforeClass; 143 import org.junit.Rule; 144 import org.junit.Test; 145 import org.junit.rules.DisableOnDebug; 146 import org.junit.rules.ExpectedException; 147 import org.junit.rules.TestRule; 148 import org.junit.rules.Timeout; 149 import org.junit.runner.RunWith; 150 import org.junit.runners.JUnit4; 151 import org.mockito.ArgumentCaptor; 152 import org.mockito.ArgumentMatchers; 153 import org.mockito.Mockito; 154 155 @RunWith(JUnit4.class) 156 public class ProtocolNegotiatorsTest { 157 private static final Runnable NOOP_RUNNABLE = new Runnable() { 158 @Override public void run() {} 159 }; 160 161 private static File server1Cert; 162 private static File server1Key; 163 private static File caCert; 164 165 @BeforeClass loadCerts()166 public static void loadCerts() throws Exception { 167 server1Cert = TestUtils.loadCert("server1.pem"); 168 server1Key = TestUtils.loadCert("server1.key"); 169 caCert = TestUtils.loadCert("ca.pem"); 170 } 171 172 private static final int TIMEOUT_SECONDS = 60; 173 @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS)); 174 @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 175 @Rule public final ExpectedException thrown = ExpectedException.none(); 176 177 private final EventLoopGroup group = new DefaultEventLoop(); 178 private Channel chan; 179 private Channel server; 180 181 private final GrpcHttp2ConnectionHandler grpcHandler = 182 FakeGrpcHttp2ConnectionHandler.newHandler(); 183 184 private EmbeddedChannel channel = new EmbeddedChannel(); 185 private ChannelPipeline pipeline = channel.pipeline(); 186 private SslContext sslContext; 187 private SSLEngine engine; 188 private ChannelHandlerContext channelHandlerCtx; 189 private static ChannelLogger noopLogger = new NoopChannelLogger(); 190 191 @Before setUp()192 public void setUp() throws Exception { 193 InputStream serverCert = TlsTesting.loadCert("server1.pem"); 194 InputStream key = TlsTesting.loadCert("server1.key"); 195 sslContext = GrpcSslContexts.forServer(serverCert, key).build(); 196 engine = SSLContext.getDefault().createSSLEngine(); 197 engine.setUseClientMode(true); 198 } 199 200 @After tearDown()201 public void tearDown() { 202 if (server != null) { 203 server.close(); 204 } 205 if (chan != null) { 206 chan.close(); 207 } 208 group.shutdownGracefully(); 209 } 210 211 @Test fromClient_unknown()212 public void fromClient_unknown() { 213 ProtocolNegotiators.FromChannelCredentialsResult result = 214 ProtocolNegotiators.from(new ChannelCredentials() { 215 @Override 216 public ChannelCredentials withoutBearerTokens() { 217 throw new UnsupportedOperationException(); 218 } 219 }); 220 assertThat(result.error).isNotNull(); 221 assertThat(result.callCredentials).isNull(); 222 assertThat(result.negotiator).isNull(); 223 } 224 225 @Test fromClient_tls()226 public void fromClient_tls() { 227 ProtocolNegotiators.FromChannelCredentialsResult result = 228 ProtocolNegotiators.from(TlsChannelCredentials.create()); 229 assertThat(result.error).isNull(); 230 assertThat(result.callCredentials).isNull(); 231 assertThat(result.negotiator) 232 .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); 233 } 234 235 @Test fromClient_unsupportedTls()236 public void fromClient_unsupportedTls() { 237 ProtocolNegotiators.FromChannelCredentialsResult result = 238 ProtocolNegotiators.from(TlsChannelCredentials.newBuilder().requireFakeFeature().build()); 239 assertThat(result.error).contains("FAKE"); 240 assertThat(result.callCredentials).isNull(); 241 assertThat(result.negotiator).isNull(); 242 } 243 244 @Test fromClient_insecure()245 public void fromClient_insecure() { 246 ProtocolNegotiators.FromChannelCredentialsResult result = 247 ProtocolNegotiators.from(InsecureChannelCredentials.create()); 248 assertThat(result.error).isNull(); 249 assertThat(result.callCredentials).isNull(); 250 assertThat(result.negotiator) 251 .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); 252 } 253 254 @Test fromClient_composite()255 public void fromClient_composite() { 256 CallCredentials callCredentials = mock(CallCredentials.class); 257 ProtocolNegotiators.FromChannelCredentialsResult result = 258 ProtocolNegotiators.from(CompositeChannelCredentials.create( 259 TlsChannelCredentials.create(), callCredentials)); 260 assertThat(result.error).isNull(); 261 assertThat(result.callCredentials).isSameInstanceAs(callCredentials); 262 assertThat(result.negotiator) 263 .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); 264 265 result = ProtocolNegotiators.from(CompositeChannelCredentials.create( 266 InsecureChannelCredentials.create(), callCredentials)); 267 assertThat(result.error).isNull(); 268 assertThat(result.callCredentials).isSameInstanceAs(callCredentials); 269 assertThat(result.negotiator) 270 .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); 271 } 272 273 @Test fromClient_netty()274 public void fromClient_netty() { 275 ProtocolNegotiator.ClientFactory factory = mock(ProtocolNegotiator.ClientFactory.class); 276 ProtocolNegotiators.FromChannelCredentialsResult result = 277 ProtocolNegotiators.from(NettyChannelCredentials.create(factory)); 278 assertThat(result.error).isNull(); 279 assertThat(result.callCredentials).isNull(); 280 assertThat(result.negotiator).isSameInstanceAs(factory); 281 } 282 283 @Test fromClient_choice()284 public void fromClient_choice() { 285 ProtocolNegotiators.FromChannelCredentialsResult result = 286 ProtocolNegotiators.from(ChoiceChannelCredentials.create( 287 new ChannelCredentials() { 288 @Override 289 public ChannelCredentials withoutBearerTokens() { 290 throw new UnsupportedOperationException(); 291 } 292 }, 293 TlsChannelCredentials.create(), 294 InsecureChannelCredentials.create())); 295 assertThat(result.error).isNull(); 296 assertThat(result.callCredentials).isNull(); 297 assertThat(result.negotiator) 298 .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorClientFactory.class); 299 300 result = ProtocolNegotiators.from(ChoiceChannelCredentials.create( 301 InsecureChannelCredentials.create(), 302 new ChannelCredentials() { 303 @Override 304 public ChannelCredentials withoutBearerTokens() { 305 throw new UnsupportedOperationException(); 306 } 307 }, 308 TlsChannelCredentials.create())); 309 assertThat(result.error).isNull(); 310 assertThat(result.callCredentials).isNull(); 311 assertThat(result.negotiator) 312 .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorClientFactory.class); 313 } 314 315 @Test fromClient_choice_unknown()316 public void fromClient_choice_unknown() { 317 ProtocolNegotiators.FromChannelCredentialsResult result = 318 ProtocolNegotiators.from(ChoiceChannelCredentials.create( 319 new ChannelCredentials() { 320 @Override 321 public ChannelCredentials withoutBearerTokens() { 322 throw new UnsupportedOperationException(); 323 } 324 })); 325 assertThat(result.error).isNotNull(); 326 assertThat(result.callCredentials).isNull(); 327 assertThat(result.negotiator).isNull(); 328 } 329 expectSuccessfulHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds)330 private InternalChannelz.Tls expectSuccessfulHandshake( 331 ChannelCredentials channelCreds, ServerCredentials serverCreds) throws Exception { 332 return (InternalChannelz.Tls) expectHandshake(channelCreds, serverCreds, true); 333 } 334 expectFailedHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds)335 private Status expectFailedHandshake( 336 ChannelCredentials channelCreds, ServerCredentials serverCreds) throws Exception { 337 return (Status) expectHandshake(channelCreds, serverCreds, false); 338 } 339 expectHandshake( ChannelCredentials channelCreds, ServerCredentials serverCreds, boolean expectSuccess)340 private Object expectHandshake( 341 ChannelCredentials channelCreds, ServerCredentials serverCreds, boolean expectSuccess) 342 throws Exception { 343 MockServerListener serverListener = new MockServerListener(); 344 ClientTransportFactory clientFactory = NettyChannelBuilder 345 // Although specified here, address is ignored because we never call build. 346 .forAddress("localhost", 0, channelCreds) 347 .buildTransportFactory(); 348 InternalServer server = NettyServerBuilder 349 .forPort(0, serverCreds) 350 .buildTransportServers(Collections.<ServerStreamTracer.Factory>emptyList()); 351 server.start(serverListener); 352 353 ManagedClientTransport.Listener clientTransportListener = 354 mock(ManagedClientTransport.Listener.class); 355 ManagedClientTransport client = clientFactory.newClientTransport( 356 server.getListenSocketAddress(), 357 new ClientTransportFactory.ClientTransportOptions() 358 .setAuthority(TestUtils.TEST_SERVER_HOST), 359 mock(ChannelLogger.class)); 360 callMeMaybe(client.start(clientTransportListener)); 361 Object result; 362 if (expectSuccess) { 363 verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)).transportReady(); 364 InternalChannelz.SocketStats stats = serverListener.transports.poll().getStats().get(); 365 assertThat(stats.security).isNotNull(); 366 assertThat(stats.security.tls).isNotNull(); 367 result = stats.security.tls; 368 } else { 369 ArgumentCaptor<Status> captor = ArgumentCaptor.forClass(Status.class); 370 verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)) 371 .transportShutdown(captor.capture()); 372 result = captor.getValue(); 373 } 374 375 client.shutdownNow(Status.UNAVAILABLE.withDescription("trash it")); 376 server.shutdown(); 377 assertTrue( 378 serverListener.waitForShutdown(TIMEOUT_SECONDS * 1000, TimeUnit.MILLISECONDS)); 379 verify(clientTransportListener, timeout(TIMEOUT_SECONDS * 1000)).transportTerminated(); 380 clientFactory.close(); 381 return result; 382 } 383 384 @Test from_tls_clientAuthNone_noClientCert()385 public void from_tls_clientAuthNone_noClientCert() throws Exception { 386 // Use convenience API to better match most user's usage 387 ServerCredentials serverCreds = TlsServerCredentials.create(server1Cert, server1Key); 388 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 389 .trustManager(caCert) 390 .build(); 391 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 392 assertThat(tls.remoteCert).isNull(); 393 } 394 395 @Test from_tls_clientAuthNone_clientCert()396 public void from_tls_clientAuthNone_clientCert() throws Exception { 397 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 398 .keyManager(server1Cert, server1Key) 399 .trustManager(caCert) 400 .build(); 401 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 402 .keyManager(server1Cert, server1Key) 403 .trustManager(caCert) 404 .build(); 405 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 406 assertThat(tls.remoteCert).isNull(); 407 } 408 409 @Test from_tls_clientAuthRequire_noClientCert()410 public void from_tls_clientAuthRequire_noClientCert() throws Exception { 411 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 412 .keyManager(server1Cert, server1Key) 413 .trustManager(caCert) 414 .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) 415 .build(); 416 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 417 .trustManager(caCert) 418 .build(); 419 Status status = expectFailedHandshake(channelCreds, serverCreds); 420 assertEquals(Status.Code.UNAVAILABLE, status.getCode()); 421 StatusException sre = status.asException(); 422 // because of netty/netty#11604 we need to check for both TLSv1.2 and v1.3 behaviors 423 if (sre.getCause() instanceof SSLHandshakeException) { 424 assertThat(sre).hasCauseThat().isInstanceOf(SSLHandshakeException.class); 425 assertThat(sre).hasCauseThat().hasMessageThat().contains("SSLV3_ALERT_HANDSHAKE_FAILURE"); 426 } else { 427 // Client cert verification is after handshake in TLSv1.3 428 assertThat(sre).hasCauseThat().hasCauseThat().isInstanceOf(SSLException.class); 429 assertThat(sre).hasCauseThat().hasMessageThat().contains("CERTIFICATE_REQUIRED"); 430 } 431 } 432 433 @Test from_tls_clientAuthRequire_clientCert()434 public void from_tls_clientAuthRequire_clientCert() throws Exception { 435 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 436 .keyManager(server1Cert, server1Key) 437 .trustManager(caCert) 438 .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) 439 .build(); 440 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 441 .keyManager(server1Cert, server1Key) 442 .trustManager(caCert) 443 .build(); 444 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 445 assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName()) 446 .contains("CN=*.test.google.com"); 447 } 448 449 @Test from_tls_clientAuthOptional_noClientCert()450 public void from_tls_clientAuthOptional_noClientCert() throws Exception { 451 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 452 .keyManager(server1Cert, server1Key) 453 .trustManager(caCert) 454 .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL) 455 .build(); 456 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 457 .trustManager(caCert) 458 .build(); 459 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 460 assertThat(tls.remoteCert).isNull(); 461 } 462 463 @Test from_tls_clientAuthOptional_clientCert()464 public void from_tls_clientAuthOptional_clientCert() throws Exception { 465 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 466 .keyManager(server1Cert, server1Key) 467 .trustManager(caCert) 468 .clientAuth(TlsServerCredentials.ClientAuth.OPTIONAL) 469 .build(); 470 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 471 .keyManager(server1Cert, server1Key) 472 .trustManager(caCert) 473 .build(); 474 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 475 assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName()) 476 .contains("CN=*.test.google.com"); 477 } 478 479 @Test from_tls_managers()480 public void from_tls_managers() throws Exception { 481 SelfSignedCertificate cert = new SelfSignedCertificate(TestUtils.TEST_SERVER_HOST); 482 KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); 483 keyStore.load(null); 484 keyStore.setKeyEntry("mykey", cert.key(), new char[0], new Certificate[] {cert.cert()}); 485 KeyManagerFactory keyManagerFactory = 486 KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); 487 keyManagerFactory.init(keyStore, new char[0]); 488 489 KeyStore certStore = KeyStore.getInstance(KeyStore.getDefaultType()); 490 certStore.load(null); 491 certStore.setCertificateEntry("mycert", cert.cert()); 492 TrustManagerFactory trustManagerFactory = 493 TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); 494 trustManagerFactory.init(certStore); 495 496 ServerCredentials serverCreds = TlsServerCredentials.newBuilder() 497 .keyManager(keyManagerFactory.getKeyManagers()) 498 .trustManager(trustManagerFactory.getTrustManagers()) 499 .clientAuth(TlsServerCredentials.ClientAuth.REQUIRE) 500 .build(); 501 ChannelCredentials channelCreds = TlsChannelCredentials.newBuilder() 502 .keyManager(keyManagerFactory.getKeyManagers()) 503 .trustManager(trustManagerFactory.getTrustManagers()) 504 .build(); 505 InternalChannelz.Tls tls = expectSuccessfulHandshake(channelCreds, serverCreds); 506 assertThat(((X509Certificate) tls.remoteCert).getSubjectX500Principal().getName()) 507 .isEqualTo("CN=" + TestUtils.TEST_SERVER_HOST); 508 cert.delete(); 509 } 510 511 @Test fromServer_unknown()512 public void fromServer_unknown() { 513 ProtocolNegotiators.FromServerCredentialsResult result = 514 ProtocolNegotiators.from(new ServerCredentials() {}); 515 assertThat(result.error).isNotNull(); 516 assertThat(result.negotiator).isNull(); 517 } 518 519 @Test fromServer_tls()520 public void fromServer_tls() throws Exception { 521 ProtocolNegotiators.FromServerCredentialsResult result = 522 ProtocolNegotiators.from(TlsServerCredentials.create(server1Cert, server1Key)); 523 assertThat(result.error).isNull(); 524 assertThat(result.negotiator) 525 .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class); 526 } 527 528 @Test fromServer_unsupportedTls()529 public void fromServer_unsupportedTls() throws Exception { 530 ProtocolNegotiators.FromServerCredentialsResult result = ProtocolNegotiators.from( 531 TlsServerCredentials.newBuilder() 532 .keyManager(server1Cert, server1Key) 533 .requireFakeFeature() 534 .build()); 535 assertThat(result.error).contains("FAKE"); 536 assertThat(result.negotiator).isNull(); 537 } 538 539 @Test fromServer_insecure()540 public void fromServer_insecure() { 541 ProtocolNegotiators.FromServerCredentialsResult result = 542 ProtocolNegotiators.from(InsecureServerCredentials.create()); 543 assertThat(result.error).isNull(); 544 assertThat(result.negotiator) 545 .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class); 546 } 547 548 @Test fromServer_netty()549 public void fromServer_netty() { 550 ProtocolNegotiator.ServerFactory factory = mock(ProtocolNegotiator.ServerFactory.class); 551 ProtocolNegotiators.FromServerCredentialsResult result = 552 ProtocolNegotiators.from(NettyServerCredentials.create(factory)); 553 assertThat(result.error).isNull(); 554 assertThat(result.negotiator).isSameInstanceAs(factory); 555 } 556 557 @Test fromServer_choice()558 public void fromServer_choice() throws Exception { 559 ProtocolNegotiators.FromServerCredentialsResult result = 560 ProtocolNegotiators.from(ChoiceServerCredentials.create( 561 new ServerCredentials() {}, 562 TlsServerCredentials.create(server1Cert, server1Key), 563 InsecureServerCredentials.create())); 564 assertThat(result.error).isNull(); 565 assertThat(result.negotiator) 566 .isInstanceOf(ProtocolNegotiators.TlsProtocolNegotiatorServerFactory.class); 567 568 result = ProtocolNegotiators.from(ChoiceServerCredentials.create( 569 InsecureServerCredentials.create(), 570 new ServerCredentials() {}, 571 TlsServerCredentials.create(server1Cert, server1Key))); 572 assertThat(result.error).isNull(); 573 assertThat(result.negotiator) 574 .isInstanceOf(ProtocolNegotiators.PlaintextProtocolNegotiatorServerFactory.class); 575 } 576 577 @Test fromServer_choice_unknown()578 public void fromServer_choice_unknown() { 579 ProtocolNegotiators.FromServerCredentialsResult result = 580 ProtocolNegotiators.from(ChoiceServerCredentials.create( 581 new ServerCredentials() {})); 582 assertThat(result.error).isNotNull(); 583 assertThat(result.negotiator).isNull(); 584 } 585 586 587 @Test waitUntilActiveHandler_handlerAdded()588 public void waitUntilActiveHandler_handlerAdded() throws Exception { 589 final CountDownLatch latch = new CountDownLatch(1); 590 591 final WaitUntilActiveHandler handler = 592 new WaitUntilActiveHandler(new ChannelHandlerAdapter() { 593 @Override 594 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 595 assertTrue(ctx.channel().isActive()); 596 latch.countDown(); 597 super.handlerAdded(ctx); 598 } 599 }, noopLogger); 600 601 ChannelHandler lateAddingHandler = new ChannelInboundHandlerAdapter() { 602 @Override 603 public void channelActive(ChannelHandlerContext ctx) throws Exception { 604 ctx.pipeline().addLast(handler); 605 ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 606 // do not propagate channelActive(). 607 } 608 }; 609 610 LocalAddress addr = new LocalAddress("local"); 611 ChannelFuture cf = new Bootstrap() 612 .channel(LocalChannel.class) 613 .handler(lateAddingHandler) 614 .group(group) 615 .register(); 616 chan = cf.channel(); 617 ChannelFuture sf = new ServerBootstrap() 618 .channel(LocalServerChannel.class) 619 .childHandler(new ChannelHandlerAdapter() {}) 620 .group(group) 621 .bind(addr); 622 server = sf.channel(); 623 sf.sync(); 624 625 assertEquals(1, latch.getCount()); 626 627 chan.connect(addr).sync(); 628 assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); 629 assertNull(chan.pipeline().context(WaitUntilActiveHandler.class)); 630 } 631 632 @Test waitUntilActiveHandler_channelActive()633 public void waitUntilActiveHandler_channelActive() throws Exception { 634 final CountDownLatch latch = new CountDownLatch(1); 635 WaitUntilActiveHandler handler = 636 new WaitUntilActiveHandler(new ChannelHandlerAdapter() { 637 @Override 638 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 639 assertTrue(ctx.channel().isActive()); 640 latch.countDown(); 641 super.handlerAdded(ctx); 642 } 643 }, noopLogger); 644 645 LocalAddress addr = new LocalAddress("local"); 646 ChannelFuture cf = new Bootstrap() 647 .channel(LocalChannel.class) 648 .handler(handler) 649 .group(group) 650 .register(); 651 chan = cf.channel(); 652 ChannelFuture sf = new ServerBootstrap() 653 .channel(LocalServerChannel.class) 654 .childHandler(new ChannelHandlerAdapter() {}) 655 .group(group) 656 .bind(addr); 657 server = sf.channel(); 658 sf.sync(); 659 660 assertEquals(1, latch.getCount()); 661 662 chan.connect(addr).sync(); 663 chan.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 664 assertTrue(latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); 665 assertNull(chan.pipeline().context(WaitUntilActiveHandler.class)); 666 } 667 668 @Test tlsHandler_failsOnNullEngine()669 public void tlsHandler_failsOnNullEngine() throws Exception { 670 thrown.expect(NullPointerException.class); 671 thrown.expectMessage("ssl"); 672 673 Object unused = ProtocolNegotiators.serverTls(null); 674 } 675 676 677 @Test tlsHandler_handlerAddedAddsSslHandler()678 public void tlsHandler_handlerAddedAddsSslHandler() throws Exception { 679 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 680 681 pipeline.addLast(handler); 682 683 assertTrue(pipeline.first() instanceof SslHandler); 684 } 685 686 @Test tlsHandler_userEventTriggeredNonSslEvent()687 public void tlsHandler_userEventTriggeredNonSslEvent() throws Exception { 688 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 689 pipeline.addLast(handler); 690 channelHandlerCtx = pipeline.context(handler); 691 Object nonSslEvent = new Object(); 692 693 pipeline.fireUserEventTriggered(nonSslEvent); 694 695 // A non ssl event should not cause the grpcHandler to be in the pipeline yet. 696 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 697 assertNull(grpcHandlerCtx); 698 } 699 700 @Test tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol()701 public void tlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception { 702 SslHandler badSslHandler = new SslHandler(engine, false) { 703 @Override 704 public String applicationProtocol() { 705 return "badprotocol"; 706 } 707 }; 708 709 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 710 pipeline.addLast(handler); 711 712 final AtomicReference<Throwable> error = new AtomicReference<>(); 713 ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { 714 @Override 715 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { 716 error.set(cause); 717 } 718 }; 719 720 pipeline.addLast(errorCapture); 721 722 pipeline.replace(SslHandler.class, null, badSslHandler); 723 channelHandlerCtx = pipeline.context(handler); 724 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 725 726 pipeline.fireUserEventTriggered(sslEvent); 727 728 // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH) 729 assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol"); 730 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 731 assertNull(grpcHandlerCtx); 732 } 733 734 @Test tlsHandler_userEventTriggeredSslEvent_handshakeFailure()735 public void tlsHandler_userEventTriggeredSslEvent_handshakeFailure() throws Exception { 736 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 737 pipeline.addLast(handler); 738 channelHandlerCtx = pipeline.context(handler); 739 Object sslEvent = new SslHandshakeCompletionEvent(new RuntimeException("bad")); 740 741 final AtomicReference<Throwable> error = new AtomicReference<>(); 742 ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { 743 @Override 744 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { 745 error.set(cause); 746 } 747 }; 748 749 pipeline.addLast(errorCapture); 750 751 pipeline.fireUserEventTriggered(sslEvent); 752 753 // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH) 754 assertThat(error.get()).hasMessageThat().contains("bad"); 755 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 756 assertNull(grpcHandlerCtx); 757 } 758 759 @Test tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2()760 public void tlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception { 761 SslHandler goodSslHandler = new SslHandler(engine, false) { 762 @Override 763 public String applicationProtocol() { 764 return "h2"; 765 } 766 }; 767 768 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 769 pipeline.addLast(handler); 770 771 pipeline.replace(SslHandler.class, null, goodSslHandler); 772 channelHandlerCtx = pipeline.context(handler); 773 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 774 775 pipeline.fireUserEventTriggered(sslEvent); 776 777 assertTrue(channel.isOpen()); 778 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 779 assertNotNull(grpcHandlerCtx); 780 } 781 782 @Test serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()783 public void serverTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom() 784 throws Exception { 785 SslHandler goodSslHandler = new SslHandler(engine, false) { 786 @Override 787 public String applicationProtocol() { 788 return "managed_mtls"; 789 } 790 }; 791 792 InputStream serverCert = TlsTesting.loadCert("server1.pem"); 793 InputStream key = TlsTesting.loadCert("server1.key"); 794 List<String> alpnList = Arrays.asList("managed_mtls", "h2"); 795 ApplicationProtocolConfig apn = new ApplicationProtocolConfig( 796 ApplicationProtocolConfig.Protocol.ALPN, 797 ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, 798 ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, 799 alpnList); 800 801 sslContext = GrpcSslContexts.forServer(serverCert, key) 802 .applicationProtocolConfig(apn).build(); 803 804 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 805 pipeline.addLast(handler); 806 807 pipeline.replace(SslHandler.class, null, goodSslHandler); 808 channelHandlerCtx = pipeline.context(handler); 809 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 810 811 pipeline.fireUserEventTriggered(sslEvent); 812 813 assertTrue(channel.isOpen()); 814 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 815 assertNotNull(grpcHandlerCtx); 816 } 817 818 @Test serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom()819 public void serverTlsHandler_userEventTriggeredSslEvent_unsupportedProtocolCustom() 820 throws Exception { 821 SslHandler badSslHandler = new SslHandler(engine, false) { 822 @Override 823 public String applicationProtocol() { 824 return "badprotocol"; 825 } 826 }; 827 828 InputStream serverCert = TlsTesting.loadCert("server1.pem"); 829 InputStream key = TlsTesting.loadCert("server1.key"); 830 List<String> alpnList = Arrays.asList("managed_mtls", "h2"); 831 ApplicationProtocolConfig apn = new ApplicationProtocolConfig( 832 ApplicationProtocolConfig.Protocol.ALPN, 833 ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, 834 ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, 835 alpnList); 836 837 sslContext = GrpcSslContexts.forServer(serverCert, key) 838 .applicationProtocolConfig(apn).build(); 839 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 840 pipeline.addLast(handler); 841 842 final AtomicReference<Throwable> error = new AtomicReference<>(); 843 ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { 844 @Override 845 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { 846 error.set(cause); 847 } 848 }; 849 850 pipeline.addLast(errorCapture); 851 852 pipeline.replace(SslHandler.class, null, badSslHandler); 853 channelHandlerCtx = pipeline.context(handler); 854 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 855 856 pipeline.fireUserEventTriggered(sslEvent); 857 858 // No h2 protocol was specified, so there should be an error, (normally handled by WBAEH) 859 assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol"); 860 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 861 assertNull(grpcHandlerCtx); 862 } 863 864 @Test clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2()865 public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolH2() throws Exception { 866 SslHandler goodSslHandler = new SslHandler(engine, false) { 867 @Override 868 public String applicationProtocol() { 869 return "h2"; 870 } 871 }; 872 DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); 873 874 ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, 875 "authority", elg, noopLogger); 876 pipeline.addLast(handler); 877 pipeline.replace(SslHandler.class, null, goodSslHandler); 878 pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 879 channelHandlerCtx = pipeline.context(handler); 880 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 881 882 pipeline.fireUserEventTriggered(sslEvent); 883 884 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 885 assertNotNull(grpcHandlerCtx); 886 } 887 888 @Test clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom()889 public void clientTlsHandler_userEventTriggeredSslEvent_supportedProtocolCustom() 890 throws Exception { 891 SslHandler goodSslHandler = new SslHandler(engine, false) { 892 @Override 893 public String applicationProtocol() { 894 return "managed_mtls"; 895 } 896 }; 897 DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); 898 899 InputStream clientCert = TlsTesting.loadCert("client.pem"); 900 InputStream key = TlsTesting.loadCert("client.key"); 901 List<String> alpnList = Arrays.asList("managed_mtls", "h2"); 902 ApplicationProtocolConfig apn = new ApplicationProtocolConfig( 903 ApplicationProtocolConfig.Protocol.ALPN, 904 ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, 905 ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, 906 alpnList); 907 908 sslContext = GrpcSslContexts.forClient() 909 .keyManager(clientCert, key) 910 .applicationProtocolConfig(apn).build(); 911 912 ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, 913 "authority", elg, noopLogger); 914 pipeline.addLast(handler); 915 pipeline.replace(SslHandler.class, null, goodSslHandler); 916 pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 917 channelHandlerCtx = pipeline.context(handler); 918 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 919 920 pipeline.fireUserEventTriggered(sslEvent); 921 922 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 923 assertNotNull(grpcHandlerCtx); 924 } 925 926 @Test clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol()927 public void clientTlsHandler_userEventTriggeredSslEvent_unsupportedProtocol() throws Exception { 928 SslHandler goodSslHandler = new SslHandler(engine, false) { 929 @Override 930 public String applicationProtocol() { 931 return "badproto"; 932 } 933 }; 934 DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); 935 936 ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, 937 "authority", elg, noopLogger); 938 pipeline.addLast(handler); 939 940 final AtomicReference<Throwable> error = new AtomicReference<>(); 941 ChannelHandler errorCapture = new ChannelInboundHandlerAdapter() { 942 @Override 943 public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { 944 error.set(cause); 945 } 946 }; 947 948 pipeline.addLast(errorCapture); 949 pipeline.replace(SslHandler.class, null, goodSslHandler); 950 pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 951 channelHandlerCtx = pipeline.context(handler); 952 Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; 953 954 pipeline.fireUserEventTriggered(sslEvent); 955 956 // Bad protocol was specified, so there should be an error, (normally handled by WBAEH) 957 assertThat(error.get()).hasMessageThat().contains("Unable to find compatible protocol"); 958 ChannelHandlerContext grpcHandlerCtx = pipeline.context(grpcHandler); 959 assertNull(grpcHandlerCtx); 960 } 961 962 @Test clientTlsHandler_closeDuringNegotiation()963 public void clientTlsHandler_closeDuringNegotiation() throws Exception { 964 ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, 965 "authority", null, noopLogger); 966 pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); 967 ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); 968 969 // SslHandler fires userEventTriggered() before channelInactive() 970 pipeline.fireChannelInactive(); 971 972 assertThat(pendingWrite.cause()).isInstanceOf(StatusRuntimeException.class); 973 assertThat(Status.fromThrowable(pendingWrite.cause()).getCode()) 974 .isEqualTo(Status.Code.UNAVAILABLE); 975 } 976 977 @Test engineLog()978 public void engineLog() { 979 ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); 980 pipeline.addLast(handler); 981 channelHandlerCtx = pipeline.context(handler); 982 983 Logger logger = Logger.getLogger(ProtocolNegotiators.class.getName()); 984 Filter oldFilter = logger.getFilter(); 985 try { 986 logger.setFilter(new Filter() { 987 @Override 988 public boolean isLoggable(LogRecord record) { 989 // We still want to the log method to be exercised, just not printed to stderr. 990 return false; 991 } 992 }); 993 994 ProtocolNegotiators.logSslEngineDetails( 995 Level.INFO, channelHandlerCtx, "message", new Exception("bad")); 996 } finally { 997 logger.setFilter(oldFilter); 998 } 999 } 1000 1001 @Test tls_failsOnNullSslContext()1002 public void tls_failsOnNullSslContext() { 1003 thrown.expect(NullPointerException.class); 1004 1005 Object unused = ProtocolNegotiators.tls(null); 1006 } 1007 1008 @Test tls_hostAndPort()1009 public void tls_hostAndPort() { 1010 HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234"); 1011 1012 assertEquals("authority", hostPort.host); 1013 assertEquals(1234, hostPort.port); 1014 } 1015 1016 @Test tls_host()1017 public void tls_host() { 1018 HostPort hostPort = ProtocolNegotiators.parseAuthority("[::1]"); 1019 1020 assertEquals("[::1]", hostPort.host); 1021 assertEquals(-1, hostPort.port); 1022 } 1023 1024 @Test tls_invalidHost()1025 public void tls_invalidHost() throws SSLException { 1026 HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_host:1234"); 1027 1028 // Even though it looks like a port, we treat it as part of the authority, since the host is 1029 // invalid. 1030 assertEquals("bad_host:1234", hostPort.host); 1031 assertEquals(-1, hostPort.port); 1032 } 1033 1034 @Test httpProxy_nullAddressNpe()1035 public void httpProxy_nullAddressNpe() throws Exception { 1036 thrown.expect(NullPointerException.class); 1037 Object unused = 1038 ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext()); 1039 } 1040 1041 @Test httpProxy_nullNegotiatorNpe()1042 public void httpProxy_nullNegotiatorNpe() throws Exception { 1043 thrown.expect(NullPointerException.class); 1044 Object unused = ProtocolNegotiators.httpProxy( 1045 InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null); 1046 } 1047 1048 @Test httpProxy_nullUserPassNoException()1049 public void httpProxy_nullUserPassNoException() throws Exception { 1050 assertNotNull(ProtocolNegotiators.httpProxy( 1051 InetSocketAddress.createUnresolved("localhost", 80), null, null, 1052 ProtocolNegotiators.plaintext())); 1053 } 1054 1055 @Test httpProxy_completes()1056 public void httpProxy_completes() throws Exception { 1057 DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); 1058 // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called 1059 // the channel is already active. 1060 LocalAddress proxy = new LocalAddress("httpProxy_completes"); 1061 SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314); 1062 1063 ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); 1064 Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) 1065 .childHandler(mockHandler) 1066 .bind(proxy).sync().channel(); 1067 1068 ProtocolNegotiator nego = 1069 ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); 1070 // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, 1071 // mocking the behavior using KickStartHandler. 1072 ChannelHandler handler = 1073 new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); 1074 Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) 1075 .register().sync().channel(); 1076 pipeline = channel.pipeline(); 1077 // Wait for initialization to complete 1078 channel.eventLoop().submit(NOOP_RUNNABLE).sync(); 1079 channel.connect(host).sync(); 1080 serverChannel.close(); 1081 ArgumentCaptor<ChannelHandlerContext> contextCaptor = 1082 ArgumentCaptor.forClass(ChannelHandlerContext.class); 1083 Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); 1084 ChannelHandlerContext serverContext = contextCaptor.getValue(); 1085 1086 final String golden = "isThisThingOn?"; 1087 ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); 1088 1089 // Wait for sending initial request to complete 1090 channel.eventLoop().submit(NOOP_RUNNABLE).sync(); 1091 ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class); 1092 Mockito.verify(mockHandler) 1093 .channelRead(ArgumentMatchers.<ChannelHandlerContext>any(), objectCaptor.capture()); 1094 ByteBuf b = (ByteBuf) objectCaptor.getValue(); 1095 String request = b.toString(UTF_8); 1096 b.release(); 1097 assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n")); 1098 assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 ")); 1099 assertTrue("No host header: " + request, request.contains("host: specialHost:314")); 1100 1101 assertFalse(negotiationFuture.isDone()); 1102 serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync(); 1103 negotiationFuture.sync(); 1104 1105 channel.eventLoop().submit(NOOP_RUNNABLE).sync(); 1106 objectCaptor = ArgumentCaptor.forClass(Object.class); 1107 Mockito.verify(mockHandler, times(2)) 1108 .channelRead(ArgumentMatchers.<ChannelHandlerContext>any(), objectCaptor.capture()); 1109 b = (ByteBuf) objectCaptor.getAllValues().get(1); 1110 // If we were using the real grpcHandler, this would have been the HTTP/2 preface 1111 String preface = b.toString(UTF_8); 1112 b.release(); 1113 assertEquals(golden, preface); 1114 1115 channel.close(); 1116 } 1117 1118 @Test httpProxy_500()1119 public void httpProxy_500() throws Exception { 1120 DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1); 1121 // ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called 1122 // the channel is already active. 1123 LocalAddress proxy = new LocalAddress("httpProxy_500"); 1124 SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314); 1125 1126 ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class); 1127 Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class) 1128 .childHandler(mockHandler) 1129 .bind(proxy).sync().channel(); 1130 1131 ProtocolNegotiator nego = 1132 ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext()); 1133 // normally NettyClientTransport will add WBAEH which kick start the ProtocolNegotiation, 1134 // mocking the behavior using KickStartHandler. 1135 ChannelHandler handler = 1136 new KickStartHandler(nego.newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler())); 1137 Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler) 1138 .register().sync().channel(); 1139 pipeline = channel.pipeline(); 1140 // Wait for initialization to complete 1141 channel.eventLoop().submit(NOOP_RUNNABLE).sync(); 1142 channel.connect(host).sync(); 1143 serverChannel.close(); 1144 ArgumentCaptor<ChannelHandlerContext> contextCaptor = 1145 ArgumentCaptor.forClass(ChannelHandlerContext.class); 1146 Mockito.verify(mockHandler).channelActive(contextCaptor.capture()); 1147 ChannelHandlerContext serverContext = contextCaptor.getValue(); 1148 1149 final String golden = "isThisThingOn?"; 1150 ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel)); 1151 1152 // Wait for sending initial request to complete 1153 channel.eventLoop().submit(NOOP_RUNNABLE).sync(); 1154 ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class); 1155 Mockito.verify(mockHandler) 1156 .channelRead(any(ChannelHandlerContext.class), objectCaptor.capture()); 1157 ByteBuf request = (ByteBuf) objectCaptor.getValue(); 1158 request.release(); 1159 1160 assertFalse(negotiationFuture.isDone()); 1161 String response = "HTTP/1.1 500 OMG\r\nContent-Length: 4\r\n\r\noops"; 1162 serverContext.writeAndFlush(bb(response, serverContext.channel())).sync(); 1163 thrown.expect(ProxyConnectException.class); 1164 try { 1165 negotiationFuture.sync(); 1166 } finally { 1167 channel.close(); 1168 } 1169 } 1170 1171 @Test waitUntilActiveHandler_firesNegotiation()1172 public void waitUntilActiveHandler_firesNegotiation() throws Exception { 1173 EventLoopGroup elg = new DefaultEventLoopGroup(1); 1174 SocketAddress addr = new LocalAddress("addr"); 1175 final AtomicReference<Object> event = new AtomicReference<>(); 1176 ChannelHandler next = new ChannelInboundHandlerAdapter() { 1177 @Override 1178 public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { 1179 event.set(evt); 1180 ctx.close(); 1181 } 1182 }; 1183 Channel s = new ServerBootstrap() 1184 .childHandler(new ChannelInboundHandlerAdapter()) 1185 .group(elg) 1186 .channel(LocalServerChannel.class) 1187 .bind(addr) 1188 .sync() 1189 .channel(); 1190 Channel c = new Bootstrap() 1191 .handler(new WaitUntilActiveHandler(next, noopLogger)) 1192 .channel(LocalChannel.class).group(group) 1193 .connect(addr) 1194 .sync() 1195 .channel(); 1196 c.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 1197 SocketAddress localAddr = c.localAddress(); 1198 ProtocolNegotiationEvent expectedEvent = ProtocolNegotiationEvent.DEFAULT 1199 .withAttributes( 1200 Attributes.newBuilder() 1201 .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddr) 1202 .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, addr) 1203 .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) 1204 .build()); 1205 1206 c.closeFuture().sync(); 1207 assertThat(event.get()).isInstanceOf(ProtocolNegotiationEvent.class); 1208 ProtocolNegotiationEvent actual = (ProtocolNegotiationEvent) event.get(); 1209 assertThat(actual).isEqualTo(expectedEvent); 1210 1211 s.close(); 1212 elg.shutdownGracefully(); 1213 } 1214 1215 @Test clientTlsHandler_firesNegotiation()1216 public void clientTlsHandler_firesNegotiation() throws Exception { 1217 SelfSignedCertificate cert = new SelfSignedCertificate("authority"); 1218 SslContext clientSslContext = 1219 GrpcSslContexts.configure(SslContextBuilder.forClient().trustManager(cert.cert())).build(); 1220 SslContext serverSslContext = 1221 GrpcSslContexts.configure(SslContextBuilder.forServer(cert.key(), cert.cert())).build(); 1222 FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); 1223 ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null); 1224 WriteBufferingAndExceptionHandler clientWbaeh = 1225 new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); 1226 1227 SocketAddress addr = new LocalAddress("addr"); 1228 1229 ChannelHandler sh = 1230 ProtocolNegotiators.serverTls(serverSslContext) 1231 .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler()); 1232 WriteBufferingAndExceptionHandler serverWbaeh = new WriteBufferingAndExceptionHandler(sh); 1233 Channel s = new ServerBootstrap() 1234 .childHandler(serverWbaeh) 1235 .group(group) 1236 .channel(LocalServerChannel.class) 1237 .bind(addr) 1238 .sync() 1239 .channel(); 1240 Channel c = new Bootstrap() 1241 .handler(clientWbaeh) 1242 .channel(LocalChannel.class) 1243 .group(group) 1244 .register() 1245 .sync() 1246 .channel(); 1247 ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); 1248 c.connect(addr).sync(); 1249 write.sync(); 1250 1251 boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); 1252 if (!completed) { 1253 assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); 1254 // sync should fail if we are in this block. 1255 write.sync(); 1256 throw new AssertionError("neither wrote nor negotiated"); 1257 } 1258 c.close(); 1259 s.close(); 1260 pn.close(); 1261 1262 assertThat(gh.securityInfo).isNotNull(); 1263 assertThat(gh.securityInfo.tls).isNotNull(); 1264 assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)) 1265 .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY); 1266 assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_SSL_SESSION)).isInstanceOf(SSLSession.class); 1267 // This is not part of the ClientTls negotiation, but shows that the negotiation event happens 1268 // in the right order. 1269 assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr); 1270 } 1271 1272 @Test plaintextUpgradeNegotiator()1273 public void plaintextUpgradeNegotiator() throws Exception { 1274 LocalAddress addr = new LocalAddress("plaintextUpgradeNegotiator"); 1275 UpgradeCodecFactory ucf = new UpgradeCodecFactory() { 1276 1277 @Override 1278 public UpgradeCodec newUpgradeCodec(CharSequence protocol) { 1279 return new Http2ServerUpgradeCodec(FakeGrpcHttp2ConnectionHandler.newHandler()); 1280 } 1281 }; 1282 final HttpServerCodec serverCodec = new HttpServerCodec(); 1283 final HttpServerUpgradeHandler serverUpgradeHandler = 1284 new HttpServerUpgradeHandler(serverCodec, ucf); 1285 Channel serverChannel = new ServerBootstrap() 1286 .group(group) 1287 .channel(LocalServerChannel.class) 1288 .childHandler(new ChannelInitializer<Channel>() { 1289 1290 @Override 1291 protected void initChannel(Channel ch) throws Exception { 1292 ch.pipeline().addLast(serverCodec, serverUpgradeHandler); 1293 } 1294 }) 1295 .bind(addr) 1296 .sync() 1297 .channel(); 1298 1299 FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); 1300 ProtocolNegotiator nego = ProtocolNegotiators.plaintextUpgrade(); 1301 ChannelHandler ch = nego.newHandler(gh); 1302 WriteBufferingAndExceptionHandler wbaeh = new WriteBufferingAndExceptionHandler(ch); 1303 1304 Channel channel = new Bootstrap() 1305 .group(group) 1306 .channel(LocalChannel.class) 1307 .handler(wbaeh) 1308 .register() 1309 .sync() 1310 .channel(); 1311 1312 ChannelFuture write = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); 1313 channel.connect(serverChannel.localAddress()); 1314 1315 boolean completed = gh.negotiated.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); 1316 if (!completed) { 1317 assertTrue("failed to negotiated", write.await(TIMEOUT_SECONDS, TimeUnit.SECONDS)); 1318 // sync should fail if we are in this block. 1319 write.sync(); 1320 throw new AssertionError("neither wrote nor negotiated"); 1321 } 1322 1323 channel.close().sync(); 1324 serverChannel.close(); 1325 1326 assertThat(gh.securityInfo).isNull(); 1327 assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL)).isEqualTo(SecurityLevel.NONE); 1328 assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr); 1329 } 1330 callMeMaybe(Runnable runnable)1331 private static void callMeMaybe(Runnable runnable) { 1332 if (runnable != null) { 1333 runnable.run(); 1334 } 1335 } 1336 1337 private static class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler { 1338 noopHandler()1339 static FakeGrpcHttp2ConnectionHandler noopHandler() { 1340 return newHandler(true); 1341 } 1342 newHandler()1343 static FakeGrpcHttp2ConnectionHandler newHandler() { 1344 return newHandler(false); 1345 } 1346 newHandler(boolean noop)1347 private static FakeGrpcHttp2ConnectionHandler newHandler(boolean noop) { 1348 DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false); 1349 DefaultHttp2ConnectionEncoder encoder = 1350 new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); 1351 DefaultHttp2ConnectionDecoder decoder = 1352 new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader()); 1353 Http2Settings settings = new Http2Settings(); 1354 return new FakeGrpcHttp2ConnectionHandler( 1355 /*channelUnused=*/ null, decoder, encoder, settings, noop, noopLogger); 1356 } 1357 1358 private final boolean noop; 1359 private Attributes attrs; 1360 private Security securityInfo; 1361 private final CountDownLatch negotiated = new CountDownLatch(1); 1362 private ChannelHandlerContext ctx; 1363 FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused, Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings, boolean noop, ChannelLogger negotiationLogger)1364 FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused, 1365 Http2ConnectionDecoder decoder, 1366 Http2ConnectionEncoder encoder, 1367 Http2Settings initialSettings, 1368 boolean noop, 1369 ChannelLogger negotiationLogger) { 1370 super(channelUnused, decoder, encoder, initialSettings, negotiationLogger); 1371 this.noop = noop; 1372 } 1373 1374 @Override handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo)1375 public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) { 1376 checkNotNull(ctx, "handleProtocolNegotiationCompleted cannot be called before handlerAdded"); 1377 super.handleProtocolNegotiationCompleted(attrs, securityInfo); 1378 this.attrs = attrs; 1379 this.securityInfo = securityInfo; 1380 // Add a temp handler that verifies first message is a NOOP_MESSAGE 1381 ctx.pipeline().addBefore(ctx.name(), null, new ChannelOutboundHandlerAdapter() { 1382 @Override 1383 public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) 1384 throws Exception { 1385 checkState( 1386 msg == NettyClientHandler.NOOP_MESSAGE, "First message should be NOOP_MESSAGE"); 1387 promise.trySuccess(); 1388 ctx.pipeline().remove(this); 1389 } 1390 }); 1391 NettyClientHandler.writeBufferingAndRemove(ctx.channel()); 1392 negotiated.countDown(); 1393 } 1394 1395 @Override handlerAdded(ChannelHandlerContext ctx)1396 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 1397 if (noop) { 1398 ctx.pipeline().remove(ctx.name()); 1399 } else { 1400 super.handlerAdded(ctx); 1401 } 1402 this.ctx = ctx; 1403 } 1404 1405 @Override getAuthority()1406 public String getAuthority() { 1407 return "authority"; 1408 } 1409 } 1410 bb(String s, Channel c)1411 private static ByteBuf bb(String s, Channel c) { 1412 return ByteBufUtil.writeUtf8(c.alloc(), s); 1413 } 1414 1415 private static final class KickStartHandler extends ChannelDuplexHandler { 1416 1417 private final ChannelHandler next; 1418 KickStartHandler(ChannelHandler next)1419 public KickStartHandler(ChannelHandler next) { 1420 this.next = checkNotNull(next, "next"); 1421 } 1422 1423 @Override handlerAdded(ChannelHandlerContext ctx)1424 public void handlerAdded(ChannelHandlerContext ctx) throws Exception { 1425 ctx.pipeline().replace(ctx.name(), null, next); 1426 ctx.pipeline().fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); 1427 } 1428 } 1429 1430 private static class MockServerListener implements ServerListener { 1431 private final CountDownLatch latch = new CountDownLatch(1); 1432 public Queue<ServerTransport> transports = new ArrayDeque<>(); 1433 1434 @Override transportCreated(ServerTransport transport)1435 public ServerTransportListener transportCreated(ServerTransport transport) { 1436 transports.add(transport); 1437 return new MockServerTransportListener(); 1438 } 1439 1440 @Override serverShutdown()1441 public void serverShutdown() { 1442 latch.countDown(); 1443 } 1444 waitForShutdown(long timeout, TimeUnit unit)1445 public boolean waitForShutdown(long timeout, TimeUnit unit) throws InterruptedException { 1446 return latch.await(timeout, unit); 1447 } 1448 } 1449 1450 private static class MockServerTransportListener implements ServerTransportListener { 1451 @Override streamCreated(ServerStream stream, String method, Metadata headers)1452 public void streamCreated(ServerStream stream, String method, Metadata headers) {} 1453 1454 @Override transportReady(Attributes attributes)1455 public Attributes transportReady(Attributes attributes) { 1456 return attributes; 1457 } 1458 1459 @Override transportTerminated()1460 public void transportTerminated() {} 1461 } 1462 } 1463