1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""HTTP2 Test Server"""
15
16import argparse
17import logging
18import sys
19
20import http2_base_server
21import test_data_frame_padding
22import test_goaway
23import test_max_streams
24import test_ping
25import test_rst_after_data
26import test_rst_after_header
27import test_rst_during_data
28import twisted
29import twisted.internet
30import twisted.internet.endpoints
31import twisted.internet.reactor
32
33_TEST_CASE_MAPPING = {
34    'rst_after_header': test_rst_after_header.TestcaseRstStreamAfterHeader,
35    'rst_after_data': test_rst_after_data.TestcaseRstStreamAfterData,
36    'rst_during_data': test_rst_during_data.TestcaseRstStreamDuringData,
37    'goaway': test_goaway.TestcaseGoaway,
38    'ping': test_ping.TestcasePing,
39    'max_streams': test_max_streams.TestcaseSettingsMaxStreams,
40
41    # Positive tests below:
42    'data_frame_padding': test_data_frame_padding.TestDataFramePadding,
43    'no_df_padding_sanity_test': test_data_frame_padding.TestDataFramePadding,
44}
45
46_exit_code = 0
47
48
49class H2Factory(twisted.internet.protocol.Factory):
50
51    def __init__(self, testcase):
52        logging.info('Creating H2Factory for new connection (%s)', testcase)
53        self._num_streams = 0
54        self._testcase = testcase
55
56    def buildProtocol(self, addr):
57        self._num_streams += 1
58        logging.info('New Connection: %d' % self._num_streams)
59        if not _TEST_CASE_MAPPING.has_key(self._testcase):
60            logging.error('Unknown test case: %s' % self._testcase)
61            assert (0)
62        else:
63            t = _TEST_CASE_MAPPING[self._testcase]
64
65        if self._testcase == 'goaway':
66            return t(self._num_streams).get_base_server()
67        elif self._testcase == 'no_df_padding_sanity_test':
68            return t(use_padding=False).get_base_server()
69        else:
70            return t().get_base_server()
71
72
73def parse_arguments():
74    parser = argparse.ArgumentParser()
75    parser.add_argument(
76        '--base_port',
77        type=int,
78        default=8080,
79        help='base port to run the servers (default: 8080). One test server is '
80        'started on each incrementing port, beginning with base_port, in the '
81        'following order: data_frame_padding,goaway,max_streams,'
82        'no_df_padding_sanity_test,ping,rst_after_data,rst_after_header,'
83        'rst_during_data')
84    return parser.parse_args()
85
86
87def listen(endpoint, test_case):
88    deferred = endpoint.listen(H2Factory(test_case))
89
90    def listen_error(reason):
91        # If listening fails, we stop the reactor and exit the program
92        # with exit code 1.
93        global _exit_code
94        _exit_code = 1
95        logging.error('Listening failed: %s' % reason.value)
96        twisted.internet.reactor.stop()
97
98    deferred.addErrback(listen_error)
99
100
101def start_test_servers(base_port):
102    """ Start one server per test case on incrementing port numbers
103  beginning with base_port """
104    index = 0
105    for test_case in sorted(_TEST_CASE_MAPPING.keys()):
106        portnum = base_port + index
107        logging.warning('serving on port %d : %s' % (portnum, test_case))
108        endpoint = twisted.internet.endpoints.TCP4ServerEndpoint(
109            twisted.internet.reactor, portnum, backlog=128)
110        # Wait until the reactor is running before calling endpoint.listen().
111        twisted.internet.reactor.callWhenRunning(listen, endpoint, test_case)
112
113        index += 1
114
115
116if __name__ == '__main__':
117    logging.basicConfig(
118        format=
119        '%(levelname) -10s %(asctime)s %(module)s:%(lineno)s | %(message)s',
120        level=logging.INFO)
121    args = parse_arguments()
122    start_test_servers(args.base_port)
123    twisted.internet.reactor.run()
124    sys.exit(_exit_code)
125