xref: /aosp_15_r20/external/federated-compute/fcp/demo/http_actions.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utilities for creating proto service and HTTP action handlers.
15
16The `@proto_action` function annotates a method as implementing a proto service
17method. The annotated method should have the type
18`Callable[[RequestMessage], ResponseMessage]`. The decorator will take care of
19transcoding to/from a HTTP request, similar to
20https://cloud.google.com/endpoints/docs/grpc/transcoding. The transcoding only
21supports proto-over-http ('?alt=proto').
22
23The `@http_action` function annotates a method as implementing a HTTP action at
24some request path. The annotated method will receive the request body, and
25should return a `HttpResponse`.
26
27The `create_handler` function merges one or more objects with decorated methods
28into a single request handler that's compatible with `http.server`.
29"""
30
31import collections
32import dataclasses
33import enum
34import gzip
35import http
36import http.server
37import re
38from typing import Any, Callable, Mapping, Match, Pattern, Type, TypeVar
39import urllib.parse
40import zlib
41
42from absl import logging
43
44from google.api import annotations_pb2
45from google.protobuf import descriptor_pool
46from google.protobuf import message
47from google.protobuf import message_factory
48
49_CallableT = TypeVar('_CallableT', bound=Callable)
50
51_HTTP_ACTION_ATTR = '_http_action_data'
52_FACTORY = message_factory.MessageFactory(descriptor_pool.Default())
53
54
55@dataclasses.dataclass(frozen=True)
56class HttpError(Exception):
57  """An Exception specifying the HTTP error to return."""
58  code: http.HTTPStatus
59
60
61@dataclasses.dataclass(frozen=True)
62class HttpResponse:
63  """Information for a successful HTTP response."""
64  body: bytes
65  headers: Mapping[str, str] = dataclasses.field(default_factory=lambda: {})
66
67
68def proto_action(*,
69                 service=str,
70                 method=str) -> Callable[[_CallableT], _CallableT]:
71  """Decorator annotating a method as handling a proto service method.
72
73  The `google.api.http` annotation on the method will determine what requests
74  will be handled by the decorated function. Only a subset of methods and path
75  patterns are currently supported.
76
77  The decorated method will be called with the request message; it should return
78  a response message or or throw an `HttpError`.
79
80  Args:
81    service: The full name of the proto service.
82    method: The name of the method.
83
84  Returns:
85    An annotated function.
86  """
87  try:
88    desc = _FACTORY.pool.FindServiceByName(service).FindMethodByName(method)
89  except KeyError as e:
90    raise ValueError(f'Unable to find /{service}.{method}.') from e
91
92  rule = desc.GetOptions().Extensions[annotations_pb2.http]
93  pattern_kind = rule.WhichOneof('pattern')
94  try:
95    http_method = _HttpMethod[pattern_kind.upper()]
96  except KeyError as e:
97    raise ValueError(
98        f'The google.api.http annotation on /{service}.{method} is invalid '
99        'or unsupported.') from e
100  path = _convert_pattern(getattr(rule, pattern_kind), alt_proto=True)
101
102  def handler(match: Match[str], body: bytes,
103              fn: Callable[[message.Message], message.Message]) -> HttpResponse:
104    request = _FACTORY.GetPrototype(desc.input_type)()
105    if rule.body == '*':
106      try:
107        request.ParseFromString(body)
108      except message.DecodeError as e:
109        raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
110    elif rule.body:
111      setattr(request, rule.body, body)
112    # Set any fields from the request path.
113    for prop, value in match.groupdict().items():
114      try:
115        unescaped = urllib.parse.unquote(value)
116      except UnicodeError as e:
117        raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
118      setattr(request, prop, unescaped)
119
120    response_body = fn(request).SerializeToString()
121    return HttpResponse(
122        body=response_body,
123        headers={
124            'Content-Length': len(response_body),
125            'Content-Type': 'application/x-protobuf'
126        })
127
128  def annotate_method(func: _CallableT) -> _CallableT:
129    setattr(func, _HTTP_ACTION_ATTR,
130            _HttpActionData(method=http_method, path=path, handler=handler))
131    return func
132
133  return annotate_method
134
135
136def http_action(*, method: str,
137                pattern: str) -> Callable[[_CallableT], _CallableT]:
138  """Decorator annotating a method as an HTTP action handler.
139
140  Request matching the method and pattern will be handled by the decorated
141  method. The pattern may contain bracket-enclosed keywords (e.g.,
142  '/data/{path}'), which will be matched against the request and passed
143  to the decorated function as keyword arguments.
144
145  The decorated method will be called with the request body (if any) and any
146  keyword args from the path pattern; it should return a `HttpResponse` or throw
147  an `HttpError`.
148
149  Args:
150    method: The type of HTTP method ('GET' or 'POST').
151    pattern: The url pattern to match.
152
153  Returns:
154    An annotated function.
155  """
156  try:
157    http_method = _HttpMethod[method.upper()]
158  except KeyError as e:
159    raise ValueError(f'unsupported HTTP method `{method}`') from e
160  path = _convert_pattern(pattern)
161
162  def handler(match: Match[str], body: bytes,
163              fn: Callable[[bytes], HttpResponse]) -> HttpResponse:
164    try:
165      args = {k: urllib.parse.unquote(v) for k, v in match.groupdict().items()}
166    except UnicodeError as e:
167      raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
168    return fn(body, **args)
169
170  def annotate_method(func: _CallableT) -> _CallableT:
171    setattr(func, _HTTP_ACTION_ATTR,
172            _HttpActionData(method=http_method, path=path, handler=handler))
173    return func
174
175  return annotate_method
176
177
178def create_handler(*services: Any) -> Type[http.server.BaseHTTPRequestHandler]:
179  """Builds a BaseHTTPRequestHandler that delegates to decorated methods.
180
181  The returned BaseHTTPRequestHandler class will route requests to decorated
182  methods of the provided services, or return 404 if the request path does not
183  match any action handlers. If the request path matches multiple registered
184  action handlers, it's unspecified which will be invoked.
185
186  Args:
187    *services: A list of objects with methods decorated with `@proto_action` or
188      `@http_action`.
189
190  Returns:
191    A BaseHTTPRequestHandler subclass.
192  """
193
194  # Collect all handlers, keyed by HTTP method.
195  handlers = collections.defaultdict(lambda: [])
196  for service in services:
197    for attr_name in dir(service):
198      attr = getattr(service, attr_name)
199      if not callable(attr):
200        continue
201      data = getattr(attr, _HTTP_ACTION_ATTR, None)
202      if isinstance(data, _HttpActionData):
203        handlers[data.method].append((data, attr))
204
205  format_handlers = lambda h: ''.join([f'\n  * {e[0].path.pattern}' for e in h])
206  logging.debug(
207      'Creating HTTP request handler for path patterns:\nGET:%s\nPOST:%s',
208      format_handlers(handlers[_HttpMethod.GET]),
209      format_handlers(handlers[_HttpMethod.POST]))
210
211  class RequestHandler(http.server.BaseHTTPRequestHandler):
212    """Handler that delegates to `handlers`."""
213
214    def do_GET(self) -> None:  # pylint:disable=invalid-name (override)
215      self._handle_request(_HttpMethod.GET, read_body=False)
216
217    def do_POST(self) -> None:  # pylint:disable=invalid-name (override)
218      self._handle_request(_HttpMethod.POST)
219
220    def _handle_request(self,
221                        method: _HttpMethod,
222                        read_body: bool = True) -> None:
223      """Reads and delegates an incoming request to a registered handler."""
224      for data, fn in handlers[method]:
225        match = data.path.fullmatch(self.path)
226        if match is None:
227          continue
228
229        try:
230          body = self._read_body() if read_body else b''
231          response = data.handler(match, body, fn)
232        except HttpError as e:
233          logging.debug('%s error: %s', self.path, e)
234          return self.send_error(e.code)
235        return self._send_response(response)
236
237      # If no handler matched the path, return an error.
238      self.send_error(http.HTTPStatus.NOT_FOUND)
239
240    def _read_body(self) -> bytes:
241      """Reads the body of the request."""
242      body = self.rfile.read(int(self.headers['Content-Length']))
243      if self.headers['Content-Encoding'] == 'gzip':
244        try:
245          body = gzip.decompress(body)
246        except (gzip.BadGzipFile, zlib.error) as e:
247          raise HttpError(http.HTTPStatus.BAD_REQUEST) from e
248      elif self.headers['Content-Encoding']:
249        logging.warning('Unsupported content encoding %s',
250                        self.headers['Content-Encoding'])
251        raise HttpError(http.HTTPStatus.BAD_REQUEST)
252      return body
253
254    def _send_response(self, response: HttpResponse) -> None:
255      """Sends a successful response message."""
256      self.send_response(http.HTTPStatus.OK)
257      for keyword, value in response.headers.items():
258        self.send_header(keyword, value)
259      self.end_headers()
260      self.wfile.write(response.body)
261
262  return RequestHandler
263
264
265class _HttpMethod(enum.Enum):
266  GET = 1
267  POST = 2
268
269
270@dataclasses.dataclass(frozen=True)
271class _HttpActionData:
272  """Data tracked for HTTP actions.
273
274  Attributes:
275    method: The name of the HTTP method to handle.
276    path: Requests matching this pattern will be handled.
277    handler: The handler function, which receives the path match, request body,
278      and decorated function.
279  """
280  method: _HttpMethod
281  path: Pattern[str]
282  handler: Callable[[Match[str], bytes, Callable[..., Any]], HttpResponse]
283
284
285def _convert_pattern(pattern: str, alt_proto=False) -> Pattern[str]:
286  """Converts a Google API pattern to a regexp with named groups."""
287  # Subfields are not supported and will generate a regexp compilation error.
288  pattern_regexp = re.sub(r'\\\{(.+?)\\\}', r'(?P<\1>[^/?]*)',
289                          re.escape(pattern))
290  if alt_proto:
291    pattern_regexp += r'\?%24alt=proto'
292  try:
293    return re.compile(pattern_regexp)
294  except re.error as e:
295    raise ValueError(f'unable to convert `{pattern}` to a regexp') from e
296