xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_data_source.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""TFF FederatedDataSource for the demo Federated Computation platform."""
15
16import dataclasses
17import functools
18import re
19from typing import Optional, Union
20
21import tensorflow as tf
22import tensorflow_federated as tff
23
24from fcp.protos import plan_pb2
25from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2
26
27POPULATION_NAME_REGEX = re.compile(r'\w+(/\w+)*')
28
29_NestedExampleSelector = Union[plan_pb2.ExampleSelector,
30                               dict[str, '_NestedExampleSelector']]
31_TaskAssignmentMode = (
32    eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode
33)
34
35
36@dataclasses.dataclass
37class DataSelectionConfig:
38  population_name: str
39  example_selector: _NestedExampleSelector
40  task_assignment_mode: _TaskAssignmentMode
41  num_clients: int
42
43
44class FederatedDataSource(tff.program.FederatedDataSource):
45  """A FederatedDataSource for use with the demo platform.
46
47  A FederatedDataSource represents a population of client devices and the set of
48  on-device data over which computations should be invoked.
49  """
50
51  _FEDERATED_TYPE = tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS)
52
53  def __init__(
54      self,
55      population_name: str,
56      example_selector: _NestedExampleSelector,
57      task_assignment_mode: _TaskAssignmentMode = (
58          _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE
59      ),
60  ):
61    """Constructs a new FederatedDataSource object.
62
63    Args:
64      population_name: The name of the population to execute computations on.
65      example_selector: A `plan_pb2.ExampleSelector` or a structure of
66        ExampleSelectors indicating CLIENTS-placed data to execute over.
67      task_assignment_mode: The TaskAssignmentMode to use for this computation.
68    """
69    if not POPULATION_NAME_REGEX.fullmatch(population_name):
70      raise ValueError(
71          f'population_name must match "{POPULATION_NAME_REGEX.pattern}".')
72    self._population_name = population_name
73    self._example_selector = example_selector
74    self._task_assignment_mode = task_assignment_mode
75
76  @property
77  def population_name(self) -> str:
78    """The name of the population from which examples will be retrieved."""
79    return self._population_name
80
81  @property
82  def example_selector(self) -> _NestedExampleSelector:
83    """The NestedExampleSelector used to obtain the examples."""
84    return self._example_selector
85
86  @property
87  def task_assignment_mode(self) -> _TaskAssignmentMode:
88    """The TaskAssignmentMode to use for this computation."""
89    return self._task_assignment_mode
90
91  @functools.cached_property
92  def federated_type(self) -> tff.FederatedType:
93
94    def get_struct_type(value):
95      if isinstance(value, dict):
96        return tff.StructType([
97            (k, get_struct_type(v)) for k, v in value.items()
98        ])
99      # ExternalDataset always returns a sequence of tf.strings, which should be
100      # serialized `tf.train.Example` protos.
101      return tff.SequenceType(tf.string)
102
103    return tff.FederatedType(
104        get_struct_type(self._example_selector), tff.CLIENTS)
105
106  @functools.cached_property
107  def capabilities(self) -> list[tff.program.Capability]:
108    return [tff.program.Capability.SUPPORTS_REUSE]
109
110  def iterator(self) -> tff.program.FederatedDataSourceIterator:
111    return _FederatedDataSourceIterator(self)
112
113
114class _FederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):
115  """A `FederatedDataSourceIterator` for use with the demo platform."""
116
117  def __init__(self, data_source: FederatedDataSource):
118    self._data_source = data_source
119
120  @classmethod
121  def from_bytes(cls, data: bytes) -> '_FederatedDataSourceIterator':
122    """Deserializes the object from bytes."""
123    raise NotImplementedError
124
125  def to_bytes(self) -> bytes:
126    """Serializes the object to bytes."""
127    raise NotImplementedError
128
129  @property
130  def federated_type(self):
131    return self._data_source.federated_type
132
133  def select(self, num_clients: Optional[int] = None) -> DataSelectionConfig:
134    if num_clients is None or num_clients <= 0:
135      raise ValueError('num_clients must be positive.')
136    return DataSelectionConfig(
137        self._data_source.population_name,
138        self._data_source.example_selector,
139        self._data_source.task_assignment_mode,
140        num_clients,
141    )
142