xref: /aosp_15_r20/external/federated-compute/fcp/demo/media.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Action handlers for file upload and download.
15
16In a production system, download would likely be handled by an external service;
17it's important that uploads are not handled separately to help ensure that
18unaggregated client data is only held ephemerally.
19"""
20
21import contextlib
22import http
23import threading
24from typing import Callable, Iterator, Optional
25import uuid
26
27from fcp.demo import http_actions
28from fcp.protos.federatedcompute import common_pb2
29
30
31class DownloadGroup:
32  """A group of downloadable files."""
33
34  def __init__(self, prefix: str, add_fn: Callable[[str, bytes, str], None]):
35    self._prefix = prefix
36    self._add_fn = add_fn
37
38  @property
39  def prefix(self) -> str:
40    """The path prefix for all files in this group."""
41    return self._prefix
42
43  def add(self,
44          name: str,
45          data: bytes,
46          content_type: str = 'application/octet-stream') -> str:
47    """Adds a file to the group.
48
49    Args:
50      name: The name of the new file.
51      data: The bytes to make available.
52      content_type: The content type to include in the response.
53
54    Returns:
55      The full path to the new file.
56
57    Raises:
58      KeyError if a file with that name has already been registered.
59    """
60    self._add_fn(name, data, content_type)
61    return self._prefix + name
62
63
64class Service:
65  """Implements a service for uploading and downloading data over HTTP."""
66
67  def __init__(self, forwarding_info: Callable[[], common_pb2.ForwardingInfo]):
68    self._forwarding_info = forwarding_info
69    self._lock = threading.Lock()
70    self._downloads: dict[str, dict[str, http_actions.HttpResponse]] = {}
71    self._uploads: dict[str, Optional[bytes]] = {}
72
73  @contextlib.contextmanager
74  def create_download_group(self) -> Iterator[DownloadGroup]:
75    """Creates a new group of downloadable files.
76
77    Files can be be added to this group using `DownloadGroup.add`. All files in
78    the group will be unregistered when the ContextManager goes out of scope.
79
80    Yields:
81      The download group to which files should be added.
82    """
83    group = str(uuid.uuid4())
84
85    def add_file(name: str, data: bytes, content_type: str) -> None:
86      with self._lock:
87        if name in self._downloads[group]:
88          raise KeyError(f'{name} already exists')
89        self._downloads[group][name] = http_actions.HttpResponse(
90            body=data,
91            headers={
92                'Content-Length': len(data),
93                'Content-Type': content_type,
94            })
95
96    with self._lock:
97      self._downloads[group] = {}
98    try:
99      yield DownloadGroup(
100          f'{self._forwarding_info().target_uri_prefix}data/{group}/', add_file)
101    finally:
102      with self._lock:
103        del self._downloads[group]
104
105  def register_upload(self) -> str:
106    """Registers a path for single-use upload, returning the resource name."""
107    name = str(uuid.uuid4())
108    with self._lock:
109      self._uploads[name] = None
110    return name
111
112  def finalize_upload(self, name: str) -> Optional[bytes]:
113    """Returns the data from an upload, if any."""
114    with self._lock:
115      return self._uploads.pop(name)
116
117  @http_actions.http_action(method='GET', pattern='/data/{group}/{name}')
118  def download(self, body: bytes, group: str,
119               name: str) -> http_actions.HttpResponse:
120    """Handles a download request."""
121    del body
122    try:
123      with self._lock:
124        return self._downloads[group][name]
125    except KeyError as e:
126      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) from e
127
128  @http_actions.http_action(
129      method='POST', pattern='/upload/v1/media/{name}?upload_protocol=raw')
130  def upload(self, body: bytes, name: str) -> http_actions.HttpResponse:
131    with self._lock:
132      if name not in self._uploads or self._uploads[name] is not None:
133        raise http_actions.HttpError(http.HTTPStatus.UNAUTHORIZED)
134      self._uploads[name] = body
135    return http_actions.HttpResponse(b'')
136