xref: /aosp_15_r20/external/federated-compute/fcp/demo/task_assignments.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Action handlers for the TaskAssignments service."""
15
16import collections
17import dataclasses
18import http
19import threading
20from typing import Callable, Optional
21import uuid
22
23from absl import logging
24
25from google.longrunning import operations_pb2
26from google.rpc import code_pb2
27from google.protobuf import text_format
28from fcp.demo import aggregations
29from fcp.demo import http_actions
30from fcp.protos.federatedcompute import common_pb2
31from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
32from fcp.protos.federatedcompute import task_assignments_pb2
33
34_TaskAssignmentMode = (
35    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
36)
37
38
39@dataclasses.dataclass(frozen=True)
40class _Task:
41  task_name: str
42  aggregation_session_id: str
43  init_checkpoint: common_pb2.Resource
44  plan: common_pb2.Resource
45  federated_select_uri_template: str
46
47
48class Service:
49  """Implements the TaskAssignments service."""
50
51  def __init__(self, population_name: str,
52               forwarding_info: Callable[[], common_pb2.ForwardingInfo],
53               aggregations_service: aggregations.Service):
54    self._population_name = population_name
55    self._forwarding_info = forwarding_info
56    self._aggregations_service = aggregations_service
57    self._single_assignment_tasks = collections.deque()
58    self._multiple_assignment_tasks: list[_Task] = []
59    self._tasks_lock = threading.Lock()
60
61  def add_task(
62      self,
63      task_name: str,
64      task_assignment_mode: _TaskAssignmentMode,
65      aggregation_session_id: str,
66      plan: common_pb2.Resource,
67      init_checkpoint: common_pb2.Resource,
68      federated_select_uri_template: str,
69  ):
70    """Adds a new task to the service."""
71    task = _Task(
72        task_name=task_name,
73        aggregation_session_id=aggregation_session_id,
74        init_checkpoint=init_checkpoint,
75        plan=plan,
76        federated_select_uri_template=federated_select_uri_template,
77    )
78    if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE:
79      with self._tasks_lock:
80        self._single_assignment_tasks.append(task)
81    elif (
82        task_assignment_mode
83        == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE
84    ):
85      with self._tasks_lock:
86        self._multiple_assignment_tasks.append(task)
87    else:
88      raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.')
89
90  def remove_task(self, aggregation_session_id: str):
91    """Removes a task from the service."""
92    with self._tasks_lock:
93      for task in self._single_assignment_tasks:
94        if task.aggregation_session_id == aggregation_session_id:
95          self._single_assignment_tasks.remove(task)
96          return
97      for task in self._multiple_assignment_tasks:
98        if task.aggregation_session_id == aggregation_session_id:
99          self._multiple_assignment_tasks.remove(task)
100          return
101      raise KeyError(aggregation_session_id)
102
103  @property
104  def _current_task(self) -> Optional[_Task]:
105    with self._tasks_lock:
106      return (
107          self._single_assignment_tasks[0]
108          if self._single_assignment_tasks
109          else None
110      )
111
112  @http_actions.proto_action(
113      service='google.internal.federatedcompute.v1.TaskAssignments',
114      method='StartTaskAssignment')
115  def start_task_assignment(
116      self, request: task_assignments_pb2.StartTaskAssignmentRequest
117  ) -> operations_pb2.Operation:
118    """Handles a StartTaskAssignment request."""
119    if request.population_name != self._population_name:
120      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
121
122    # NOTE: A production implementation should consider whether the current task
123    # supports `request.client_version` before assigning the client. Given that
124    # all clients may not be eligible for all tasks, consider more sophisticated
125    # assignment than a FIFO queue.
126    task = self._current_task
127    if task:
128      logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id,
129                    task.task_name)
130      # NOTE: If a production implementation of the Aggregations service cannot
131      # always pre-authorize clients (e.g., due to rate-limiting incoming
132      # clients), this code should either retry the operation or return a
133      # non-permanent error to the client (e.g., UNAVAILABLE).
134      authorization_token = self._aggregations_service.pre_authorize_clients(
135          task.aggregation_session_id, num_tokens=1)[0]
136      response = task_assignments_pb2.StartTaskAssignmentResponse(
137          task_assignment=task_assignments_pb2.TaskAssignment(
138              aggregation_data_forwarding_info=self._forwarding_info(),
139              aggregation_info=(
140                  task_assignments_pb2.TaskAssignment.AggregationInfo()
141              ),
142              session_id=request.session_id,
143              aggregation_id=task.aggregation_session_id,
144              authorization_token=authorization_token,
145              task_name=task.task_name,
146              init_checkpoint=task.init_checkpoint,
147              plan=task.plan,
148              federated_select_uri_info=(
149                  task_assignments_pb2.FederatedSelectUriInfo(
150                      uri_template=task.federated_select_uri_template
151                  )
152              ),
153          )
154      )
155    else:
156      # NOTE: Instead of immediately rejecting clients, a production
157      # implementation may keep around some number of clients to be assigned to
158      # queued tasks or even future rounds of the current task (depending on how
159      # quickly rounds complete).
160      logging.debug('[%s] StartTaskAssignment: rejected', request.session_id)
161      response = task_assignments_pb2.StartTaskAssignmentResponse(
162          rejection_info=common_pb2.RejectionInfo())
163
164    # If task assignment took significant time, we return a longrunning
165    # Operation; since this implementation makes assignment decisions right
166    # away, we can return an already-completed operation.
167    op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
168    op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata())
169    op.response.Pack(response)
170    return op
171
172  @http_actions.proto_action(
173      service='google.internal.federatedcompute.v1.TaskAssignments',
174      method='PerformMultipleTaskAssignments')
175  def perform_multiple_task_assignments(
176      self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest
177  ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse:
178    """Handles a PerformMultipleTaskAssignments request."""
179    if request.population_name != self._population_name:
180      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
181
182    task_assignments = []
183    with self._tasks_lock:
184      for task in self._multiple_assignment_tasks:
185        if task.task_name not in request.task_names:
186          continue
187
188        # NOTE: A production implementation should consider whether the task
189        # supports `request.client_version` before assigning the client.
190
191        authorization_token = self._aggregations_service.pre_authorize_clients(
192            task.aggregation_session_id, num_tokens=1)[0]
193        task_assignments.append(
194            task_assignments_pb2.TaskAssignment(
195                aggregation_data_forwarding_info=self._forwarding_info(),
196                aggregation_info=(
197                    task_assignments_pb2.TaskAssignment.AggregationInfo()
198                ),
199                session_id=request.session_id,
200                aggregation_id=task.aggregation_session_id,
201                authorization_token=authorization_token,
202                task_name=task.task_name,
203                init_checkpoint=task.init_checkpoint,
204                plan=task.plan,
205                federated_select_uri_info=(
206                    task_assignments_pb2.FederatedSelectUriInfo(
207                        uri_template=task.federated_select_uri_template
208                    )
209                ),
210            )
211        )
212
213    return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse(
214        task_assignments=task_assignments)
215
216  @http_actions.proto_action(
217      service='google.internal.federatedcompute.v1.TaskAssignments',
218      method='ReportTaskResult')
219  def report_task_result(
220      self, request: task_assignments_pb2.ReportTaskResultRequest
221  ) -> task_assignments_pb2.ReportTaskResultResponse:
222    """Handles a ReportTaskResult request."""
223    if request.population_name != self._population_name:
224      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
225    logging.log(
226        (logging.DEBUG if request.computation_status_code == code_pb2.OK else
227         logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id,
228        code_pb2.Code.Name(request.computation_status_code),
229        text_format.MessageToString(request.client_stats, as_one_line=True))
230    return task_assignments_pb2.ReportTaskResultResponse()
231