1 package software.amazon.awssdk.crt.test; 2 3 import org.junit.Test; 4 import software.amazon.awssdk.crt.eventstream.*; 5 import software.amazon.awssdk.crt.io.ClientBootstrap; 6 import software.amazon.awssdk.crt.io.EventLoopGroup; 7 import software.amazon.awssdk.crt.io.ServerBootstrap; 8 import software.amazon.awssdk.crt.io.SocketOptions; 9 10 import java.io.IOException; 11 import java.nio.charset.StandardCharsets; 12 import java.util.ArrayList; 13 import java.util.List; 14 import java.util.concurrent.*; 15 import java.util.concurrent.locks.Condition; 16 import java.util.concurrent.locks.Lock; 17 import java.util.concurrent.locks.ReentrantLock; 18 19 import static org.junit.Assert.*; 20 21 public class EventStreamClientConnectionTest extends CrtTestFixture { EventStreamClientConnectionTest()22 public EventStreamClientConnectionTest() {} 23 24 @Test testConnectionHandling()25 public void testConnectionHandling() throws ExecutionException, InterruptedException, IOException, TimeoutException { 26 SocketOptions socketOptions = new SocketOptions(); 27 socketOptions.connectTimeoutMs = 3000; 28 socketOptions.domain = SocketOptions.SocketDomain.IPv4; 29 socketOptions.type = SocketOptions.SocketType.STREAM; 30 31 EventLoopGroup elGroup = new EventLoopGroup(1); 32 ServerBootstrap bootstrap = new ServerBootstrap(elGroup); 33 ClientBootstrap clientBootstrap = new ClientBootstrap(elGroup, null); 34 final boolean[] connectionReceived = {false}; 35 final boolean[] connectionShutdown = {false}; 36 final ServerConnection[] serverConnections = {null}; 37 final CompletableFuture<ServerConnection> serverConnectionAccepted = new CompletableFuture<>(); 38 39 ServerListener listener = new ServerListener("127.0.0.1", (short)8033, socketOptions, null, bootstrap, new ServerListenerHandler() { 40 private ServerConnectionHandler connectionHandler = null; 41 42 public ServerConnectionHandler onNewConnection(ServerConnection serverConnection, int errorCode) { 43 serverConnections[0] = serverConnection; 44 connectionReceived[0] = true; 45 46 connectionHandler = new ServerConnectionHandler(serverConnection) { 47 48 @Override 49 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 50 } 51 52 @Override 53 protected ServerConnectionContinuationHandler onIncomingStream(ServerConnectionContinuation continuation, String operationName) { 54 return null; 55 } 56 }; 57 58 serverConnectionAccepted.complete(serverConnection); 59 return connectionHandler; 60 } 61 62 public void onConnectionShutdown(ServerConnection serverConnection, int errorCode) { 63 connectionShutdown[0] = true; 64 } 65 }); 66 67 final boolean[] clientConnected = {false}; 68 final ClientConnection[] clientConnectionArray = {null}; 69 70 CompletableFuture<Void> connectFuture = ClientConnection.connect("127.0.0.1", (short)8033, socketOptions, null, clientBootstrap, new ClientConnectionHandler() { 71 @Override 72 protected void onConnectionSetup(ClientConnection connection, int errorCode) { 73 clientConnected[0] = true; 74 clientConnectionArray[0] = connection; 75 } 76 77 @Override 78 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 79 80 } 81 }); 82 83 connectFuture.get(1, TimeUnit.SECONDS); 84 assertNotNull(clientConnectionArray[0]); 85 serverConnectionAccepted.get(1, TimeUnit.SECONDS); 86 assertNotNull(serverConnections[0]); 87 clientConnectionArray[0].closeConnection(0); 88 clientConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 89 serverConnections[0].getClosedFuture().get(1, TimeUnit.SECONDS); 90 assertTrue(connectionReceived[0]); 91 assertTrue(connectionShutdown[0]); 92 assertTrue(clientConnected[0]); 93 listener.close(); 94 listener.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 95 bootstrap.close(); 96 clientBootstrap.close(); 97 clientBootstrap.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 98 elGroup.close(); 99 elGroup.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 100 socketOptions.close(); 101 } 102 103 @Test testConnectionProtocolMessageHandling()104 public void testConnectionProtocolMessageHandling() throws ExecutionException, InterruptedException, IOException, TimeoutException { 105 SocketOptions socketOptions = new SocketOptions(); 106 socketOptions.connectTimeoutMs = 3000; 107 socketOptions.domain = SocketOptions.SocketDomain.IPv4; 108 socketOptions.type = SocketOptions.SocketType.STREAM; 109 110 EventLoopGroup elGroup = new EventLoopGroup(1); 111 ServerBootstrap bootstrap = new ServerBootstrap(elGroup); 112 ClientBootstrap clientBootstrap = new ClientBootstrap(elGroup, null); 113 final List<Header>[] receivedMessageHeaders = new List[]{null}; 114 final byte[][] receivedPayload = {null}; 115 final MessageType[] receivedMessageType = {null}; 116 final int[] receivedMessageFlags = {-1}; 117 final ServerConnection[] serverConnectionArray = {null}; 118 final boolean[] serverConnShutdown = {false}; 119 120 final byte[] responseMessage = "{ \"message\": \"connect ack\" }".getBytes(StandardCharsets.UTF_8); 121 final CompletableFuture<ServerConnection> serverConnectionAccepted = new CompletableFuture<>(); 122 final CompletableFuture<ServerConnection> serverMessageSent = new CompletableFuture<>(); 123 124 ServerListener listener = new ServerListener("127.0.0.1", (short)8034, socketOptions, null, bootstrap, new ServerListenerHandler() { 125 private ServerConnectionHandler connectionHandler = null; 126 127 public ServerConnectionHandler onNewConnection(ServerConnection serverConnection, int errorCode) { 128 serverConnectionArray[0] = serverConnection; 129 130 connectionHandler = new ServerConnectionHandler(serverConnection) { 131 132 @Override 133 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 134 receivedMessageHeaders[0] = headers; 135 receivedPayload[0] = payload; 136 receivedMessageType[0] = messageType; 137 receivedMessageFlags[0] = messageFlags; 138 serverConnection.sendProtocolMessage(null, responseMessage, MessageType.ConnectAck, MessageFlags.ConnectionAccepted.getByteValue()) 139 .whenComplete((res, ex) -> { 140 serverMessageSent.complete(serverConnection); 141 }); 142 } 143 144 @Override 145 protected ServerConnectionContinuationHandler onIncomingStream(ServerConnectionContinuation continuation, String operationName) { 146 return null; 147 } 148 }; 149 serverConnectionAccepted.complete(serverConnection); 150 return connectionHandler; 151 } 152 153 @Override 154 protected void onConnectionShutdown(ServerConnection serverConnection, int errorCode) { 155 serverConnShutdown[0] = true; 156 } 157 }); 158 159 final ClientConnection[] clientConnectionArray = {null}; 160 final List<Header>[] clientReceivedMessageHeaders = new List[]{null}; 161 final byte[][] clientReceivedPayload = {null}; 162 final MessageType[] clientReceivedMessageType = {null}; 163 final int[] clientReceivedMessageFlags = {-1}; 164 final CompletableFuture<Void> clientMessageReceived = new CompletableFuture<>(); 165 166 CompletableFuture<Void> connectFuture = ClientConnection.connect("127.0.0.1", (short)8034, socketOptions, null, clientBootstrap, new ClientConnectionHandler() { 167 @Override 168 protected void onConnectionSetup(ClientConnection connection, int errorCode) { 169 clientConnectionArray[0] = connection; 170 } 171 172 @Override 173 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 174 clientReceivedMessageHeaders[0] = headers; 175 clientReceivedPayload[0] = payload; 176 clientReceivedMessageType[0] = messageType; 177 clientReceivedMessageFlags[0] = messageFlags; 178 clientMessageReceived.complete(null); 179 } 180 }); 181 182 final byte[] connectPayload = "test connect payload".getBytes(StandardCharsets.UTF_8); 183 connectFuture.get(1, TimeUnit.SECONDS); 184 assertNotNull(clientConnectionArray[0]); 185 serverConnectionAccepted.get(1, TimeUnit.SECONDS); 186 assertNotNull(serverConnectionArray[0]); 187 188 clientConnectionArray[0].sendProtocolMessage(null, connectPayload, MessageType.Connect, 0); 189 190 clientMessageReceived.get(1, TimeUnit.SECONDS); 191 assertEquals(MessageType.Connect, receivedMessageType[0]); 192 assertArrayEquals(connectPayload, receivedPayload[0]); 193 assertEquals(MessageType.ConnectAck, clientReceivedMessageType[0]); 194 assertEquals(MessageFlags.ConnectionAccepted.getByteValue(), clientReceivedMessageFlags[0]); 195 assertArrayEquals(responseMessage, clientReceivedPayload[0]); 196 clientConnectionArray[0].closeConnection(0); 197 clientConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 198 serverConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 199 200 assertTrue(serverConnShutdown[0]); 201 listener.close(); 202 listener.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 203 bootstrap.close(); 204 clientBootstrap.close(); 205 clientBootstrap.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 206 elGroup.close(); 207 elGroup.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 208 socketOptions.close(); 209 } 210 211 @Test testConnectionProtocolMessageWithExtraHeadersHandling()212 public void testConnectionProtocolMessageWithExtraHeadersHandling() throws ExecutionException, InterruptedException, IOException, TimeoutException { 213 SocketOptions socketOptions = new SocketOptions(); 214 socketOptions.connectTimeoutMs = 3000; 215 socketOptions.domain = SocketOptions.SocketDomain.IPv4; 216 socketOptions.type = SocketOptions.SocketType.STREAM; 217 218 EventLoopGroup elGroup = new EventLoopGroup(1); 219 ServerBootstrap bootstrap = new ServerBootstrap(elGroup); 220 ClientBootstrap clientBootstrap = new ClientBootstrap(elGroup, null); 221 final boolean[] connectionShutdown = {false}; 222 final List<Header>[] receivedMessageHeaders = new List[]{null}; 223 final byte[][] receivedPayload = {null}; 224 final MessageType[] receivedMessageType = {null}; 225 final int[] receivedMessageFlags = {-1}; 226 final ServerConnection[] serverConnections = {null}; 227 228 final byte[] responseMessage = "{ \"message\": \"connect ack\" }".getBytes(StandardCharsets.UTF_8); 229 230 Header serverStrHeader = Header.createHeader("serverStrHeaderName", "serverStrHeaderValue"); 231 Header serverIntHeader = Header.createHeader("serverIntHeaderName", 25); 232 233 Lock semaphoreLock = new ReentrantLock(); 234 Condition semaphore = semaphoreLock.newCondition(); 235 236 ServerListener listener = new ServerListener("127.0.0.1", (short)8035, socketOptions, null, bootstrap, new ServerListenerHandler() { 237 private ServerConnectionHandler connectionHandler = null; 238 239 public ServerConnectionHandler onNewConnection(ServerConnection serverConnection, int errorCode) { 240 serverConnections[0] = serverConnection; 241 connectionHandler = new ServerConnectionHandler(serverConnection) { 242 243 @Override 244 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 245 receivedMessageHeaders[0] = headers; 246 receivedPayload[0] = payload; 247 receivedMessageType[0] = messageType; 248 receivedMessageFlags[0] = messageFlags; 249 250 List<Header> respHeaders = new ArrayList<>(); 251 respHeaders.add(serverStrHeader); 252 respHeaders.add(serverIntHeader); 253 serverConnection.sendProtocolMessage(respHeaders, responseMessage, MessageType.ConnectAck, MessageFlags.ConnectionAccepted.getByteValue()); 254 } 255 256 @Override 257 protected ServerConnectionContinuationHandler onIncomingStream(ServerConnectionContinuation continuation, String operationName) { 258 return null; 259 } 260 }; 261 262 semaphoreLock.lock(); 263 semaphore.signal(); 264 semaphoreLock.unlock(); 265 return connectionHandler; 266 } 267 268 public void onConnectionShutdown(ServerConnection serverConnection, int errorCode) { 269 connectionShutdown[0] = true; 270 } 271 }); 272 273 final boolean[] clientConnected = {false}; 274 final ClientConnection[] clientConnectionArray = {null}; 275 final List<Header>[] clientReceivedMessageHeaders = new List[]{null}; 276 final byte[][] clientReceivedPayload = {null}; 277 final MessageType[] clientReceivedMessageType = {null}; 278 final int[] clientReceivedMessageFlags = {-1}; 279 280 CompletableFuture<Void> connectFuture = ClientConnection.connect("127.0.0.1", (short)8035, socketOptions, null, clientBootstrap, new ClientConnectionHandler() { 281 @Override 282 protected void onConnectionSetup(ClientConnection connection, int errorCode) { 283 clientConnected[0] = true; 284 clientConnectionArray[0] = connection; 285 } 286 287 @Override 288 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 289 semaphoreLock.lock(); 290 clientReceivedMessageHeaders[0] = headers; 291 clientReceivedPayload[0] = payload; 292 clientReceivedMessageType[0] = messageType; 293 clientReceivedMessageFlags[0] = messageFlags; 294 semaphore.signal(); 295 semaphoreLock.unlock(); 296 } 297 }); 298 299 final byte[] connectPayload = "test connect payload".getBytes(StandardCharsets.UTF_8); 300 connectFuture.get(1, TimeUnit.SECONDS); 301 assertNotNull(clientConnectionArray[0]); 302 semaphoreLock.lock(); 303 semaphore.await(1, TimeUnit.SECONDS); 304 assertNotNull(serverConnections[0]); 305 306 Header clientStrHeader = Header.createHeader("clientStrHeaderName", "clientStrHeaderValue"); 307 Header clientIntHeader = Header.createHeader("clientIntHeaderName", 35); 308 List<Header> clientHeaders = new ArrayList<>(); 309 clientHeaders.add(clientStrHeader); 310 clientHeaders.add(clientIntHeader); 311 312 clientConnectionArray[0].sendProtocolMessage(clientHeaders, connectPayload, MessageType.Connect, 0); 313 314 semaphore.await(1, TimeUnit.SECONDS); 315 semaphoreLock.unlock(); 316 assertEquals(MessageType.Connect, receivedMessageType[0]); 317 assertArrayEquals(connectPayload, receivedPayload[0]); 318 assertNotNull(receivedMessageHeaders[0]); 319 assertEquals(clientStrHeader.getName(), receivedMessageHeaders[0].get(0).getName()); 320 assertEquals(clientStrHeader.getValueAsString(), receivedMessageHeaders[0].get(0).getValueAsString()); 321 assertEquals(clientIntHeader.getName(), receivedMessageHeaders[0].get(1).getName()); 322 assertEquals(clientIntHeader.getValueAsInt(), receivedMessageHeaders[0].get(1).getValueAsInt()); 323 assertEquals(MessageType.ConnectAck, clientReceivedMessageType[0]); 324 assertEquals(MessageFlags.ConnectionAccepted.getByteValue(), clientReceivedMessageFlags[0]); 325 assertArrayEquals(responseMessage, clientReceivedPayload[0]); 326 assertEquals(serverStrHeader.getName(), clientReceivedMessageHeaders[0].get(0).getName()); 327 assertEquals(serverStrHeader.getValueAsString(), clientReceivedMessageHeaders[0].get(0).getValueAsString()); 328 assertEquals(serverIntHeader.getName(), clientReceivedMessageHeaders[0].get(1).getName()); 329 assertEquals(serverIntHeader.getValueAsInt(), clientReceivedMessageHeaders[0].get(1).getValueAsInt()); 330 clientConnectionArray[0].closeConnection(0); 331 clientConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 332 serverConnections[0].getClosedFuture().get(1, TimeUnit.SECONDS); 333 334 assertTrue(connectionShutdown[0]); 335 assertTrue(clientConnected[0]); 336 listener.close(); 337 listener.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 338 bootstrap.close(); 339 clientBootstrap.close(); 340 clientBootstrap.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 341 elGroup.close(); 342 elGroup.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 343 socketOptions.close(); 344 } 345 346 @Test testContinuationMessageHandling()347 public void testContinuationMessageHandling() throws ExecutionException, InterruptedException, IOException, TimeoutException { 348 SocketOptions socketOptions = new SocketOptions(); 349 socketOptions.connectTimeoutMs = 3000; 350 socketOptions.domain = SocketOptions.SocketDomain.IPv4; 351 socketOptions.type = SocketOptions.SocketType.STREAM; 352 353 EventLoopGroup elGroup = new EventLoopGroup(1); 354 ServerBootstrap bootstrap = new ServerBootstrap(elGroup); 355 ClientBootstrap clientBootstrap = new ClientBootstrap(elGroup, null); 356 357 final boolean[] connectionShutdown = {false}; 358 final String[] receivedOperationName = new String[]{null}; 359 final String[] receivedContinuationPayload = new String[]{null}; 360 361 final byte[] responsePayload = "{ \"message\": \"this is a response message\" }".getBytes(StandardCharsets.UTF_8); 362 final ServerConnection[] serverConnections = {null}; 363 Lock semaphoreLock = new ReentrantLock(); 364 Condition semaphore = semaphoreLock.newCondition(); 365 366 ServerListener listener = new ServerListener("127.0.0.1", (short)8036, socketOptions, null, bootstrap, new ServerListenerHandler() { 367 private ServerConnectionHandler connectionHandler = null; 368 369 public ServerConnectionHandler onNewConnection(ServerConnection serverConnection, int errorCode) { 370 serverConnections[0] = serverConnection; 371 connectionHandler = new ServerConnectionHandler(serverConnection) { 372 373 @Override 374 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 375 int responseMessageFlag = MessageFlags.ConnectionAccepted.getByteValue(); 376 MessageType acceptResponseType = MessageType.ConnectAck; 377 378 connection.sendProtocolMessage(null, null, acceptResponseType, responseMessageFlag); 379 } 380 381 @Override 382 protected ServerConnectionContinuationHandler onIncomingStream(ServerConnectionContinuation continuation, String operationName) { 383 receivedOperationName[0] = operationName; 384 385 return new ServerConnectionContinuationHandler(continuation) { 386 @Override 387 protected void onContinuationMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 388 receivedContinuationPayload[0] = new String(payload, StandardCharsets.UTF_8); 389 390 continuation.sendMessage(null, responsePayload, 391 MessageType.ApplicationError, 392 MessageFlags.TerminateStream.getByteValue()) 393 .whenComplete((res, ex) -> { 394 connection.closeConnection(0); 395 this.close(); 396 }); 397 } 398 }; 399 } 400 }; 401 402 semaphoreLock.lock(); 403 semaphore.signal(); 404 semaphoreLock.unlock(); 405 return connectionHandler; 406 } 407 408 public void onConnectionShutdown(ServerConnection serverConnection, int errorCode) { 409 connectionShutdown[0] = true; 410 } 411 }); 412 413 final ClientConnection[] clientConnectionArray = {null}; 414 final List<Header>[] clientReceivedMessageHeaders = new List[]{null}; 415 final byte[][] clientReceivedPayload = {null}; 416 final MessageType[] clientReceivedMessageType = {null}; 417 final int[] clientReceivedMessageFlags = {-1}; 418 final boolean[] clientContinuationClosed = {false}; 419 420 CompletableFuture<Void> connectFuture = ClientConnection.connect("127.0.0.1", (short)8036, socketOptions, null, clientBootstrap, new ClientConnectionHandler() { 421 @Override 422 protected void onConnectionSetup(ClientConnection connection, int errorCode) { 423 clientConnectionArray[0] = connection; 424 } 425 426 @Override 427 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 428 semaphoreLock.lock(); 429 semaphore.signal(); 430 semaphoreLock.unlock(); 431 } 432 }); 433 434 final byte[] connectPayload = "test connect payload".getBytes(StandardCharsets.UTF_8); 435 connectFuture.get(1, TimeUnit.SECONDS); 436 assertNotNull(clientConnectionArray[0]); 437 semaphoreLock.lock(); 438 semaphore.await(1, TimeUnit.SECONDS); 439 assertNotNull(serverConnections[0]); 440 clientConnectionArray[0].sendProtocolMessage(null, connectPayload, MessageType.Connect, 0); 441 semaphore.await(1, TimeUnit.SECONDS); 442 String operationName = "testOperation"; 443 444 ClientConnectionContinuation continuation = clientConnectionArray[0].newStream(new ClientConnectionContinuationHandler() { 445 @Override 446 protected void onContinuationMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 447 semaphoreLock.lock(); 448 clientReceivedMessageHeaders[0] = headers; 449 clientReceivedMessageType[0] = messageType; 450 clientReceivedMessageFlags[0] = messageFlags; 451 clientReceivedPayload[0] = payload; 452 semaphoreLock.unlock(); 453 } 454 455 @Override 456 protected void onContinuationClosed() { 457 semaphoreLock.lock(); 458 clientContinuationClosed[0] = true; 459 semaphore.signal(); 460 semaphoreLock.unlock(); 461 super.onContinuationClosed(); 462 } 463 }); 464 assertNotNull(continuation); 465 466 final byte[] operationPayload = "{\"message\": \"message payload\"}".getBytes(StandardCharsets.UTF_8); 467 continuation.activate(operationName, null, operationPayload, MessageType.ApplicationMessage, 0); 468 semaphore.await(1, TimeUnit.SECONDS); 469 470 assertArrayEquals(responsePayload, clientReceivedPayload[0]); 471 assertEquals(MessageType.ApplicationError, clientReceivedMessageType[0]); 472 assertEquals(MessageFlags.TerminateStream.getByteValue(), clientReceivedMessageFlags[0]); 473 assertTrue(clientContinuationClosed[0]); 474 475 clientConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 476 serverConnections[0].getClosedFuture().get(1, TimeUnit.SECONDS); 477 semaphoreLock.unlock(); 478 479 assertTrue(connectionShutdown[0]); 480 assertNotNull(receivedOperationName[0]); 481 assertEquals(operationName, receivedOperationName[0]); 482 assertEquals(new String(operationPayload, StandardCharsets.UTF_8), receivedContinuationPayload[0]); 483 listener.close(); 484 listener.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 485 bootstrap.close(); 486 clientBootstrap.close(); 487 clientBootstrap.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 488 elGroup.close(); 489 elGroup.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 490 socketOptions.close(); 491 } 492 493 @Test testContinuationMessageWithExtraHeadersHandling()494 public void testContinuationMessageWithExtraHeadersHandling() throws ExecutionException, InterruptedException, IOException, TimeoutException { 495 SocketOptions socketOptions = new SocketOptions(); 496 socketOptions.connectTimeoutMs = 3000; 497 socketOptions.domain = SocketOptions.SocketDomain.IPv4; 498 socketOptions.type = SocketOptions.SocketType.STREAM; 499 500 EventLoopGroup elGroup = new EventLoopGroup(1); 501 ServerBootstrap bootstrap = new ServerBootstrap(elGroup); 502 ClientBootstrap clientBootstrap = new ClientBootstrap(elGroup, null); 503 504 final boolean[] connectionShutdown = {false}; 505 506 final String[] receivedOperationName = new String[]{null}; 507 final String[] receivedContinuationPayload = new String[]{null}; 508 final List<Header>[] receivedHeadersServer = new List[]{null}; 509 510 Header serverStrHeader = Header.createHeader("serverStrHeaderName", "serverStrHeaderValue"); 511 Header serverIntHeader = Header.createHeader("serverIntHeaderName", 25); 512 513 final byte[] responsePayload = "{ \"message\": \"this is a response message\" }".getBytes(StandardCharsets.UTF_8); 514 final ServerConnection[] serverConnections = {null}; 515 Lock semaphoreLock = new ReentrantLock(); 516 Condition semaphore = semaphoreLock.newCondition(); 517 518 ServerListener listener = new ServerListener("127.0.0.1", (short)8037, socketOptions, null, bootstrap, new ServerListenerHandler() { 519 private ServerConnectionHandler connectionHandler = null; 520 521 public ServerConnectionHandler onNewConnection(ServerConnection serverConnection, int errorCode) { 522 serverConnections[0] = serverConnection; 523 connectionHandler = new ServerConnectionHandler(serverConnection) { 524 525 @Override 526 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 527 int responseMessageFlag = MessageFlags.ConnectionAccepted.getByteValue(); 528 MessageType acceptResponseType = MessageType.ConnectAck; 529 530 connection.sendProtocolMessage(null, null, acceptResponseType, responseMessageFlag); 531 } 532 533 @Override 534 protected ServerConnectionContinuationHandler onIncomingStream(ServerConnectionContinuation continuation, String operationName) { 535 receivedOperationName[0] = operationName; 536 537 return new ServerConnectionContinuationHandler(continuation) { 538 @Override 539 protected void onContinuationMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 540 receivedContinuationPayload[0] = new String(payload, StandardCharsets.UTF_8); 541 receivedHeadersServer[0] = headers; 542 List<Header> responseHeaders = new ArrayList<>(); 543 responseHeaders.add(serverStrHeader); 544 responseHeaders.add(serverIntHeader); 545 546 continuation.sendMessage(responseHeaders, responsePayload, 547 MessageType.ApplicationError, 548 MessageFlags.TerminateStream.getByteValue()) 549 .whenComplete((res, ex) -> { 550 connection.closeConnection(0); 551 this.close(); 552 }); 553 } 554 }; 555 } 556 }; 557 558 semaphoreLock.lock(); 559 semaphore.signal(); 560 semaphoreLock.unlock(); 561 return connectionHandler; 562 } 563 564 public void onConnectionShutdown(ServerConnection serverConnection, int errorCode) { 565 connectionShutdown[0] = true; 566 } 567 }); 568 569 final ClientConnection[] clientConnectionArray = {null}; 570 final List<Header>[] clientReceivedMessageHeaders = new List[]{null}; 571 final byte[][] clientReceivedPayload = {null}; 572 final MessageType[] clientReceivedMessageType = {null}; 573 final int[] clientReceivedMessageFlags = {-1}; 574 final boolean[] clientContinuationClosed = {false}; 575 576 CompletableFuture<Void> connectFuture = ClientConnection.connect("127.0.0.1", (short)8037, socketOptions, null, clientBootstrap, new ClientConnectionHandler() { 577 @Override 578 protected void onConnectionSetup(ClientConnection connection, int errorCode) { 579 clientConnectionArray[0] = connection; 580 } 581 582 @Override 583 protected void onProtocolMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 584 semaphoreLock.lock(); 585 semaphore.signal(); 586 semaphoreLock.unlock(); 587 } 588 }); 589 590 final byte[] connectPayload = "test connect payload".getBytes(StandardCharsets.UTF_8); 591 connectFuture.get(1, TimeUnit.SECONDS); 592 assertNotNull(clientConnectionArray[0]); 593 semaphoreLock.lock(); 594 semaphore.await(1, TimeUnit.SECONDS); 595 assertNotNull(serverConnections[0]); 596 clientConnectionArray[0].sendProtocolMessage(null, connectPayload, MessageType.Connect, 0); 597 semaphore.await(1, TimeUnit.SECONDS); 598 String operationName = "testOperation"; 599 600 ClientConnectionContinuation continuation = clientConnectionArray[0].newStream(new ClientConnectionContinuationHandler() { 601 @Override 602 protected void onContinuationMessage(List<Header> headers, byte[] payload, MessageType messageType, int messageFlags) { 603 semaphoreLock.lock(); 604 clientReceivedMessageHeaders[0] = headers; 605 clientReceivedMessageType[0] = messageType; 606 clientReceivedMessageFlags[0] = messageFlags; 607 clientReceivedPayload[0] = payload; 608 semaphoreLock.unlock(); 609 } 610 611 @Override 612 protected void onContinuationClosed() { 613 semaphoreLock.lock(); 614 clientContinuationClosed[0] = true; 615 semaphore.signal(); 616 semaphoreLock.unlock(); 617 super.onContinuationClosed(); 618 } 619 }); 620 assertNotNull(continuation); 621 622 final byte[] operationPayload = "{\"message\": \"message payload\"}".getBytes(StandardCharsets.UTF_8); 623 Header clientStrHeader = Header.createHeader("clientStrHeaderName", "clientStrHeaderValue"); 624 Header clientIntHeader = Header.createHeader("clientIntHeaderName", 35); 625 List<Header> clientHeaders = new ArrayList<>(); 626 clientHeaders.add(clientStrHeader); 627 clientHeaders.add(clientIntHeader); 628 continuation.activate(operationName, clientHeaders, operationPayload, MessageType.ApplicationMessage, 0).get(1, TimeUnit.SECONDS); 629 semaphore.await(1, TimeUnit.SECONDS); 630 631 assertArrayEquals(responsePayload, clientReceivedPayload[0]); 632 assertEquals(MessageType.ApplicationError, clientReceivedMessageType[0]); 633 assertEquals(MessageFlags.TerminateStream.getByteValue(), clientReceivedMessageFlags[0]); 634 assertNotNull(receivedHeadersServer[0]); 635 assertEquals(clientStrHeader.getName(), receivedHeadersServer[0].get(0).getName()); 636 assertEquals(clientStrHeader.getValueAsString(), receivedHeadersServer[0].get(0).getValueAsString()); 637 assertEquals(clientIntHeader.getName(), receivedHeadersServer[0].get(1).getName()); 638 assertEquals(clientIntHeader.getValueAsInt(), receivedHeadersServer[0].get(1).getValueAsInt()); 639 assertEquals(serverStrHeader.getName(), clientReceivedMessageHeaders[0].get(0).getName()); 640 assertEquals(serverStrHeader.getValueAsString(), clientReceivedMessageHeaders[0].get(0).getValueAsString()); 641 assertEquals(serverIntHeader.getName(), clientReceivedMessageHeaders[0].get(1).getName()); 642 assertEquals(serverIntHeader.getValueAsInt(), clientReceivedMessageHeaders[0].get(1).getValueAsInt()); 643 assertTrue(clientContinuationClosed[0]); 644 645 clientConnectionArray[0].getClosedFuture().get(1, TimeUnit.SECONDS); 646 serverConnections[0].getClosedFuture().get(1, TimeUnit.SECONDS); 647 semaphoreLock.unlock(); 648 assertTrue(connectionShutdown[0]); 649 assertNotNull(receivedOperationName[0]); 650 assertEquals(operationName, receivedOperationName[0]); 651 assertEquals(new String(operationPayload, StandardCharsets.UTF_8), receivedContinuationPayload[0]); 652 listener.close(); 653 listener.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 654 bootstrap.close(); 655 clientBootstrap.close(); 656 clientBootstrap.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 657 elGroup.close(); 658 elGroup.getShutdownCompleteFuture().get(1, TimeUnit.SECONDS); 659 socketOptions.close(); 660 } 661 } 662