# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Action handlers for the TaskAssignments service.""" import collections import dataclasses import http import threading from typing import Callable, Optional import uuid from absl import logging from google.longrunning import operations_pb2 from google.rpc import code_pb2 from google.protobuf import text_format from fcp.demo import aggregations from fcp.demo import http_actions from fcp.protos.federatedcompute import common_pb2 from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 from fcp.protos.federatedcompute import task_assignments_pb2 _TaskAssignmentMode = ( eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode ) @dataclasses.dataclass(frozen=True) class _Task: task_name: str aggregation_session_id: str init_checkpoint: common_pb2.Resource plan: common_pb2.Resource federated_select_uri_template: str class Service: """Implements the TaskAssignments service.""" def __init__(self, population_name: str, forwarding_info: Callable[[], common_pb2.ForwardingInfo], aggregations_service: aggregations.Service): self._population_name = population_name self._forwarding_info = forwarding_info self._aggregations_service = aggregations_service self._single_assignment_tasks = collections.deque() self._multiple_assignment_tasks: list[_Task] = [] self._tasks_lock = threading.Lock() def add_task( self, task_name: str, task_assignment_mode: _TaskAssignmentMode, aggregation_session_id: str, plan: common_pb2.Resource, init_checkpoint: common_pb2.Resource, federated_select_uri_template: str, ): """Adds a new task to the service.""" task = _Task( task_name=task_name, aggregation_session_id=aggregation_session_id, init_checkpoint=init_checkpoint, plan=plan, federated_select_uri_template=federated_select_uri_template, ) if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE: with self._tasks_lock: self._single_assignment_tasks.append(task) elif ( task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE ): with self._tasks_lock: self._multiple_assignment_tasks.append(task) else: raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.') def remove_task(self, aggregation_session_id: str): """Removes a task from the service.""" with self._tasks_lock: for task in self._single_assignment_tasks: if task.aggregation_session_id == aggregation_session_id: self._single_assignment_tasks.remove(task) return for task in self._multiple_assignment_tasks: if task.aggregation_session_id == aggregation_session_id: self._multiple_assignment_tasks.remove(task) return raise KeyError(aggregation_session_id) @property def _current_task(self) -> Optional[_Task]: with self._tasks_lock: return ( self._single_assignment_tasks[0] if self._single_assignment_tasks else None ) @http_actions.proto_action( service='google.internal.federatedcompute.v1.TaskAssignments', method='StartTaskAssignment') def start_task_assignment( self, request: task_assignments_pb2.StartTaskAssignmentRequest ) -> operations_pb2.Operation: """Handles a StartTaskAssignment request.""" if request.population_name != self._population_name: raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) # NOTE: A production implementation should consider whether the current task # supports `request.client_version` before assigning the client. Given that # all clients may not be eligible for all tasks, consider more sophisticated # assignment than a FIFO queue. task = self._current_task if task: logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id, task.task_name) # NOTE: If a production implementation of the Aggregations service cannot # always pre-authorize clients (e.g., due to rate-limiting incoming # clients), this code should either retry the operation or return a # non-permanent error to the client (e.g., UNAVAILABLE). authorization_token = self._aggregations_service.pre_authorize_clients( task.aggregation_session_id, num_tokens=1)[0] response = task_assignments_pb2.StartTaskAssignmentResponse( task_assignment=task_assignments_pb2.TaskAssignment( aggregation_data_forwarding_info=self._forwarding_info(), aggregation_info=( task_assignments_pb2.TaskAssignment.AggregationInfo() ), session_id=request.session_id, aggregation_id=task.aggregation_session_id, authorization_token=authorization_token, task_name=task.task_name, init_checkpoint=task.init_checkpoint, plan=task.plan, federated_select_uri_info=( task_assignments_pb2.FederatedSelectUriInfo( uri_template=task.federated_select_uri_template ) ), ) ) else: # NOTE: Instead of immediately rejecting clients, a production # implementation may keep around some number of clients to be assigned to # queued tasks or even future rounds of the current task (depending on how # quickly rounds complete). logging.debug('[%s] StartTaskAssignment: rejected', request.session_id) response = task_assignments_pb2.StartTaskAssignmentResponse( rejection_info=common_pb2.RejectionInfo()) # If task assignment took significant time, we return a longrunning # Operation; since this implementation makes assignment decisions right # away, we can return an already-completed operation. op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True) op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata()) op.response.Pack(response) return op @http_actions.proto_action( service='google.internal.federatedcompute.v1.TaskAssignments', method='PerformMultipleTaskAssignments') def perform_multiple_task_assignments( self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse: """Handles a PerformMultipleTaskAssignments request.""" if request.population_name != self._population_name: raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) task_assignments = [] with self._tasks_lock: for task in self._multiple_assignment_tasks: if task.task_name not in request.task_names: continue # NOTE: A production implementation should consider whether the task # supports `request.client_version` before assigning the client. authorization_token = self._aggregations_service.pre_authorize_clients( task.aggregation_session_id, num_tokens=1)[0] task_assignments.append( task_assignments_pb2.TaskAssignment( aggregation_data_forwarding_info=self._forwarding_info(), aggregation_info=( task_assignments_pb2.TaskAssignment.AggregationInfo() ), session_id=request.session_id, aggregation_id=task.aggregation_session_id, authorization_token=authorization_token, task_name=task.task_name, init_checkpoint=task.init_checkpoint, plan=task.plan, federated_select_uri_info=( task_assignments_pb2.FederatedSelectUriInfo( uri_template=task.federated_select_uri_template ) ), ) ) return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse( task_assignments=task_assignments) @http_actions.proto_action( service='google.internal.federatedcompute.v1.TaskAssignments', method='ReportTaskResult') def report_task_result( self, request: task_assignments_pb2.ReportTaskResultRequest ) -> task_assignments_pb2.ReportTaskResultResponse: """Handles a ReportTaskResult request.""" if request.population_name != self._population_name: raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) logging.log( (logging.DEBUG if request.computation_status_code == code_pb2.OK else logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id, code_pb2.Code.Name(request.computation_status_code), text_format.MessageToString(request.client_stats, as_one_line=True)) return task_assignments_pb2.ReportTaskResultResponse()