1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Action handlers for the TaskAssignments service.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerimport collections 17*14675a02SAndroid Build Coastguard Workerimport dataclasses 18*14675a02SAndroid Build Coastguard Workerimport http 19*14675a02SAndroid Build Coastguard Workerimport threading 20*14675a02SAndroid Build Coastguard Workerfrom typing import Callable, Optional 21*14675a02SAndroid Build Coastguard Workerimport uuid 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Workerfrom absl import logging 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Workerfrom google.longrunning import operations_pb2 26*14675a02SAndroid Build Coastguard Workerfrom google.rpc import code_pb2 27*14675a02SAndroid Build Coastguard Workerfrom google.protobuf import text_format 28*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import aggregations 29*14675a02SAndroid Build Coastguard Workerfrom fcp.demo import http_actions 30*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import common_pb2 31*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 32*14675a02SAndroid Build Coastguard Workerfrom fcp.protos.federatedcompute import task_assignments_pb2 33*14675a02SAndroid Build Coastguard Worker 34*14675a02SAndroid Build Coastguard Worker_TaskAssignmentMode = ( 35*14675a02SAndroid Build Coastguard Worker eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 36*14675a02SAndroid Build Coastguard Worker) 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard Worker@dataclasses.dataclass(frozen=True) 40*14675a02SAndroid Build Coastguard Workerclass _Task: 41*14675a02SAndroid Build Coastguard Worker task_name: str 42*14675a02SAndroid Build Coastguard Worker aggregation_session_id: str 43*14675a02SAndroid Build Coastguard Worker init_checkpoint: common_pb2.Resource 44*14675a02SAndroid Build Coastguard Worker plan: common_pb2.Resource 45*14675a02SAndroid Build Coastguard Worker federated_select_uri_template: str 46*14675a02SAndroid Build Coastguard Worker 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Workerclass Service: 49*14675a02SAndroid Build Coastguard Worker """Implements the TaskAssignments service.""" 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Worker def __init__(self, population_name: str, 52*14675a02SAndroid Build Coastguard Worker forwarding_info: Callable[[], common_pb2.ForwardingInfo], 53*14675a02SAndroid Build Coastguard Worker aggregations_service: aggregations.Service): 54*14675a02SAndroid Build Coastguard Worker self._population_name = population_name 55*14675a02SAndroid Build Coastguard Worker self._forwarding_info = forwarding_info 56*14675a02SAndroid Build Coastguard Worker self._aggregations_service = aggregations_service 57*14675a02SAndroid Build Coastguard Worker self._single_assignment_tasks = collections.deque() 58*14675a02SAndroid Build Coastguard Worker self._multiple_assignment_tasks: list[_Task] = [] 59*14675a02SAndroid Build Coastguard Worker self._tasks_lock = threading.Lock() 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker def add_task( 62*14675a02SAndroid Build Coastguard Worker self, 63*14675a02SAndroid Build Coastguard Worker task_name: str, 64*14675a02SAndroid Build Coastguard Worker task_assignment_mode: _TaskAssignmentMode, 65*14675a02SAndroid Build Coastguard Worker aggregation_session_id: str, 66*14675a02SAndroid Build Coastguard Worker plan: common_pb2.Resource, 67*14675a02SAndroid Build Coastguard Worker init_checkpoint: common_pb2.Resource, 68*14675a02SAndroid Build Coastguard Worker federated_select_uri_template: str, 69*14675a02SAndroid Build Coastguard Worker ): 70*14675a02SAndroid Build Coastguard Worker """Adds a new task to the service.""" 71*14675a02SAndroid Build Coastguard Worker task = _Task( 72*14675a02SAndroid Build Coastguard Worker task_name=task_name, 73*14675a02SAndroid Build Coastguard Worker aggregation_session_id=aggregation_session_id, 74*14675a02SAndroid Build Coastguard Worker init_checkpoint=init_checkpoint, 75*14675a02SAndroid Build Coastguard Worker plan=plan, 76*14675a02SAndroid Build Coastguard Worker federated_select_uri_template=federated_select_uri_template, 77*14675a02SAndroid Build Coastguard Worker ) 78*14675a02SAndroid Build Coastguard Worker if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE: 79*14675a02SAndroid Build Coastguard Worker with self._tasks_lock: 80*14675a02SAndroid Build Coastguard Worker self._single_assignment_tasks.append(task) 81*14675a02SAndroid Build Coastguard Worker elif ( 82*14675a02SAndroid Build Coastguard Worker task_assignment_mode 83*14675a02SAndroid Build Coastguard Worker == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE 84*14675a02SAndroid Build Coastguard Worker ): 85*14675a02SAndroid Build Coastguard Worker with self._tasks_lock: 86*14675a02SAndroid Build Coastguard Worker self._multiple_assignment_tasks.append(task) 87*14675a02SAndroid Build Coastguard Worker else: 88*14675a02SAndroid Build Coastguard Worker raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.') 89*14675a02SAndroid Build Coastguard Worker 90*14675a02SAndroid Build Coastguard Worker def remove_task(self, aggregation_session_id: str): 91*14675a02SAndroid Build Coastguard Worker """Removes a task from the service.""" 92*14675a02SAndroid Build Coastguard Worker with self._tasks_lock: 93*14675a02SAndroid Build Coastguard Worker for task in self._single_assignment_tasks: 94*14675a02SAndroid Build Coastguard Worker if task.aggregation_session_id == aggregation_session_id: 95*14675a02SAndroid Build Coastguard Worker self._single_assignment_tasks.remove(task) 96*14675a02SAndroid Build Coastguard Worker return 97*14675a02SAndroid Build Coastguard Worker for task in self._multiple_assignment_tasks: 98*14675a02SAndroid Build Coastguard Worker if task.aggregation_session_id == aggregation_session_id: 99*14675a02SAndroid Build Coastguard Worker self._multiple_assignment_tasks.remove(task) 100*14675a02SAndroid Build Coastguard Worker return 101*14675a02SAndroid Build Coastguard Worker raise KeyError(aggregation_session_id) 102*14675a02SAndroid Build Coastguard Worker 103*14675a02SAndroid Build Coastguard Worker @property 104*14675a02SAndroid Build Coastguard Worker def _current_task(self) -> Optional[_Task]: 105*14675a02SAndroid Build Coastguard Worker with self._tasks_lock: 106*14675a02SAndroid Build Coastguard Worker return ( 107*14675a02SAndroid Build Coastguard Worker self._single_assignment_tasks[0] 108*14675a02SAndroid Build Coastguard Worker if self._single_assignment_tasks 109*14675a02SAndroid Build Coastguard Worker else None 110*14675a02SAndroid Build Coastguard Worker ) 111*14675a02SAndroid Build Coastguard Worker 112*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 113*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.TaskAssignments', 114*14675a02SAndroid Build Coastguard Worker method='StartTaskAssignment') 115*14675a02SAndroid Build Coastguard Worker def start_task_assignment( 116*14675a02SAndroid Build Coastguard Worker self, request: task_assignments_pb2.StartTaskAssignmentRequest 117*14675a02SAndroid Build Coastguard Worker ) -> operations_pb2.Operation: 118*14675a02SAndroid Build Coastguard Worker """Handles a StartTaskAssignment request.""" 119*14675a02SAndroid Build Coastguard Worker if request.population_name != self._population_name: 120*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 121*14675a02SAndroid Build Coastguard Worker 122*14675a02SAndroid Build Coastguard Worker # NOTE: A production implementation should consider whether the current task 123*14675a02SAndroid Build Coastguard Worker # supports `request.client_version` before assigning the client. Given that 124*14675a02SAndroid Build Coastguard Worker # all clients may not be eligible for all tasks, consider more sophisticated 125*14675a02SAndroid Build Coastguard Worker # assignment than a FIFO queue. 126*14675a02SAndroid Build Coastguard Worker task = self._current_task 127*14675a02SAndroid Build Coastguard Worker if task: 128*14675a02SAndroid Build Coastguard Worker logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id, 129*14675a02SAndroid Build Coastguard Worker task.task_name) 130*14675a02SAndroid Build Coastguard Worker # NOTE: If a production implementation of the Aggregations service cannot 131*14675a02SAndroid Build Coastguard Worker # always pre-authorize clients (e.g., due to rate-limiting incoming 132*14675a02SAndroid Build Coastguard Worker # clients), this code should either retry the operation or return a 133*14675a02SAndroid Build Coastguard Worker # non-permanent error to the client (e.g., UNAVAILABLE). 134*14675a02SAndroid Build Coastguard Worker authorization_token = self._aggregations_service.pre_authorize_clients( 135*14675a02SAndroid Build Coastguard Worker task.aggregation_session_id, num_tokens=1)[0] 136*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse( 137*14675a02SAndroid Build Coastguard Worker task_assignment=task_assignments_pb2.TaskAssignment( 138*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=self._forwarding_info(), 139*14675a02SAndroid Build Coastguard Worker aggregation_info=( 140*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 141*14675a02SAndroid Build Coastguard Worker ), 142*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 143*14675a02SAndroid Build Coastguard Worker aggregation_id=task.aggregation_session_id, 144*14675a02SAndroid Build Coastguard Worker authorization_token=authorization_token, 145*14675a02SAndroid Build Coastguard Worker task_name=task.task_name, 146*14675a02SAndroid Build Coastguard Worker init_checkpoint=task.init_checkpoint, 147*14675a02SAndroid Build Coastguard Worker plan=task.plan, 148*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 149*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 150*14675a02SAndroid Build Coastguard Worker uri_template=task.federated_select_uri_template 151*14675a02SAndroid Build Coastguard Worker ) 152*14675a02SAndroid Build Coastguard Worker ), 153*14675a02SAndroid Build Coastguard Worker ) 154*14675a02SAndroid Build Coastguard Worker ) 155*14675a02SAndroid Build Coastguard Worker else: 156*14675a02SAndroid Build Coastguard Worker # NOTE: Instead of immediately rejecting clients, a production 157*14675a02SAndroid Build Coastguard Worker # implementation may keep around some number of clients to be assigned to 158*14675a02SAndroid Build Coastguard Worker # queued tasks or even future rounds of the current task (depending on how 159*14675a02SAndroid Build Coastguard Worker # quickly rounds complete). 160*14675a02SAndroid Build Coastguard Worker logging.debug('[%s] StartTaskAssignment: rejected', request.session_id) 161*14675a02SAndroid Build Coastguard Worker response = task_assignments_pb2.StartTaskAssignmentResponse( 162*14675a02SAndroid Build Coastguard Worker rejection_info=common_pb2.RejectionInfo()) 163*14675a02SAndroid Build Coastguard Worker 164*14675a02SAndroid Build Coastguard Worker # If task assignment took significant time, we return a longrunning 165*14675a02SAndroid Build Coastguard Worker # Operation; since this implementation makes assignment decisions right 166*14675a02SAndroid Build Coastguard Worker # away, we can return an already-completed operation. 167*14675a02SAndroid Build Coastguard Worker op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True) 168*14675a02SAndroid Build Coastguard Worker op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata()) 169*14675a02SAndroid Build Coastguard Worker op.response.Pack(response) 170*14675a02SAndroid Build Coastguard Worker return op 171*14675a02SAndroid Build Coastguard Worker 172*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 173*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.TaskAssignments', 174*14675a02SAndroid Build Coastguard Worker method='PerformMultipleTaskAssignments') 175*14675a02SAndroid Build Coastguard Worker def perform_multiple_task_assignments( 176*14675a02SAndroid Build Coastguard Worker self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest 177*14675a02SAndroid Build Coastguard Worker ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse: 178*14675a02SAndroid Build Coastguard Worker """Handles a PerformMultipleTaskAssignments request.""" 179*14675a02SAndroid Build Coastguard Worker if request.population_name != self._population_name: 180*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 181*14675a02SAndroid Build Coastguard Worker 182*14675a02SAndroid Build Coastguard Worker task_assignments = [] 183*14675a02SAndroid Build Coastguard Worker with self._tasks_lock: 184*14675a02SAndroid Build Coastguard Worker for task in self._multiple_assignment_tasks: 185*14675a02SAndroid Build Coastguard Worker if task.task_name not in request.task_names: 186*14675a02SAndroid Build Coastguard Worker continue 187*14675a02SAndroid Build Coastguard Worker 188*14675a02SAndroid Build Coastguard Worker # NOTE: A production implementation should consider whether the task 189*14675a02SAndroid Build Coastguard Worker # supports `request.client_version` before assigning the client. 190*14675a02SAndroid Build Coastguard Worker 191*14675a02SAndroid Build Coastguard Worker authorization_token = self._aggregations_service.pre_authorize_clients( 192*14675a02SAndroid Build Coastguard Worker task.aggregation_session_id, num_tokens=1)[0] 193*14675a02SAndroid Build Coastguard Worker task_assignments.append( 194*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment( 195*14675a02SAndroid Build Coastguard Worker aggregation_data_forwarding_info=self._forwarding_info(), 196*14675a02SAndroid Build Coastguard Worker aggregation_info=( 197*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.TaskAssignment.AggregationInfo() 198*14675a02SAndroid Build Coastguard Worker ), 199*14675a02SAndroid Build Coastguard Worker session_id=request.session_id, 200*14675a02SAndroid Build Coastguard Worker aggregation_id=task.aggregation_session_id, 201*14675a02SAndroid Build Coastguard Worker authorization_token=authorization_token, 202*14675a02SAndroid Build Coastguard Worker task_name=task.task_name, 203*14675a02SAndroid Build Coastguard Worker init_checkpoint=task.init_checkpoint, 204*14675a02SAndroid Build Coastguard Worker plan=task.plan, 205*14675a02SAndroid Build Coastguard Worker federated_select_uri_info=( 206*14675a02SAndroid Build Coastguard Worker task_assignments_pb2.FederatedSelectUriInfo( 207*14675a02SAndroid Build Coastguard Worker uri_template=task.federated_select_uri_template 208*14675a02SAndroid Build Coastguard Worker ) 209*14675a02SAndroid Build Coastguard Worker ), 210*14675a02SAndroid Build Coastguard Worker ) 211*14675a02SAndroid Build Coastguard Worker ) 212*14675a02SAndroid Build Coastguard Worker 213*14675a02SAndroid Build Coastguard Worker return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse( 214*14675a02SAndroid Build Coastguard Worker task_assignments=task_assignments) 215*14675a02SAndroid Build Coastguard Worker 216*14675a02SAndroid Build Coastguard Worker @http_actions.proto_action( 217*14675a02SAndroid Build Coastguard Worker service='google.internal.federatedcompute.v1.TaskAssignments', 218*14675a02SAndroid Build Coastguard Worker method='ReportTaskResult') 219*14675a02SAndroid Build Coastguard Worker def report_task_result( 220*14675a02SAndroid Build Coastguard Worker self, request: task_assignments_pb2.ReportTaskResultRequest 221*14675a02SAndroid Build Coastguard Worker ) -> task_assignments_pb2.ReportTaskResultResponse: 222*14675a02SAndroid Build Coastguard Worker """Handles a ReportTaskResult request.""" 223*14675a02SAndroid Build Coastguard Worker if request.population_name != self._population_name: 224*14675a02SAndroid Build Coastguard Worker raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 225*14675a02SAndroid Build Coastguard Worker logging.log( 226*14675a02SAndroid Build Coastguard Worker (logging.DEBUG if request.computation_status_code == code_pb2.OK else 227*14675a02SAndroid Build Coastguard Worker logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id, 228*14675a02SAndroid Build Coastguard Worker code_pb2.Code.Name(request.computation_status_code), 229*14675a02SAndroid Build Coastguard Worker text_format.MessageToString(request.client_stats, as_one_line=True)) 230*14675a02SAndroid Build Coastguard Worker return task_assignments_pb2.ReportTaskResultResponse() 231