1# Copyright 2016 gRPC authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""HTTP2 Test Server""" 15 16import argparse 17import logging 18import sys 19 20import http2_base_server 21import test_data_frame_padding 22import test_goaway 23import test_max_streams 24import test_ping 25import test_rst_after_data 26import test_rst_after_header 27import test_rst_during_data 28import twisted 29import twisted.internet 30import twisted.internet.endpoints 31import twisted.internet.reactor 32 33_TEST_CASE_MAPPING = { 34 "rst_after_header": test_rst_after_header.TestcaseRstStreamAfterHeader, 35 "rst_after_data": test_rst_after_data.TestcaseRstStreamAfterData, 36 "rst_during_data": test_rst_during_data.TestcaseRstStreamDuringData, 37 "goaway": test_goaway.TestcaseGoaway, 38 "ping": test_ping.TestcasePing, 39 "max_streams": test_max_streams.TestcaseSettingsMaxStreams, 40 # Positive tests below: 41 "data_frame_padding": test_data_frame_padding.TestDataFramePadding, 42 "no_df_padding_sanity_test": test_data_frame_padding.TestDataFramePadding, 43} 44 45_exit_code = 0 46 47 48class H2Factory(twisted.internet.protocol.Factory): 49 def __init__(self, testcase): 50 logging.info("Creating H2Factory for new connection (%s)", testcase) 51 self._num_streams = 0 52 self._testcase = testcase 53 54 def buildProtocol(self, addr): 55 self._num_streams += 1 56 logging.info("New Connection: %d" % self._num_streams) 57 if not _TEST_CASE_MAPPING.has_key(self._testcase): 58 logging.error("Unknown test case: %s" % self._testcase) 59 assert 0 60 else: 61 t = _TEST_CASE_MAPPING[self._testcase] 62 63 if self._testcase == "goaway": 64 return t(self._num_streams).get_base_server() 65 elif self._testcase == "no_df_padding_sanity_test": 66 return t(use_padding=False).get_base_server() 67 else: 68 return t().get_base_server() 69 70 71def parse_arguments(): 72 parser = argparse.ArgumentParser() 73 parser.add_argument( 74 "--base_port", 75 type=int, 76 default=8080, 77 help=( 78 "base port to run the servers (default: 8080). One test server is " 79 "started on each incrementing port, beginning with base_port, in" 80 " the " 81 "following order: data_frame_padding,goaway,max_streams," 82 "no_df_padding_sanity_test,ping,rst_after_data,rst_after_header," 83 "rst_during_data" 84 ), 85 ) 86 return parser.parse_args() 87 88 89def listen(endpoint, test_case): 90 deferred = endpoint.listen(H2Factory(test_case)) 91 92 def listen_error(reason): 93 # If listening fails, we stop the reactor and exit the program 94 # with exit code 1. 95 global _exit_code 96 _exit_code = 1 97 logging.error("Listening failed: %s" % reason.value) 98 twisted.internet.reactor.stop() 99 100 deferred.addErrback(listen_error) 101 102 103def start_test_servers(base_port): 104 """Start one server per test case on incrementing port numbers 105 beginning with base_port""" 106 index = 0 107 for test_case in sorted(_TEST_CASE_MAPPING.keys()): 108 portnum = base_port + index 109 logging.warning("serving on port %d : %s" % (portnum, test_case)) 110 endpoint = twisted.internet.endpoints.TCP4ServerEndpoint( 111 twisted.internet.reactor, portnum, backlog=128 112 ) 113 # Wait until the reactor is running before calling endpoint.listen(). 114 twisted.internet.reactor.callWhenRunning(listen, endpoint, test_case) 115 116 index += 1 117 118 119if __name__ == "__main__": 120 logging.basicConfig( 121 format=( 122 "%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s" 123 ), 124 level=logging.INFO, 125 ) 126 args = parse_arguments() 127 start_test_servers(args.base_port) 128 twisted.internet.reactor.run() 129 sys.exit(_exit_code) 130