xref: /aosp_15_r20/external/grpc-grpc/test/http2_test/http2_base_server.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import logging
16import struct
17
18import h2
19import h2.connection
20import messages_pb2
21import twisted
22import twisted.internet
23import twisted.internet.protocol
24
25_READ_CHUNK_SIZE = 16384
26_GRPC_HEADER_SIZE = 5
27_MIN_SETTINGS_MAX_FRAME_SIZE = 16384
28
29
30class H2ProtocolBaseServer(twisted.internet.protocol.Protocol):
31    def __init__(self):
32        self._conn = h2.connection.H2Connection(client_side=False)
33        self._recv_buffer = {}
34        self._handlers = {}
35        self._handlers["ConnectionMade"] = self.on_connection_made_default
36        self._handlers["DataReceived"] = self.on_data_received_default
37        self._handlers["WindowUpdated"] = self.on_window_update_default
38        self._handlers["RequestReceived"] = self.on_request_received_default
39        self._handlers["SendDone"] = self.on_send_done_default
40        self._handlers["ConnectionLost"] = self.on_connection_lost
41        self._handlers["PingAcknowledged"] = self.on_ping_acknowledged_default
42        self._stream_status = {}
43        self._send_remaining = {}
44        self._outstanding_pings = 0
45
46    def set_handlers(self, handlers):
47        self._handlers = handlers
48
49    def connectionMade(self):
50        self._handlers["ConnectionMade"]()
51
52    def connectionLost(self, reason):
53        self._handlers["ConnectionLost"](reason)
54
55    def on_connection_made_default(self):
56        logging.info("Connection Made")
57        self._conn.initiate_connection()
58        self.transport.setTcpNoDelay(True)
59        self.transport.write(self._conn.data_to_send())
60
61    def on_connection_lost(self, reason):
62        logging.info("Disconnected %s" % reason)
63
64    def dataReceived(self, data):
65        try:
66            events = self._conn.receive_data(data)
67        except h2.exceptions.ProtocolError:
68            # this try/except block catches exceptions due to race between sending
69            # GOAWAY and processing a response in flight.
70            return
71        if self._conn.data_to_send:
72            self.transport.write(self._conn.data_to_send())
73        for event in events:
74            if isinstance(
75                event, h2.events.RequestReceived
76            ) and self._handlers.has_key("RequestReceived"):
77                logging.info(
78                    "RequestReceived Event for stream: %d" % event.stream_id
79                )
80                self._handlers["RequestReceived"](event)
81            elif isinstance(
82                event, h2.events.DataReceived
83            ) and self._handlers.has_key("DataReceived"):
84                logging.info(
85                    "DataReceived Event for stream: %d" % event.stream_id
86                )
87                self._handlers["DataReceived"](event)
88            elif isinstance(
89                event, h2.events.WindowUpdated
90            ) and self._handlers.has_key("WindowUpdated"):
91                logging.info(
92                    "WindowUpdated Event for stream: %d" % event.stream_id
93                )
94                self._handlers["WindowUpdated"](event)
95            elif isinstance(
96                event, h2.events.PingAcknowledged
97            ) and self._handlers.has_key("PingAcknowledged"):
98                logging.info("PingAcknowledged Event")
99                self._handlers["PingAcknowledged"](event)
100        self.transport.write(self._conn.data_to_send())
101
102    def on_ping_acknowledged_default(self, event):
103        logging.info("ping acknowledged")
104        self._outstanding_pings -= 1
105
106    def on_data_received_default(self, event):
107        self._conn.acknowledge_received_data(len(event.data), event.stream_id)
108        self._recv_buffer[event.stream_id] += event.data
109
110    def on_request_received_default(self, event):
111        self._recv_buffer[event.stream_id] = ""
112        self._stream_id = event.stream_id
113        self._stream_status[event.stream_id] = True
114        self._conn.send_headers(
115            stream_id=event.stream_id,
116            headers=[
117                (":status", "200"),
118                ("content-type", "application/grpc"),
119                ("grpc-encoding", "identity"),
120                ("grpc-accept-encoding", "identity,deflate,gzip"),
121            ],
122        )
123        self.transport.write(self._conn.data_to_send())
124
125    def on_window_update_default(
126        self, _, pad_length=None, read_chunk_size=_READ_CHUNK_SIZE
127    ):
128        # try to resume sending on all active streams (update might be for connection)
129        for stream_id in self._send_remaining:
130            self.default_send(
131                stream_id,
132                pad_length=pad_length,
133                read_chunk_size=read_chunk_size,
134            )
135
136    def send_reset_stream(self):
137        self._conn.reset_stream(self._stream_id)
138        self.transport.write(self._conn.data_to_send())
139
140    def setup_send(
141        self,
142        data_to_send,
143        stream_id,
144        pad_length=None,
145        read_chunk_size=_READ_CHUNK_SIZE,
146    ):
147        logging.info("Setting up data to send for stream_id: %d" % stream_id)
148        self._send_remaining[stream_id] = len(data_to_send)
149        self._send_offset = 0
150        self._data_to_send = data_to_send
151        self.default_send(
152            stream_id, pad_length=pad_length, read_chunk_size=read_chunk_size
153        )
154
155    def default_send(
156        self, stream_id, pad_length=None, read_chunk_size=_READ_CHUNK_SIZE
157    ):
158        if not self._send_remaining.has_key(stream_id):
159            # not setup to send data yet
160            return
161
162        while self._send_remaining[stream_id] > 0:
163            lfcw = self._conn.local_flow_control_window(stream_id)
164            padding_bytes = pad_length + 1 if pad_length is not None else 0
165            if lfcw - padding_bytes <= 0:
166                logging.info(
167                    "Stream %d. lfcw: %d. padding bytes: %d. not enough"
168                    " quota yet" % (stream_id, lfcw, padding_bytes)
169                )
170                break
171            chunk_size = min(lfcw - padding_bytes, read_chunk_size)
172            bytes_to_send = min(chunk_size, self._send_remaining[stream_id])
173            logging.info(
174                "flow_control_window = %d. sending [%d:%d] stream_id %d."
175                " includes %d total padding bytes"
176                % (
177                    lfcw,
178                    self._send_offset,
179                    self._send_offset + bytes_to_send + padding_bytes,
180                    stream_id,
181                    padding_bytes,
182                )
183            )
184            # The receiver might allow sending frames larger than the http2 minimum
185            # max frame size (16384), but this test should never send more than 16384
186            # for simplicity (which is always legal).
187            if bytes_to_send + padding_bytes > _MIN_SETTINGS_MAX_FRAME_SIZE:
188                raise ValueError(
189                    "overload: sending %d" % (bytes_to_send + padding_bytes)
190                )
191            data = self._data_to_send[
192                self._send_offset : self._send_offset + bytes_to_send
193            ]
194            try:
195                self._conn.send_data(
196                    stream_id, data, end_stream=False, pad_length=pad_length
197                )
198            except h2.exceptions.ProtocolError:
199                logging.info("Stream %d is closed" % stream_id)
200                break
201            self._send_remaining[stream_id] -= bytes_to_send
202            self._send_offset += bytes_to_send
203            if self._send_remaining[stream_id] == 0:
204                self._handlers["SendDone"](stream_id)
205
206    def default_ping(self):
207        logging.info("sending ping")
208        self._outstanding_pings += 1
209        self._conn.ping(b"\x00" * 8)
210        self.transport.write(self._conn.data_to_send())
211
212    def on_send_done_default(self, stream_id):
213        if self._stream_status[stream_id]:
214            self._stream_status[stream_id] = False
215            self.default_send_trailer(stream_id)
216        else:
217            logging.error("Stream %d is already closed" % stream_id)
218
219    def default_send_trailer(self, stream_id):
220        logging.info("Sending trailer for stream id %d" % stream_id)
221        self._conn.send_headers(
222            stream_id, headers=[("grpc-status", "0")], end_stream=True
223        )
224        self.transport.write(self._conn.data_to_send())
225
226    @staticmethod
227    def default_response_data(response_size):
228        sresp = messages_pb2.SimpleResponse()
229        sresp.payload.body = b"\x00" * response_size
230        serialized_resp_proto = sresp.SerializeToString()
231        response_data = (
232            b"\x00"
233            + struct.pack("i", len(serialized_resp_proto))[::-1]
234            + serialized_resp_proto
235        )
236        return response_data
237
238    def parse_received_data(self, stream_id):
239        """returns a grpc framed string of bytes containing response proto of the size
240        asked in request"""
241        recv_buffer = self._recv_buffer[stream_id]
242        grpc_msg_size = struct.unpack("i", recv_buffer[1:5][::-1])[0]
243        if len(recv_buffer) != _GRPC_HEADER_SIZE + grpc_msg_size:
244            return None
245        req_proto_str = recv_buffer[5 : 5 + grpc_msg_size]
246        sr = messages_pb2.SimpleRequest()
247        sr.ParseFromString(req_proto_str)
248        logging.info("Parsed simple request for stream %d" % stream_id)
249        return sr
250