xref: /aosp_15_r20/external/federated-compute/fcp/demo/task_assignments.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Action handlers for the TaskAssignments service."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerimport collections
17*14675a02SAndroid Build Coastguard Workerimport dataclasses
18*14675a02SAndroid Build Coastguard Workerimport http
19*14675a02SAndroid Build Coastguard Workerimport threading
20*14675a02SAndroid Build Coastguard Workerfrom typing import Callable, Optional
21*14675a02SAndroid Build Coastguard Workerimport uuid
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Workerfrom absl import logging
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Workerfrom google.longrunning import operations_pb2
26*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2
27*14675a02SAndroid Build Coastguard Workerfrom google.protobuf import text_format
28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations
29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions
30*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2
31*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import task_assignments_pb2
33*14675a02SAndroid Build Coastguard Worker
34*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = (
35*14675a02SAndroid Build Coastguard Worker    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
36*14675a02SAndroid Build Coastguard Worker)
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker
39*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True)
40*14675a02SAndroid Build Coastguard Workerclass _Task:
41*14675a02SAndroid Build Coastguard Worker  task_name: str
42*14675a02SAndroid Build Coastguard Worker  aggregation_session_id: str
43*14675a02SAndroid Build Coastguard Worker  init_checkpoint: common_pb2.Resource
44*14675a02SAndroid Build Coastguard Worker  plan: common_pb2.Resource
45*14675a02SAndroid Build Coastguard Worker  federated_select_uri_template: str
46*14675a02SAndroid Build Coastguard Worker
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Workerclass Service:
49*14675a02SAndroid Build Coastguard Worker  """Implements the TaskAssignments service."""
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Worker  def __init__(self, population_name: str,
52*14675a02SAndroid Build Coastguard Worker               forwarding_info: Callable[[], common_pb2.ForwardingInfo],
53*14675a02SAndroid Build Coastguard Worker               aggregations_service: aggregations.Service):
54*14675a02SAndroid Build Coastguard Worker    self._population_name = population_name
55*14675a02SAndroid Build Coastguard Worker    self._forwarding_info = forwarding_info
56*14675a02SAndroid Build Coastguard Worker    self._aggregations_service = aggregations_service
57*14675a02SAndroid Build Coastguard Worker    self._single_assignment_tasks = collections.deque()
58*14675a02SAndroid Build Coastguard Worker    self._multiple_assignment_tasks: list[_Task] = []
59*14675a02SAndroid Build Coastguard Worker    self._tasks_lock = threading.Lock()
60*14675a02SAndroid Build Coastguard Worker
61*14675a02SAndroid Build Coastguard Worker  def add_task(
62*14675a02SAndroid Build Coastguard Worker      self,
63*14675a02SAndroid Build Coastguard Worker      task_name: str,
64*14675a02SAndroid Build Coastguard Worker      task_assignment_mode: _TaskAssignmentMode,
65*14675a02SAndroid Build Coastguard Worker      aggregation_session_id: str,
66*14675a02SAndroid Build Coastguard Worker      plan: common_pb2.Resource,
67*14675a02SAndroid Build Coastguard Worker      init_checkpoint: common_pb2.Resource,
68*14675a02SAndroid Build Coastguard Worker      federated_select_uri_template: str,
69*14675a02SAndroid Build Coastguard Worker  ):
70*14675a02SAndroid Build Coastguard Worker    """Adds a new task to the service."""
71*14675a02SAndroid Build Coastguard Worker    task = _Task(
72*14675a02SAndroid Build Coastguard Worker        task_name=task_name,
73*14675a02SAndroid Build Coastguard Worker        aggregation_session_id=aggregation_session_id,
74*14675a02SAndroid Build Coastguard Worker        init_checkpoint=init_checkpoint,
75*14675a02SAndroid Build Coastguard Worker        plan=plan,
76*14675a02SAndroid Build Coastguard Worker        federated_select_uri_template=federated_select_uri_template,
77*14675a02SAndroid Build Coastguard Worker    )
78*14675a02SAndroid Build Coastguard Worker    if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE:
79*14675a02SAndroid Build Coastguard Worker      with self._tasks_lock:
80*14675a02SAndroid Build Coastguard Worker        self._single_assignment_tasks.append(task)
81*14675a02SAndroid Build Coastguard Worker    elif (
82*14675a02SAndroid Build Coastguard Worker        task_assignment_mode
83*14675a02SAndroid Build Coastguard Worker        == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE
84*14675a02SAndroid Build Coastguard Worker    ):
85*14675a02SAndroid Build Coastguard Worker      with self._tasks_lock:
86*14675a02SAndroid Build Coastguard Worker        self._multiple_assignment_tasks.append(task)
87*14675a02SAndroid Build Coastguard Worker    else:
88*14675a02SAndroid Build Coastguard Worker      raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.')
89*14675a02SAndroid Build Coastguard Worker
90*14675a02SAndroid Build Coastguard Worker  def remove_task(self, aggregation_session_id: str):
91*14675a02SAndroid Build Coastguard Worker    """Removes a task from the service."""
92*14675a02SAndroid Build Coastguard Worker    with self._tasks_lock:
93*14675a02SAndroid Build Coastguard Worker      for task in self._single_assignment_tasks:
94*14675a02SAndroid Build Coastguard Worker        if task.aggregation_session_id == aggregation_session_id:
95*14675a02SAndroid Build Coastguard Worker          self._single_assignment_tasks.remove(task)
96*14675a02SAndroid Build Coastguard Worker          return
97*14675a02SAndroid Build Coastguard Worker      for task in self._multiple_assignment_tasks:
98*14675a02SAndroid Build Coastguard Worker        if task.aggregation_session_id == aggregation_session_id:
99*14675a02SAndroid Build Coastguard Worker          self._multiple_assignment_tasks.remove(task)
100*14675a02SAndroid Build Coastguard Worker          return
101*14675a02SAndroid Build Coastguard Worker      raise KeyError(aggregation_session_id)
102*14675a02SAndroid Build Coastguard Worker
103*14675a02SAndroid Build Coastguard Worker  @property
104*14675a02SAndroid Build Coastguard Worker  def _current_task(self) -> Optional[_Task]:
105*14675a02SAndroid Build Coastguard Worker    with self._tasks_lock:
106*14675a02SAndroid Build Coastguard Worker      return (
107*14675a02SAndroid Build Coastguard Worker          self._single_assignment_tasks[0]
108*14675a02SAndroid Build Coastguard Worker          if self._single_assignment_tasks
109*14675a02SAndroid Build Coastguard Worker          else None
110*14675a02SAndroid Build Coastguard Worker      )
111*14675a02SAndroid Build Coastguard Worker
112*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
113*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.TaskAssignments',
114*14675a02SAndroid Build Coastguard Worker      method='StartTaskAssignment')
115*14675a02SAndroid Build Coastguard Worker  def start_task_assignment(
116*14675a02SAndroid Build Coastguard Worker      self, request: task_assignments_pb2.StartTaskAssignmentRequest
117*14675a02SAndroid Build Coastguard Worker  ) -> operations_pb2.Operation:
118*14675a02SAndroid Build Coastguard Worker    """Handles a StartTaskAssignment request."""
119*14675a02SAndroid Build Coastguard Worker    if request.population_name != self._population_name:
120*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
121*14675a02SAndroid Build Coastguard Worker
122*14675a02SAndroid Build Coastguard Worker    # NOTE: A production implementation should consider whether the current task
123*14675a02SAndroid Build Coastguard Worker    # supports `request.client_version` before assigning the client. Given that
124*14675a02SAndroid Build Coastguard Worker    # all clients may not be eligible for all tasks, consider more sophisticated
125*14675a02SAndroid Build Coastguard Worker    # assignment than a FIFO queue.
126*14675a02SAndroid Build Coastguard Worker    task = self._current_task
127*14675a02SAndroid Build Coastguard Worker    if task:
128*14675a02SAndroid Build Coastguard Worker      logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id,
129*14675a02SAndroid Build Coastguard Worker                    task.task_name)
130*14675a02SAndroid Build Coastguard Worker      # NOTE: If a production implementation of the Aggregations service cannot
131*14675a02SAndroid Build Coastguard Worker      # always pre-authorize clients (e.g., due to rate-limiting incoming
132*14675a02SAndroid Build Coastguard Worker      # clients), this code should either retry the operation or return a
133*14675a02SAndroid Build Coastguard Worker      # non-permanent error to the client (e.g., UNAVAILABLE).
134*14675a02SAndroid Build Coastguard Worker      authorization_token = self._aggregations_service.pre_authorize_clients(
135*14675a02SAndroid Build Coastguard Worker          task.aggregation_session_id, num_tokens=1)[0]
136*14675a02SAndroid Build Coastguard Worker      response = task_assignments_pb2.StartTaskAssignmentResponse(
137*14675a02SAndroid Build Coastguard Worker          task_assignment=task_assignments_pb2.TaskAssignment(
138*14675a02SAndroid Build Coastguard Worker              aggregation_data_forwarding_info=self._forwarding_info(),
139*14675a02SAndroid Build Coastguard Worker              aggregation_info=(
140*14675a02SAndroid Build Coastguard Worker                  task_assignments_pb2.TaskAssignment.AggregationInfo()
141*14675a02SAndroid Build Coastguard Worker              ),
142*14675a02SAndroid Build Coastguard Worker              session_id=request.session_id,
143*14675a02SAndroid Build Coastguard Worker              aggregation_id=task.aggregation_session_id,
144*14675a02SAndroid Build Coastguard Worker              authorization_token=authorization_token,
145*14675a02SAndroid Build Coastguard Worker              task_name=task.task_name,
146*14675a02SAndroid Build Coastguard Worker              init_checkpoint=task.init_checkpoint,
147*14675a02SAndroid Build Coastguard Worker              plan=task.plan,
148*14675a02SAndroid Build Coastguard Worker              federated_select_uri_info=(
149*14675a02SAndroid Build Coastguard Worker                  task_assignments_pb2.FederatedSelectUriInfo(
150*14675a02SAndroid Build Coastguard Worker                      uri_template=task.federated_select_uri_template
151*14675a02SAndroid Build Coastguard Worker                  )
152*14675a02SAndroid Build Coastguard Worker              ),
153*14675a02SAndroid Build Coastguard Worker          )
154*14675a02SAndroid Build Coastguard Worker      )
155*14675a02SAndroid Build Coastguard Worker    else:
156*14675a02SAndroid Build Coastguard Worker      # NOTE: Instead of immediately rejecting clients, a production
157*14675a02SAndroid Build Coastguard Worker      # implementation may keep around some number of clients to be assigned to
158*14675a02SAndroid Build Coastguard Worker      # queued tasks or even future rounds of the current task (depending on how
159*14675a02SAndroid Build Coastguard Worker      # quickly rounds complete).
160*14675a02SAndroid Build Coastguard Worker      logging.debug('[%s] StartTaskAssignment: rejected', request.session_id)
161*14675a02SAndroid Build Coastguard Worker      response = task_assignments_pb2.StartTaskAssignmentResponse(
162*14675a02SAndroid Build Coastguard Worker          rejection_info=common_pb2.RejectionInfo())
163*14675a02SAndroid Build Coastguard Worker
164*14675a02SAndroid Build Coastguard Worker    # If task assignment took significant time, we return a longrunning
165*14675a02SAndroid Build Coastguard Worker    # Operation; since this implementation makes assignment decisions right
166*14675a02SAndroid Build Coastguard Worker    # away, we can return an already-completed operation.
167*14675a02SAndroid Build Coastguard Worker    op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True)
168*14675a02SAndroid Build Coastguard Worker    op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata())
169*14675a02SAndroid Build Coastguard Worker    op.response.Pack(response)
170*14675a02SAndroid Build Coastguard Worker    return op
171*14675a02SAndroid Build Coastguard Worker
172*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
173*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.TaskAssignments',
174*14675a02SAndroid Build Coastguard Worker      method='PerformMultipleTaskAssignments')
175*14675a02SAndroid Build Coastguard Worker  def perform_multiple_task_assignments(
176*14675a02SAndroid Build Coastguard Worker      self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest
177*14675a02SAndroid Build Coastguard Worker  ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse:
178*14675a02SAndroid Build Coastguard Worker    """Handles a PerformMultipleTaskAssignments request."""
179*14675a02SAndroid Build Coastguard Worker    if request.population_name != self._population_name:
180*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
181*14675a02SAndroid Build Coastguard Worker
182*14675a02SAndroid Build Coastguard Worker    task_assignments = []
183*14675a02SAndroid Build Coastguard Worker    with self._tasks_lock:
184*14675a02SAndroid Build Coastguard Worker      for task in self._multiple_assignment_tasks:
185*14675a02SAndroid Build Coastguard Worker        if task.task_name not in request.task_names:
186*14675a02SAndroid Build Coastguard Worker          continue
187*14675a02SAndroid Build Coastguard Worker
188*14675a02SAndroid Build Coastguard Worker        # NOTE: A production implementation should consider whether the task
189*14675a02SAndroid Build Coastguard Worker        # supports `request.client_version` before assigning the client.
190*14675a02SAndroid Build Coastguard Worker
191*14675a02SAndroid Build Coastguard Worker        authorization_token = self._aggregations_service.pre_authorize_clients(
192*14675a02SAndroid Build Coastguard Worker            task.aggregation_session_id, num_tokens=1)[0]
193*14675a02SAndroid Build Coastguard Worker        task_assignments.append(
194*14675a02SAndroid Build Coastguard Worker            task_assignments_pb2.TaskAssignment(
195*14675a02SAndroid Build Coastguard Worker                aggregation_data_forwarding_info=self._forwarding_info(),
196*14675a02SAndroid Build Coastguard Worker                aggregation_info=(
197*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.TaskAssignment.AggregationInfo()
198*14675a02SAndroid Build Coastguard Worker                ),
199*14675a02SAndroid Build Coastguard Worker                session_id=request.session_id,
200*14675a02SAndroid Build Coastguard Worker                aggregation_id=task.aggregation_session_id,
201*14675a02SAndroid Build Coastguard Worker                authorization_token=authorization_token,
202*14675a02SAndroid Build Coastguard Worker                task_name=task.task_name,
203*14675a02SAndroid Build Coastguard Worker                init_checkpoint=task.init_checkpoint,
204*14675a02SAndroid Build Coastguard Worker                plan=task.plan,
205*14675a02SAndroid Build Coastguard Worker                federated_select_uri_info=(
206*14675a02SAndroid Build Coastguard Worker                    task_assignments_pb2.FederatedSelectUriInfo(
207*14675a02SAndroid Build Coastguard Worker                        uri_template=task.federated_select_uri_template
208*14675a02SAndroid Build Coastguard Worker                    )
209*14675a02SAndroid Build Coastguard Worker                ),
210*14675a02SAndroid Build Coastguard Worker            )
211*14675a02SAndroid Build Coastguard Worker        )
212*14675a02SAndroid Build Coastguard Worker
213*14675a02SAndroid Build Coastguard Worker    return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse(
214*14675a02SAndroid Build Coastguard Worker        task_assignments=task_assignments)
215*14675a02SAndroid Build Coastguard Worker
216*14675a02SAndroid Build Coastguard Worker  @http_actions.proto_action(
217*14675a02SAndroid Build Coastguard Worker      service='google.internal.federatedcompute.v1.TaskAssignments',
218*14675a02SAndroid Build Coastguard Worker      method='ReportTaskResult')
219*14675a02SAndroid Build Coastguard Worker  def report_task_result(
220*14675a02SAndroid Build Coastguard Worker      self, request: task_assignments_pb2.ReportTaskResultRequest
221*14675a02SAndroid Build Coastguard Worker  ) -> task_assignments_pb2.ReportTaskResultResponse:
222*14675a02SAndroid Build Coastguard Worker    """Handles a ReportTaskResult request."""
223*14675a02SAndroid Build Coastguard Worker    if request.population_name != self._population_name:
224*14675a02SAndroid Build Coastguard Worker      raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND)
225*14675a02SAndroid Build Coastguard Worker    logging.log(
226*14675a02SAndroid Build Coastguard Worker        (logging.DEBUG if request.computation_status_code == code_pb2.OK else
227*14675a02SAndroid Build Coastguard Worker         logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id,
228*14675a02SAndroid Build Coastguard Worker        code_pb2.Code.Name(request.computation_status_code),
229*14675a02SAndroid Build Coastguard Worker        text_format.MessageToString(request.client_stats, as_one_line=True))
230*14675a02SAndroid Build Coastguard Worker    return task_assignments_pb2.ReportTaskResultResponse()
231