1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""TFF FederatedDataSource for the demo Federated Computation platform.""" 15 16import dataclasses 17import functools 18import re 19from typing import Optional, Union 20 21import tensorflow as tf 22import tensorflow_federated as tff 23 24from fcp.protos import plan_pb2 25from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 26 27POPULATION_NAME_REGEX = re.compile(r'\w+(/\w+)*') 28 29_NestedExampleSelector = Union[plan_pb2.ExampleSelector, 30 dict[str, '_NestedExampleSelector']] 31_TaskAssignmentMode = ( 32 eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode 33) 34 35 36@dataclasses.dataclass 37class DataSelectionConfig: 38 population_name: str 39 example_selector: _NestedExampleSelector 40 task_assignment_mode: _TaskAssignmentMode 41 num_clients: int 42 43 44class FederatedDataSource(tff.program.FederatedDataSource): 45 """A FederatedDataSource for use with the demo platform. 46 47 A FederatedDataSource represents a population of client devices and the set of 48 on-device data over which computations should be invoked. 49 """ 50 51 _FEDERATED_TYPE = tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS) 52 53 def __init__( 54 self, 55 population_name: str, 56 example_selector: _NestedExampleSelector, 57 task_assignment_mode: _TaskAssignmentMode = ( 58 _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE 59 ), 60 ): 61 """Constructs a new FederatedDataSource object. 62 63 Args: 64 population_name: The name of the population to execute computations on. 65 example_selector: A `plan_pb2.ExampleSelector` or a structure of 66 ExampleSelectors indicating CLIENTS-placed data to execute over. 67 task_assignment_mode: The TaskAssignmentMode to use for this computation. 68 """ 69 if not POPULATION_NAME_REGEX.fullmatch(population_name): 70 raise ValueError( 71 f'population_name must match "{POPULATION_NAME_REGEX.pattern}".') 72 self._population_name = population_name 73 self._example_selector = example_selector 74 self._task_assignment_mode = task_assignment_mode 75 76 @property 77 def population_name(self) -> str: 78 """The name of the population from which examples will be retrieved.""" 79 return self._population_name 80 81 @property 82 def example_selector(self) -> _NestedExampleSelector: 83 """The NestedExampleSelector used to obtain the examples.""" 84 return self._example_selector 85 86 @property 87 def task_assignment_mode(self) -> _TaskAssignmentMode: 88 """The TaskAssignmentMode to use for this computation.""" 89 return self._task_assignment_mode 90 91 @functools.cached_property 92 def federated_type(self) -> tff.FederatedType: 93 94 def get_struct_type(value): 95 if isinstance(value, dict): 96 return tff.StructType([ 97 (k, get_struct_type(v)) for k, v in value.items() 98 ]) 99 # ExternalDataset always returns a sequence of tf.strings, which should be 100 # serialized `tf.train.Example` protos. 101 return tff.SequenceType(tf.string) 102 103 return tff.FederatedType( 104 get_struct_type(self._example_selector), tff.CLIENTS) 105 106 @functools.cached_property 107 def capabilities(self) -> list[tff.program.Capability]: 108 return [tff.program.Capability.SUPPORTS_REUSE] 109 110 def iterator(self) -> tff.program.FederatedDataSourceIterator: 111 return _FederatedDataSourceIterator(self) 112 113 114class _FederatedDataSourceIterator(tff.program.FederatedDataSourceIterator): 115 """A `FederatedDataSourceIterator` for use with the demo platform.""" 116 117 def __init__(self, data_source: FederatedDataSource): 118 self._data_source = data_source 119 120 @classmethod 121 def from_bytes(cls, data: bytes) -> '_FederatedDataSourceIterator': 122 """Deserializes the object from bytes.""" 123 raise NotImplementedError 124 125 def to_bytes(self) -> bytes: 126 """Serializes the object to bytes.""" 127 raise NotImplementedError 128 129 @property 130 def federated_type(self): 131 return self._data_source.federated_type 132 133 def select(self, num_clients: Optional[int] = None) -> DataSelectionConfig: 134 if num_clients is None or num_clients <= 0: 135 raise ValueError('num_clients must be positive.') 136 return DataSelectionConfig( 137 self._data_source.population_name, 138 self._data_source.example_selector, 139 self._data_source.task_assignment_mode, 140 num_clients, 141 ) 142