xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/data_spec.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""A class to specify on-device dataset inputs."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable
17*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional, Union
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
20*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
21*14675a02SAndroid Build Coastguard Worker
22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks
23*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Workerclass DataSpec:
27*14675a02SAndroid Build Coastguard Worker  """A specification of a single dataset input."""
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker  __slots__ = (
30*14675a02SAndroid Build Coastguard Worker      '_example_selector_proto',
31*14675a02SAndroid Build Coastguard Worker      '_preprocessing_fn',
32*14675a02SAndroid Build Coastguard Worker      '_preprocessing_comp',
33*14675a02SAndroid Build Coastguard Worker      '_fingerprint',
34*14675a02SAndroid Build Coastguard Worker  )
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard Worker  def __init__(
37*14675a02SAndroid Build Coastguard Worker      self,
38*14675a02SAndroid Build Coastguard Worker      example_selector_proto: plan_pb2.ExampleSelector,
39*14675a02SAndroid Build Coastguard Worker      preprocessing_fn: Optional[
40*14675a02SAndroid Build Coastguard Worker          Callable[[tf.data.Dataset], tf.data.Dataset]
41*14675a02SAndroid Build Coastguard Worker      ] = None,
42*14675a02SAndroid Build Coastguard Worker  ):
43*14675a02SAndroid Build Coastguard Worker    """Constructs a specification of a dataset input.
44*14675a02SAndroid Build Coastguard Worker
45*14675a02SAndroid Build Coastguard Worker    Args:
46*14675a02SAndroid Build Coastguard Worker      example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto.
47*14675a02SAndroid Build Coastguard Worker      preprocessing_fn: A callable that accepts as an argument the raw input
48*14675a02SAndroid Build Coastguard Worker        `tf.data.Dataset` with `string`-serialized items, performs any desired
49*14675a02SAndroid Build Coastguard Worker        preprocessing such as deserialization, filtering, batching, and
50*14675a02SAndroid Build Coastguard Worker        formatting, and returns the transformed `tf.data.Dataset` as a result.
51*14675a02SAndroid Build Coastguard Worker        If preprocessing_fn is set to None, it is expected that any client data
52*14675a02SAndroid Build Coastguard Worker        preprocessing has already been incorporated into the `tff.Computation`
53*14675a02SAndroid Build Coastguard Worker        that this `DataSpec` is associated with.
54*14675a02SAndroid Build Coastguard Worker
55*14675a02SAndroid Build Coastguard Worker    Raises:
56*14675a02SAndroid Build Coastguard Worker      TypeError: If the types of the arguments are invalid.
57*14675a02SAndroid Build Coastguard Worker    """
58*14675a02SAndroid Build Coastguard Worker    type_checks.check_type(
59*14675a02SAndroid Build Coastguard Worker        example_selector_proto,
60*14675a02SAndroid Build Coastguard Worker        plan_pb2.ExampleSelector,
61*14675a02SAndroid Build Coastguard Worker        name='example_selector_proto',
62*14675a02SAndroid Build Coastguard Worker    )
63*14675a02SAndroid Build Coastguard Worker    if preprocessing_fn is not None:
64*14675a02SAndroid Build Coastguard Worker      type_checks.check_callable(preprocessing_fn, name='preprocessing_fn')
65*14675a02SAndroid Build Coastguard Worker    self._example_selector_proto = example_selector_proto
66*14675a02SAndroid Build Coastguard Worker    self._preprocessing_fn = preprocessing_fn
67*14675a02SAndroid Build Coastguard Worker    # Set once self.preprocessing_comp is accessed, as we can't call
68*14675a02SAndroid Build Coastguard Worker    # tff.computation in __init__.
69*14675a02SAndroid Build Coastguard Worker    self._preprocessing_comp = None
70*14675a02SAndroid Build Coastguard Worker
71*14675a02SAndroid Build Coastguard Worker  @property
72*14675a02SAndroid Build Coastguard Worker  def example_selector_proto(self) -> plan_pb2.ExampleSelector:
73*14675a02SAndroid Build Coastguard Worker    return self._example_selector_proto
74*14675a02SAndroid Build Coastguard Worker
75*14675a02SAndroid Build Coastguard Worker  @property
76*14675a02SAndroid Build Coastguard Worker  def preprocessing_fn(
77*14675a02SAndroid Build Coastguard Worker      self,
78*14675a02SAndroid Build Coastguard Worker  ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]:
79*14675a02SAndroid Build Coastguard Worker    return self._preprocessing_fn
80*14675a02SAndroid Build Coastguard Worker
81*14675a02SAndroid Build Coastguard Worker  @property
82*14675a02SAndroid Build Coastguard Worker  def preprocessing_comp(self) -> tff.Computation:
83*14675a02SAndroid Build Coastguard Worker    """Returns the preprocessing computation for the input dataset."""
84*14675a02SAndroid Build Coastguard Worker    if self._preprocessing_comp is None:
85*14675a02SAndroid Build Coastguard Worker      if self.preprocessing_fn is None:
86*14675a02SAndroid Build Coastguard Worker        raise ValueError(
87*14675a02SAndroid Build Coastguard Worker            "DataSpec's preprocessing_fn is None so a "
88*14675a02SAndroid Build Coastguard Worker            'preprocessing tff.Computation cannot be generated.'
89*14675a02SAndroid Build Coastguard Worker        )
90*14675a02SAndroid Build Coastguard Worker      self._preprocessing_comp = tff.tf_computation(
91*14675a02SAndroid Build Coastguard Worker          self.preprocessing_fn, tff.SequenceType(tf.string)
92*14675a02SAndroid Build Coastguard Worker      )
93*14675a02SAndroid Build Coastguard Worker    return self._preprocessing_comp
94*14675a02SAndroid Build Coastguard Worker
95*14675a02SAndroid Build Coastguard Worker  @property
96*14675a02SAndroid Build Coastguard Worker  def type_signature(self) -> tff.Type:
97*14675a02SAndroid Build Coastguard Worker    """Returns the type signature of the result of the preprocessing_comp.
98*14675a02SAndroid Build Coastguard Worker
99*14675a02SAndroid Build Coastguard Worker    Effectively the type or 'spec' of the parsed example from the example store
100*14675a02SAndroid Build Coastguard Worker    pointed at by `example_selector_proto`.
101*14675a02SAndroid Build Coastguard Worker    """
102*14675a02SAndroid Build Coastguard Worker    return self.preprocessing_comp.type_signature.result
103*14675a02SAndroid Build Coastguard Worker
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Workerdef is_data_spec_or_structure(x: Any) -> bool:
106*14675a02SAndroid Build Coastguard Worker  """Returns True iff `x` is either a `DataSpec` or a nested structure of it."""
107*14675a02SAndroid Build Coastguard Worker  if x is None:
108*14675a02SAndroid Build Coastguard Worker    return False
109*14675a02SAndroid Build Coastguard Worker  if isinstance(x, DataSpec):
110*14675a02SAndroid Build Coastguard Worker    return True
111*14675a02SAndroid Build Coastguard Worker  try:
112*14675a02SAndroid Build Coastguard Worker    x = tff.structure.from_container(x)
113*14675a02SAndroid Build Coastguard Worker    return all(
114*14675a02SAndroid Build Coastguard Worker        is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x)
115*14675a02SAndroid Build Coastguard Worker    )
116*14675a02SAndroid Build Coastguard Worker  except TypeError:
117*14675a02SAndroid Build Coastguard Worker    return False
118*14675a02SAndroid Build Coastguard Worker
119*14675a02SAndroid Build Coastguard Worker
120*14675a02SAndroid Build Coastguard Workerdef check_data_spec_or_structure(x: Any, name: str):
121*14675a02SAndroid Build Coastguard Worker  """Raises error iff `x` is not a `DataSpec` or a nested structure of it."""
122*14675a02SAndroid Build Coastguard Worker  if not is_data_spec_or_structure(x):
123*14675a02SAndroid Build Coastguard Worker    raise TypeError(
124*14675a02SAndroid Build Coastguard Worker        f'Expected `{name}` to be a `DataSpec` or a nested '
125*14675a02SAndroid Build Coastguard Worker        f'structure of it, found {str(x)}.'
126*14675a02SAndroid Build Coastguard Worker    )
127*14675a02SAndroid Build Coastguard Worker
128*14675a02SAndroid Build Coastguard Worker
129*14675a02SAndroid Build Coastguard WorkerNestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']]
130*14675a02SAndroid Build Coastguard Worker
131*14675a02SAndroid Build Coastguard Worker
132*14675a02SAndroid Build Coastguard Workerdef generate_example_selector_bytes_list(ds: NestedDataSpec):
133*14675a02SAndroid Build Coastguard Worker  """Returns an ordered list of the bytes of each DataSpec's example selector.
134*14675a02SAndroid Build Coastguard Worker
135*14675a02SAndroid Build Coastguard Worker  The order aligns with the order of a struct given by
136*14675a02SAndroid Build Coastguard Worker  tff.structure.to_elements().
137*14675a02SAndroid Build Coastguard Worker
138*14675a02SAndroid Build Coastguard Worker  Args:
139*14675a02SAndroid Build Coastguard Worker    ds: A `NestedDataSpec`.
140*14675a02SAndroid Build Coastguard Worker  """
141*14675a02SAndroid Build Coastguard Worker  if isinstance(ds, DataSpec):
142*14675a02SAndroid Build Coastguard Worker    return [ds.example_selector_proto.SerializeToString()]
143*14675a02SAndroid Build Coastguard Worker  else:
144*14675a02SAndroid Build Coastguard Worker    ds = tff.structure.from_container(ds)
145*14675a02SAndroid Build Coastguard Worker    assert isinstance(ds, tff.structure.Struct)
146*14675a02SAndroid Build Coastguard Worker    data_spec_elements = tff.structure.to_elements(ds)
147*14675a02SAndroid Build Coastguard Worker    selector_bytes_list = []
148*14675a02SAndroid Build Coastguard Worker    for _, element in data_spec_elements:
149*14675a02SAndroid Build Coastguard Worker      selector_bytes_list.extend(generate_example_selector_bytes_list(element))
150*14675a02SAndroid Build Coastguard Worker    return selector_bytes_list
151