1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""A class to specify on-device dataset inputs.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom collections.abc import Callable 17*14675a02SAndroid Build Coastguard Workerfrom typing import Any, Optional, Union 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 20*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 21*14675a02SAndroid Build Coastguard Worker 22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks 23*14675a02SAndroid Build Coastguard Workerfrom fcp.protos import plan_pb2 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Workerclass DataSpec: 27*14675a02SAndroid Build Coastguard Worker """A specification of a single dataset input.""" 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Worker __slots__ = ( 30*14675a02SAndroid Build Coastguard Worker '_example_selector_proto', 31*14675a02SAndroid Build Coastguard Worker '_preprocessing_fn', 32*14675a02SAndroid Build Coastguard Worker '_preprocessing_comp', 33*14675a02SAndroid Build Coastguard Worker '_fingerprint', 34*14675a02SAndroid Build Coastguard Worker ) 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard Worker def __init__( 37*14675a02SAndroid Build Coastguard Worker self, 38*14675a02SAndroid Build Coastguard Worker example_selector_proto: plan_pb2.ExampleSelector, 39*14675a02SAndroid Build Coastguard Worker preprocessing_fn: Optional[ 40*14675a02SAndroid Build Coastguard Worker Callable[[tf.data.Dataset], tf.data.Dataset] 41*14675a02SAndroid Build Coastguard Worker ] = None, 42*14675a02SAndroid Build Coastguard Worker ): 43*14675a02SAndroid Build Coastguard Worker """Constructs a specification of a dataset input. 44*14675a02SAndroid Build Coastguard Worker 45*14675a02SAndroid Build Coastguard Worker Args: 46*14675a02SAndroid Build Coastguard Worker example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto. 47*14675a02SAndroid Build Coastguard Worker preprocessing_fn: A callable that accepts as an argument the raw input 48*14675a02SAndroid Build Coastguard Worker `tf.data.Dataset` with `string`-serialized items, performs any desired 49*14675a02SAndroid Build Coastguard Worker preprocessing such as deserialization, filtering, batching, and 50*14675a02SAndroid Build Coastguard Worker formatting, and returns the transformed `tf.data.Dataset` as a result. 51*14675a02SAndroid Build Coastguard Worker If preprocessing_fn is set to None, it is expected that any client data 52*14675a02SAndroid Build Coastguard Worker preprocessing has already been incorporated into the `tff.Computation` 53*14675a02SAndroid Build Coastguard Worker that this `DataSpec` is associated with. 54*14675a02SAndroid Build Coastguard Worker 55*14675a02SAndroid Build Coastguard Worker Raises: 56*14675a02SAndroid Build Coastguard Worker TypeError: If the types of the arguments are invalid. 57*14675a02SAndroid Build Coastguard Worker """ 58*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 59*14675a02SAndroid Build Coastguard Worker example_selector_proto, 60*14675a02SAndroid Build Coastguard Worker plan_pb2.ExampleSelector, 61*14675a02SAndroid Build Coastguard Worker name='example_selector_proto', 62*14675a02SAndroid Build Coastguard Worker ) 63*14675a02SAndroid Build Coastguard Worker if preprocessing_fn is not None: 64*14675a02SAndroid Build Coastguard Worker type_checks.check_callable(preprocessing_fn, name='preprocessing_fn') 65*14675a02SAndroid Build Coastguard Worker self._example_selector_proto = example_selector_proto 66*14675a02SAndroid Build Coastguard Worker self._preprocessing_fn = preprocessing_fn 67*14675a02SAndroid Build Coastguard Worker # Set once self.preprocessing_comp is accessed, as we can't call 68*14675a02SAndroid Build Coastguard Worker # tff.computation in __init__. 69*14675a02SAndroid Build Coastguard Worker self._preprocessing_comp = None 70*14675a02SAndroid Build Coastguard Worker 71*14675a02SAndroid Build Coastguard Worker @property 72*14675a02SAndroid Build Coastguard Worker def example_selector_proto(self) -> plan_pb2.ExampleSelector: 73*14675a02SAndroid Build Coastguard Worker return self._example_selector_proto 74*14675a02SAndroid Build Coastguard Worker 75*14675a02SAndroid Build Coastguard Worker @property 76*14675a02SAndroid Build Coastguard Worker def preprocessing_fn( 77*14675a02SAndroid Build Coastguard Worker self, 78*14675a02SAndroid Build Coastguard Worker ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]: 79*14675a02SAndroid Build Coastguard Worker return self._preprocessing_fn 80*14675a02SAndroid Build Coastguard Worker 81*14675a02SAndroid Build Coastguard Worker @property 82*14675a02SAndroid Build Coastguard Worker def preprocessing_comp(self) -> tff.Computation: 83*14675a02SAndroid Build Coastguard Worker """Returns the preprocessing computation for the input dataset.""" 84*14675a02SAndroid Build Coastguard Worker if self._preprocessing_comp is None: 85*14675a02SAndroid Build Coastguard Worker if self.preprocessing_fn is None: 86*14675a02SAndroid Build Coastguard Worker raise ValueError( 87*14675a02SAndroid Build Coastguard Worker "DataSpec's preprocessing_fn is None so a " 88*14675a02SAndroid Build Coastguard Worker 'preprocessing tff.Computation cannot be generated.' 89*14675a02SAndroid Build Coastguard Worker ) 90*14675a02SAndroid Build Coastguard Worker self._preprocessing_comp = tff.tf_computation( 91*14675a02SAndroid Build Coastguard Worker self.preprocessing_fn, tff.SequenceType(tf.string) 92*14675a02SAndroid Build Coastguard Worker ) 93*14675a02SAndroid Build Coastguard Worker return self._preprocessing_comp 94*14675a02SAndroid Build Coastguard Worker 95*14675a02SAndroid Build Coastguard Worker @property 96*14675a02SAndroid Build Coastguard Worker def type_signature(self) -> tff.Type: 97*14675a02SAndroid Build Coastguard Worker """Returns the type signature of the result of the preprocessing_comp. 98*14675a02SAndroid Build Coastguard Worker 99*14675a02SAndroid Build Coastguard Worker Effectively the type or 'spec' of the parsed example from the example store 100*14675a02SAndroid Build Coastguard Worker pointed at by `example_selector_proto`. 101*14675a02SAndroid Build Coastguard Worker """ 102*14675a02SAndroid Build Coastguard Worker return self.preprocessing_comp.type_signature.result 103*14675a02SAndroid Build Coastguard Worker 104*14675a02SAndroid Build Coastguard Worker 105*14675a02SAndroid Build Coastguard Workerdef is_data_spec_or_structure(x: Any) -> bool: 106*14675a02SAndroid Build Coastguard Worker """Returns True iff `x` is either a `DataSpec` or a nested structure of it.""" 107*14675a02SAndroid Build Coastguard Worker if x is None: 108*14675a02SAndroid Build Coastguard Worker return False 109*14675a02SAndroid Build Coastguard Worker if isinstance(x, DataSpec): 110*14675a02SAndroid Build Coastguard Worker return True 111*14675a02SAndroid Build Coastguard Worker try: 112*14675a02SAndroid Build Coastguard Worker x = tff.structure.from_container(x) 113*14675a02SAndroid Build Coastguard Worker return all( 114*14675a02SAndroid Build Coastguard Worker is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x) 115*14675a02SAndroid Build Coastguard Worker ) 116*14675a02SAndroid Build Coastguard Worker except TypeError: 117*14675a02SAndroid Build Coastguard Worker return False 118*14675a02SAndroid Build Coastguard Worker 119*14675a02SAndroid Build Coastguard Worker 120*14675a02SAndroid Build Coastguard Workerdef check_data_spec_or_structure(x: Any, name: str): 121*14675a02SAndroid Build Coastguard Worker """Raises error iff `x` is not a `DataSpec` or a nested structure of it.""" 122*14675a02SAndroid Build Coastguard Worker if not is_data_spec_or_structure(x): 123*14675a02SAndroid Build Coastguard Worker raise TypeError( 124*14675a02SAndroid Build Coastguard Worker f'Expected `{name}` to be a `DataSpec` or a nested ' 125*14675a02SAndroid Build Coastguard Worker f'structure of it, found {str(x)}.' 126*14675a02SAndroid Build Coastguard Worker ) 127*14675a02SAndroid Build Coastguard Worker 128*14675a02SAndroid Build Coastguard Worker 129*14675a02SAndroid Build Coastguard WorkerNestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']] 130*14675a02SAndroid Build Coastguard Worker 131*14675a02SAndroid Build Coastguard Worker 132*14675a02SAndroid Build Coastguard Workerdef generate_example_selector_bytes_list(ds: NestedDataSpec): 133*14675a02SAndroid Build Coastguard Worker """Returns an ordered list of the bytes of each DataSpec's example selector. 134*14675a02SAndroid Build Coastguard Worker 135*14675a02SAndroid Build Coastguard Worker The order aligns with the order of a struct given by 136*14675a02SAndroid Build Coastguard Worker tff.structure.to_elements(). 137*14675a02SAndroid Build Coastguard Worker 138*14675a02SAndroid Build Coastguard Worker Args: 139*14675a02SAndroid Build Coastguard Worker ds: A `NestedDataSpec`. 140*14675a02SAndroid Build Coastguard Worker """ 141*14675a02SAndroid Build Coastguard Worker if isinstance(ds, DataSpec): 142*14675a02SAndroid Build Coastguard Worker return [ds.example_selector_proto.SerializeToString()] 143*14675a02SAndroid Build Coastguard Worker else: 144*14675a02SAndroid Build Coastguard Worker ds = tff.structure.from_container(ds) 145*14675a02SAndroid Build Coastguard Worker assert isinstance(ds, tff.structure.Struct) 146*14675a02SAndroid Build Coastguard Worker data_spec_elements = tff.structure.to_elements(ds) 147*14675a02SAndroid Build Coastguard Worker selector_bytes_list = [] 148*14675a02SAndroid Build Coastguard Worker for _, element in data_spec_elements: 149*14675a02SAndroid Build Coastguard Worker selector_bytes_list.extend(generate_example_selector_bytes_list(element)) 150*14675a02SAndroid Build Coastguard Worker return selector_bytes_list 151