1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Action handlers for the TaskAssignments service.""" 15 16import collections 17import dataclasses 18import http 19import threading 20from typing import Callable, Optional 21import uuid 22 23from absl import logging 24 25from google.longrunning import operations_pb2 26from google.rpc import code_pb2 27from google.protobuf import text_format 28from fcp.demo import aggregations 29from fcp.demo import http_actions 30from fcp.protos.federatedcompute import common_pb2 31from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 32from fcp.protos.federatedcompute import task_assignments_pb2 33 34_TaskAssignmentMode = ( 35 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 36) 37 38 39@dataclasses.dataclass(frozen=True) 40class _Task: 41 task_name: str 42 aggregation_session_id: str 43 init_checkpoint: common_pb2.Resource 44 plan: common_pb2.Resource 45 federated_select_uri_template: str 46 47 48class Service: 49 """Implements the TaskAssignments service.""" 50 51 def __init__(self, population_name: str, 52 forwarding_info: Callable[[], common_pb2.ForwardingInfo], 53 aggregations_service: aggregations.Service): 54 self._population_name = population_name 55 self._forwarding_info = forwarding_info 56 self._aggregations_service = aggregations_service 57 self._single_assignment_tasks = collections.deque() 58 self._multiple_assignment_tasks: list[_Task] = [] 59 self._tasks_lock = threading.Lock() 60 61 def add_task( 62 self, 63 task_name: str, 64 task_assignment_mode: _TaskAssignmentMode, 65 aggregation_session_id: str, 66 plan: common_pb2.Resource, 67 init_checkpoint: common_pb2.Resource, 68 federated_select_uri_template: str, 69 ): 70 """Adds a new task to the service.""" 71 task = _Task( 72 task_name=task_name, 73 aggregation_session_id=aggregation_session_id, 74 init_checkpoint=init_checkpoint, 75 plan=plan, 76 federated_select_uri_template=federated_select_uri_template, 77 ) 78 if task_assignment_mode == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE: 79 with self._tasks_lock: 80 self._single_assignment_tasks.append(task) 81 elif ( 82 task_assignment_mode 83 == _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE 84 ): 85 with self._tasks_lock: 86 self._multiple_assignment_tasks.append(task) 87 else: 88 raise ValueError(f'Unsupport TaskAssignmentMode {task_assignment_mode}.') 89 90 def remove_task(self, aggregation_session_id: str): 91 """Removes a task from the service.""" 92 with self._tasks_lock: 93 for task in self._single_assignment_tasks: 94 if task.aggregation_session_id == aggregation_session_id: 95 self._single_assignment_tasks.remove(task) 96 return 97 for task in self._multiple_assignment_tasks: 98 if task.aggregation_session_id == aggregation_session_id: 99 self._multiple_assignment_tasks.remove(task) 100 return 101 raise KeyError(aggregation_session_id) 102 103 @property 104 def _current_task(self) -> Optional[_Task]: 105 with self._tasks_lock: 106 return ( 107 self._single_assignment_tasks[0] 108 if self._single_assignment_tasks 109 else None 110 ) 111 112 @http_actions.proto_action( 113 service='google.internal.federatedcompute.v1.TaskAssignments', 114 method='StartTaskAssignment') 115 def start_task_assignment( 116 self, request: task_assignments_pb2.StartTaskAssignmentRequest 117 ) -> operations_pb2.Operation: 118 """Handles a StartTaskAssignment request.""" 119 if request.population_name != self._population_name: 120 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 121 122 # NOTE: A production implementation should consider whether the current task 123 # supports `request.client_version` before assigning the client. Given that 124 # all clients may not be eligible for all tasks, consider more sophisticated 125 # assignment than a FIFO queue. 126 task = self._current_task 127 if task: 128 logging.debug('[%s] StartTaskAssignment: assigned %s', request.session_id, 129 task.task_name) 130 # NOTE: If a production implementation of the Aggregations service cannot 131 # always pre-authorize clients (e.g., due to rate-limiting incoming 132 # clients), this code should either retry the operation or return a 133 # non-permanent error to the client (e.g., UNAVAILABLE). 134 authorization_token = self._aggregations_service.pre_authorize_clients( 135 task.aggregation_session_id, num_tokens=1)[0] 136 response = task_assignments_pb2.StartTaskAssignmentResponse( 137 task_assignment=task_assignments_pb2.TaskAssignment( 138 aggregation_data_forwarding_info=self._forwarding_info(), 139 aggregation_info=( 140 task_assignments_pb2.TaskAssignment.AggregationInfo() 141 ), 142 session_id=request.session_id, 143 aggregation_id=task.aggregation_session_id, 144 authorization_token=authorization_token, 145 task_name=task.task_name, 146 init_checkpoint=task.init_checkpoint, 147 plan=task.plan, 148 federated_select_uri_info=( 149 task_assignments_pb2.FederatedSelectUriInfo( 150 uri_template=task.federated_select_uri_template 151 ) 152 ), 153 ) 154 ) 155 else: 156 # NOTE: Instead of immediately rejecting clients, a production 157 # implementation may keep around some number of clients to be assigned to 158 # queued tasks or even future rounds of the current task (depending on how 159 # quickly rounds complete). 160 logging.debug('[%s] StartTaskAssignment: rejected', request.session_id) 161 response = task_assignments_pb2.StartTaskAssignmentResponse( 162 rejection_info=common_pb2.RejectionInfo()) 163 164 # If task assignment took significant time, we return a longrunning 165 # Operation; since this implementation makes assignment decisions right 166 # away, we can return an already-completed operation. 167 op = operations_pb2.Operation(name=f'operations/{uuid.uuid4()}', done=True) 168 op.metadata.Pack(task_assignments_pb2.StartTaskAssignmentMetadata()) 169 op.response.Pack(response) 170 return op 171 172 @http_actions.proto_action( 173 service='google.internal.federatedcompute.v1.TaskAssignments', 174 method='PerformMultipleTaskAssignments') 175 def perform_multiple_task_assignments( 176 self, request: task_assignments_pb2.PerformMultipleTaskAssignmentsRequest 177 ) -> task_assignments_pb2.PerformMultipleTaskAssignmentsResponse: 178 """Handles a PerformMultipleTaskAssignments request.""" 179 if request.population_name != self._population_name: 180 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 181 182 task_assignments = [] 183 with self._tasks_lock: 184 for task in self._multiple_assignment_tasks: 185 if task.task_name not in request.task_names: 186 continue 187 188 # NOTE: A production implementation should consider whether the task 189 # supports `request.client_version` before assigning the client. 190 191 authorization_token = self._aggregations_service.pre_authorize_clients( 192 task.aggregation_session_id, num_tokens=1)[0] 193 task_assignments.append( 194 task_assignments_pb2.TaskAssignment( 195 aggregation_data_forwarding_info=self._forwarding_info(), 196 aggregation_info=( 197 task_assignments_pb2.TaskAssignment.AggregationInfo() 198 ), 199 session_id=request.session_id, 200 aggregation_id=task.aggregation_session_id, 201 authorization_token=authorization_token, 202 task_name=task.task_name, 203 init_checkpoint=task.init_checkpoint, 204 plan=task.plan, 205 federated_select_uri_info=( 206 task_assignments_pb2.FederatedSelectUriInfo( 207 uri_template=task.federated_select_uri_template 208 ) 209 ), 210 ) 211 ) 212 213 return task_assignments_pb2.PerformMultipleTaskAssignmentsResponse( 214 task_assignments=task_assignments) 215 216 @http_actions.proto_action( 217 service='google.internal.federatedcompute.v1.TaskAssignments', 218 method='ReportTaskResult') 219 def report_task_result( 220 self, request: task_assignments_pb2.ReportTaskResultRequest 221 ) -> task_assignments_pb2.ReportTaskResultResponse: 222 """Handles a ReportTaskResult request.""" 223 if request.population_name != self._population_name: 224 raise http_actions.HttpError(http.HTTPStatus.NOT_FOUND) 225 logging.log( 226 (logging.DEBUG if request.computation_status_code == code_pb2.OK else 227 logging.WARN), '[%s] ReportTaskResult: %s (%s)', request.session_id, 228 code_pb2.Code.Name(request.computation_status_code), 229 text_format.MessageToString(request.client_stats, as_one_line=True)) 230 return task_assignments_pb2.ReportTaskResultResponse() 231