1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Utilities for manipulating TensorFlow graph logic.""" 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional, Union 17*14675a02SAndroid Build Coastguard Worker 18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import data_spec 22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils 23*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks 24*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import external_dataset 25*14675a02SAndroid Build Coastguard Workerfrom tensorflow_federated.proto.v0 import computation_pb2 26*14675a02SAndroid Build Coastguard Worker 27*14675a02SAndroid Build Coastguard WorkerTfValue = Union[tf.Variable, tf.Tensor] 28*14675a02SAndroid Build Coastguard WorkerDatasetTensor = tf.Tensor 29*14675a02SAndroid Build Coastguard WorkerArgument = Union[TfValue, list[TfValue], DatasetTensor] 30*14675a02SAndroid Build Coastguard WorkerArgs = Optional[Union[Argument, tuple[Argument, ...]]] 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard WorkerResult = Argument 33*14675a02SAndroid Build Coastguard WorkerMaybeSplitOutputs = Union[Result, tuple[Result, ...]] 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard WorkerEXAMPLE_SELECTOR_PLACEHOLDER_PREFIX = 'example_selector' 37*14675a02SAndroid Build Coastguard Worker 38*14675a02SAndroid Build Coastguard Worker 39*14675a02SAndroid Build Coastguard Workerdef generate_example_selector_placeholders( 40*14675a02SAndroid Build Coastguard Worker type_spec: tff.Type, 41*14675a02SAndroid Build Coastguard Worker name_prefix: str, 42*14675a02SAndroid Build Coastguard Worker): 43*14675a02SAndroid Build Coastguard Worker """Generates list of tff.compat.v1.placeholders for each leaf in a type spec. 44*14675a02SAndroid Build Coastguard Worker 45*14675a02SAndroid Build Coastguard Worker The order of the placeholders aligns with the order given by 46*14675a02SAndroid Build Coastguard Worker tff.structure.to_elements(). 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Worker Placeholders will be named by concatenating the name_prefix arg with the list 49*14675a02SAndroid Build Coastguard Worker of indexes at each level of the struct to get to the placeholder's leaf in the 50*14675a02SAndroid Build Coastguard Worker tff.Type. 51*14675a02SAndroid Build Coastguard Worker 52*14675a02SAndroid Build Coastguard Worker Args: 53*14675a02SAndroid Build Coastguard Worker type_spec: A type spec to infer the list of placeholders from. This is 54*14675a02SAndroid Build Coastguard Worker expected to be a tff.SequenceType or a tff.StructType, and if it is a 55*14675a02SAndroid Build Coastguard Worker tff.StructType, it is expected to be a tree of tff.StructTypes with 56*14675a02SAndroid Build Coastguard Worker tff.SequenceTypes at the leaves. This is expected to reflect the TFF type 57*14675a02SAndroid Build Coastguard Worker signature of the input client data. 58*14675a02SAndroid Build Coastguard Worker name_prefix: The name prefix that should be used when naming each 59*14675a02SAndroid Build Coastguard Worker placeholder. 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker Returns: 62*14675a02SAndroid Build Coastguard Worker A list of tf.compat.v2.placeholders. 63*14675a02SAndroid Build Coastguard Worker """ 64*14675a02SAndroid Build Coastguard Worker type_spec = tff.to_type(type_spec) 65*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 66*14675a02SAndroid Build Coastguard Worker type_spec, (tff.SequenceType, tff.StructType), name='type_spec' 67*14675a02SAndroid Build Coastguard Worker ) 68*14675a02SAndroid Build Coastguard Worker if type_spec.is_sequence(): 69*14675a02SAndroid Build Coastguard Worker # Each client input is a sequence of serialized `tf.Example`s, which is why 70*14675a02SAndroid Build Coastguard Worker # the leaves of these TFF type signatures are sequences. Each input sequence 71*14675a02SAndroid Build Coastguard Worker # of `tf.Example`s requires a single `ExampleSelector` that determines that 72*14675a02SAndroid Build Coastguard Worker # stream of `tf.Example`s is selected from the data store, which is why we 73*14675a02SAndroid Build Coastguard Worker # only have a single placeholder for the `ExampleSelector`. 74*14675a02SAndroid Build Coastguard Worker return [tf.compat.v1.placeholder(tf.string, shape=[], name=name_prefix)] 75*14675a02SAndroid Build Coastguard Worker else: 76*14675a02SAndroid Build Coastguard Worker type_spec.check_struct() 77*14675a02SAndroid Build Coastguard Worker type_spec_elements = tff.structure.to_elements(type_spec) 78*14675a02SAndroid Build Coastguard Worker placeholders = [] 79*14675a02SAndroid Build Coastguard Worker for element_index, (_, element_type) in enumerate(type_spec_elements): 80*14675a02SAndroid Build Coastguard Worker placeholders.extend( 81*14675a02SAndroid Build Coastguard Worker generate_example_selector_placeholders( 82*14675a02SAndroid Build Coastguard Worker element_type, f'{name_prefix}_{element_index}' 83*14675a02SAndroid Build Coastguard Worker ) 84*14675a02SAndroid Build Coastguard Worker ) 85*14675a02SAndroid Build Coastguard Worker return placeholders 86*14675a02SAndroid Build Coastguard Worker 87*14675a02SAndroid Build Coastguard Worker 88*14675a02SAndroid Build Coastguard Workerdef embed_data_logic( 89*14675a02SAndroid Build Coastguard Worker client_data_type: tff.Type, 90*14675a02SAndroid Build Coastguard Worker dataspec: Optional[data_spec.NestedDataSpec] = None, 91*14675a02SAndroid Build Coastguard Worker) -> tuple[tf.Tensor, list[MaybeSplitOutputs], list[tf.Tensor]]: 92*14675a02SAndroid Build Coastguard Worker """Embeds the data logic into the current TensorFlow graph. 93*14675a02SAndroid Build Coastguard Worker 94*14675a02SAndroid Build Coastguard Worker Adds dataset ops to the current graph, using the custom `ExternalDataset` 95*14675a02SAndroid Build Coastguard Worker which returns a placeholder token. The initialization op and data values are 96*14675a02SAndroid Build Coastguard Worker also returned. 97*14675a02SAndroid Build Coastguard Worker 98*14675a02SAndroid Build Coastguard Worker Args: 99*14675a02SAndroid Build Coastguard Worker client_data_type: The TFF type signature of the input client data. 100*14675a02SAndroid Build Coastguard Worker dataspec: If provided, either an instance of `data_spec.DataSpec` or a 101*14675a02SAndroid Build Coastguard Worker nested structure of these that matches the structure of the first element 102*14675a02SAndroid Build Coastguard Worker of the input to the client work part of the computation. 103*14675a02SAndroid Build Coastguard Worker 104*14675a02SAndroid Build Coastguard Worker Returns: 105*14675a02SAndroid Build Coastguard Worker A `tuple` containing the following (in order): 106*14675a02SAndroid Build Coastguard Worker token_placeholder: A dataset token placeholder tensor 107*14675a02SAndroid Build Coastguard Worker data_values: A list of dataset output values 108*14675a02SAndroid Build Coastguard Worker example_selector_placeholders: A possibly empty list of placeholders used 109*14675a02SAndroid Build Coastguard Worker for passing in example selector information into the client graph. This 110*14675a02SAndroid Build Coastguard Worker list will be empty iff dataspec is supplied. 111*14675a02SAndroid Build Coastguard Worker 112*14675a02SAndroid Build Coastguard Worker Raises: 113*14675a02SAndroid Build Coastguard Worker ValueError: If the number of dataset output from one data source is not 1. 114*14675a02SAndroid Build Coastguard Worker ValueError: If a node exists in the graph already that contains a node with 115*14675a02SAndroid Build Coastguard Worker the same name as the example selector placeholders. 116*14675a02SAndroid Build Coastguard Worker """ 117*14675a02SAndroid Build Coastguard Worker data_values = [] 118*14675a02SAndroid Build Coastguard Worker # Embeds the token placeholder for the custom ExternalDataset op. 119*14675a02SAndroid Build Coastguard Worker token_placeholder = tf.compat.v1.placeholder( 120*14675a02SAndroid Build Coastguard Worker tf.string, shape=[], name='data_token' 121*14675a02SAndroid Build Coastguard Worker ) 122*14675a02SAndroid Build Coastguard Worker 123*14675a02SAndroid Build Coastguard Worker example_selector_placeholders = [] 124*14675a02SAndroid Build Coastguard Worker if dataspec is None: 125*14675a02SAndroid Build Coastguard Worker example_selector_placeholders = generate_example_selector_placeholders( 126*14675a02SAndroid Build Coastguard Worker client_data_type, EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX 127*14675a02SAndroid Build Coastguard Worker ) 128*14675a02SAndroid Build Coastguard Worker # If the first placeholder does not have the expected prefix, then it is due 129*14675a02SAndroid Build Coastguard Worker # to other variables in the graph, likely created from the input 130*14675a02SAndroid Build Coastguard Worker # tff.Computation, having the special name. This check ensures that no other 131*14675a02SAndroid Build Coastguard Worker # variables use this special example selector placeholder name and we can 132*14675a02SAndroid Build Coastguard Worker # easily extract example selector placeholders in the generated artifact. 133*14675a02SAndroid Build Coastguard Worker if example_selector_placeholders and ( 134*14675a02SAndroid Build Coastguard Worker not ( 135*14675a02SAndroid Build Coastguard Worker example_selector_placeholders[0].name.startswith( 136*14675a02SAndroid Build Coastguard Worker f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}:' 137*14675a02SAndroid Build Coastguard Worker ) 138*14675a02SAndroid Build Coastguard Worker or example_selector_placeholders[0].name.startswith( 139*14675a02SAndroid Build Coastguard Worker f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}_0' 140*14675a02SAndroid Build Coastguard Worker ) 141*14675a02SAndroid Build Coastguard Worker ) 142*14675a02SAndroid Build Coastguard Worker ): 143*14675a02SAndroid Build Coastguard Worker raise ValueError( 144*14675a02SAndroid Build Coastguard Worker 'Graph already contains a placeholder with name ' 145*14675a02SAndroid Build Coastguard Worker f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}. Please ' 146*14675a02SAndroid Build Coastguard Worker 'avoid the use of this special name.' 147*14675a02SAndroid Build Coastguard Worker ) 148*14675a02SAndroid Build Coastguard Worker data_sources = make_data_sources_without_dataspec(client_data_type) 149*14675a02SAndroid Build Coastguard Worker assert len(example_selector_placeholders) == len(data_sources) 150*14675a02SAndroid Build Coastguard Worker else: 151*14675a02SAndroid Build Coastguard Worker data_sources = make_data_sources_with_dataspec(client_data_type, dataspec) 152*14675a02SAndroid Build Coastguard Worker 153*14675a02SAndroid Build Coastguard Worker # Embeds data source computations into the current graph. 154*14675a02SAndroid Build Coastguard Worker for index, data_comp in enumerate(data_sources): 155*14675a02SAndroid Build Coastguard Worker data_comp_import_args = [token_placeholder] 156*14675a02SAndroid Build Coastguard Worker if example_selector_placeholders: 157*14675a02SAndroid Build Coastguard Worker data_comp_import_args.append(example_selector_placeholders[index]) 158*14675a02SAndroid Build Coastguard Worker ds_values = import_tensorflow( 159*14675a02SAndroid Build Coastguard Worker 'data_{}'.format(index), data_comp, data_comp_import_args 160*14675a02SAndroid Build Coastguard Worker ) # pytype: disable=wrong-arg-types 161*14675a02SAndroid Build Coastguard Worker if len(ds_values) != 1: 162*14675a02SAndroid Build Coastguard Worker raise ValueError( 163*14675a02SAndroid Build Coastguard Worker 'Expected one dataset output from a data source, found {}.'.format( 164*14675a02SAndroid Build Coastguard Worker str(len(ds_values)) 165*14675a02SAndroid Build Coastguard Worker ) 166*14675a02SAndroid Build Coastguard Worker ) 167*14675a02SAndroid Build Coastguard Worker data_values.extend(ds_values) 168*14675a02SAndroid Build Coastguard Worker 169*14675a02SAndroid Build Coastguard Worker return token_placeholder, data_values, example_selector_placeholders 170*14675a02SAndroid Build Coastguard Worker 171*14675a02SAndroid Build Coastguard Worker 172*14675a02SAndroid Build Coastguard Workerdef import_tensorflow( 173*14675a02SAndroid Build Coastguard Worker name: str, 174*14675a02SAndroid Build Coastguard Worker comp: tff.framework.ConcreteComputation, 175*14675a02SAndroid Build Coastguard Worker args: Args = None, 176*14675a02SAndroid Build Coastguard Worker split_outputs: bool = False, 177*14675a02SAndroid Build Coastguard Worker session_token_tensor: Optional[tf.Tensor] = None, 178*14675a02SAndroid Build Coastguard Worker) -> MaybeSplitOutputs: 179*14675a02SAndroid Build Coastguard Worker """Imports a tensorflow computation into the current graph. 180*14675a02SAndroid Build Coastguard Worker 181*14675a02SAndroid Build Coastguard Worker Args: 182*14675a02SAndroid Build Coastguard Worker name: The string name to use as the graph import prefix. 183*14675a02SAndroid Build Coastguard Worker comp: An instance of `tff.framework.ConcreteComputation` with just the 184*14675a02SAndroid Build Coastguard Worker `tensorflow` section. 185*14675a02SAndroid Build Coastguard Worker args: Either a single argument, a tuple of arguments, or None. An argument 186*14675a02SAndroid Build Coastguard Worker must be either: - a Python `list` containing either tensors or variables, 187*14675a02SAndroid Build Coastguard Worker or - a single variant tensor representing a dataset input. 188*14675a02SAndroid Build Coastguard Worker split_outputs: Whether to unpack the result tuple into a Python tuple. If 189*14675a02SAndroid Build Coastguard Worker `True`, `import_tensorflow` will return a tuple with multiple result 190*14675a02SAndroid Build Coastguard Worker objects, corresponding to the return elements in the type signature of 191*14675a02SAndroid Build Coastguard Worker `comp`. Notice that the return type signature of `comp` must be a tuple in 192*14675a02SAndroid Build Coastguard Worker this case. If `False`, `import_tensorflow` will return the entire result 193*14675a02SAndroid Build Coastguard Worker in a flattened form as a single Python result object. Each Python result 194*14675a02SAndroid Build Coastguard Worker object, similar to the argumens in `args`, will be either a Python `list` 195*14675a02SAndroid Build Coastguard Worker of variant tensors or a singleton Python list containing only the dataset 196*14675a02SAndroid Build Coastguard Worker variant tensor. 197*14675a02SAndroid Build Coastguard Worker session_token_tensor: A tensor in the current graph containing the "session 198*14675a02SAndroid Build Coastguard Worker token" of the TensorFlow being imported. This is useful for passing a 199*14675a02SAndroid Build Coastguard Worker session-global identifier into the graph for use with ops like 200*14675a02SAndroid Build Coastguard Worker `ServeSlices` and `ExternalDataset` that take in a token which references 201*14675a02SAndroid Build Coastguard Worker session-global state. 202*14675a02SAndroid Build Coastguard Worker 203*14675a02SAndroid Build Coastguard Worker Returns: 204*14675a02SAndroid Build Coastguard Worker One of: 205*14675a02SAndroid Build Coastguard Worker - A single result (Python `list` of variable value or variant tensors) if 206*14675a02SAndroid Build Coastguard Worker `split_outputs` is `False`. 207*14675a02SAndroid Build Coastguard Worker - A Python `tuple` of such results, if `split_outputs` is `True`. 208*14675a02SAndroid Build Coastguard Worker 209*14675a02SAndroid Build Coastguard Worker Raises: 210*14675a02SAndroid Build Coastguard Worker TypeError: If the arguments are of the wrong types. 211*14675a02SAndroid Build Coastguard Worker """ 212*14675a02SAndroid Build Coastguard Worker type_checks.check_type(name, str, name='name') 213*14675a02SAndroid Build Coastguard Worker type_checks.check_type(comp, tff.framework.ConcreteComputation, name='comp') 214*14675a02SAndroid Build Coastguard Worker type_checks.check_type(split_outputs, bool, name='split_outputs') 215*14675a02SAndroid Build Coastguard Worker 216*14675a02SAndroid Build Coastguard Worker comp_proto = tff.framework.ConcreteComputation.get_proto(comp) 217*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 218*14675a02SAndroid Build Coastguard Worker comp_proto, computation_pb2.Computation, name='comp_proto' 219*14675a02SAndroid Build Coastguard Worker ) 220*14675a02SAndroid Build Coastguard Worker 221*14675a02SAndroid Build Coastguard Worker which_comp = comp_proto.WhichOneof('computation') 222*14675a02SAndroid Build Coastguard Worker if which_comp != 'tensorflow': 223*14675a02SAndroid Build Coastguard Worker raise TypeError( 224*14675a02SAndroid Build Coastguard Worker 'Expected a TensorFlow computation, found {}.'.format(which_comp) 225*14675a02SAndroid Build Coastguard Worker ) 226*14675a02SAndroid Build Coastguard Worker if args is None: 227*14675a02SAndroid Build Coastguard Worker input_map = None 228*14675a02SAndroid Build Coastguard Worker elif isinstance(args, tuple): 229*14675a02SAndroid Build Coastguard Worker which_binding = comp_proto.tensorflow.parameter.WhichOneof('binding') 230*14675a02SAndroid Build Coastguard Worker if which_binding != 'struct': 231*14675a02SAndroid Build Coastguard Worker raise TypeError( 232*14675a02SAndroid Build Coastguard Worker 'Expected a struct binding with a struct of args, found {}.'.format( 233*14675a02SAndroid Build Coastguard Worker which_binding 234*14675a02SAndroid Build Coastguard Worker ) 235*14675a02SAndroid Build Coastguard Worker ) 236*14675a02SAndroid Build Coastguard Worker input_map = {} 237*14675a02SAndroid Build Coastguard Worker for index, arg in enumerate(args): 238*14675a02SAndroid Build Coastguard Worker input_map.update( 239*14675a02SAndroid Build Coastguard Worker create_tensor_map( 240*14675a02SAndroid Build Coastguard Worker comp_proto.tensorflow.parameter.struct.element[index], arg 241*14675a02SAndroid Build Coastguard Worker ) 242*14675a02SAndroid Build Coastguard Worker ) 243*14675a02SAndroid Build Coastguard Worker else: 244*14675a02SAndroid Build Coastguard Worker input_map = create_tensor_map(comp_proto.tensorflow.parameter, args) 245*14675a02SAndroid Build Coastguard Worker if input_map is not None: 246*14675a02SAndroid Build Coastguard Worker # Add remappings for all potential control dependencies in the graph as 247*14675a02SAndroid Build Coastguard Worker # well. Since `tf.graph_util.import_graph_def` input map works on the tensor 248*14675a02SAndroid Build Coastguard Worker # (not graph node) level, we must handle this case also. 249*14675a02SAndroid Build Coastguard Worker def control_dep_name(name: str) -> str: 250*14675a02SAndroid Build Coastguard Worker if name.startswith('^'): 251*14675a02SAndroid Build Coastguard Worker return name 252*14675a02SAndroid Build Coastguard Worker node_name = name.split(':', maxsplit=1)[0] 253*14675a02SAndroid Build Coastguard Worker return f'^{node_name}' 254*14675a02SAndroid Build Coastguard Worker 255*14675a02SAndroid Build Coastguard Worker input_map.update( 256*14675a02SAndroid Build Coastguard Worker { 257*14675a02SAndroid Build Coastguard Worker control_dep_name(k): control_dep_name(v.name) 258*14675a02SAndroid Build Coastguard Worker for k, v in input_map.items() 259*14675a02SAndroid Build Coastguard Worker if not k.startswith('^') 260*14675a02SAndroid Build Coastguard Worker } 261*14675a02SAndroid Build Coastguard Worker ) 262*14675a02SAndroid Build Coastguard Worker input_map = {} if input_map is None else input_map 263*14675a02SAndroid Build Coastguard Worker if ( 264*14675a02SAndroid Build Coastguard Worker session_token_tensor is not None 265*14675a02SAndroid Build Coastguard Worker and comp_proto.tensorflow.session_token_tensor_name 266*14675a02SAndroid Build Coastguard Worker ): 267*14675a02SAndroid Build Coastguard Worker input_map[comp_proto.tensorflow.session_token_tensor_name] = ( 268*14675a02SAndroid Build Coastguard Worker session_token_tensor 269*14675a02SAndroid Build Coastguard Worker ) 270*14675a02SAndroid Build Coastguard Worker if split_outputs: 271*14675a02SAndroid Build Coastguard Worker return_elements = [] 272*14675a02SAndroid Build Coastguard Worker subset_sizes = [] 273*14675a02SAndroid Build Coastguard Worker which_binding = comp_proto.tensorflow.result.WhichOneof('binding') 274*14675a02SAndroid Build Coastguard Worker if which_binding != 'struct': 275*14675a02SAndroid Build Coastguard Worker raise TypeError( 276*14675a02SAndroid Build Coastguard Worker 'If `split_outputs` is `True`, the result of the computation we are ' 277*14675a02SAndroid Build Coastguard Worker 'importing must be a `struct`; found {}.'.format(which_binding) 278*14675a02SAndroid Build Coastguard Worker ) 279*14675a02SAndroid Build Coastguard Worker for binding in comp_proto.tensorflow.result.struct.element: 280*14675a02SAndroid Build Coastguard Worker tensor_names = _list_tensor_names_in_binding(binding) 281*14675a02SAndroid Build Coastguard Worker return_elements.extend(tensor_names) 282*14675a02SAndroid Build Coastguard Worker subset_sizes.append(len(tensor_names)) 283*14675a02SAndroid Build Coastguard Worker else: 284*14675a02SAndroid Build Coastguard Worker return_elements = _list_tensor_names_in_binding( 285*14675a02SAndroid Build Coastguard Worker comp_proto.tensorflow.result 286*14675a02SAndroid Build Coastguard Worker ) 287*14675a02SAndroid Build Coastguard Worker subset_sizes = [len(return_elements)] 288*14675a02SAndroid Build Coastguard Worker 289*14675a02SAndroid Build Coastguard Worker graph_def = tensor_utils.import_graph_def_from_any( 290*14675a02SAndroid Build Coastguard Worker comp_proto.tensorflow.graph_def 291*14675a02SAndroid Build Coastguard Worker ) 292*14675a02SAndroid Build Coastguard Worker 293*14675a02SAndroid Build Coastguard Worker # We will be importing multiple GraphDefs into the server or client graphs. 294*14675a02SAndroid Build Coastguard Worker # These individual graphs may have identifical `shared_name` attributes on 295*14675a02SAndroid Build Coastguard Worker # variable ops, which causes the runtime to reference the same resource, which 296*14675a02SAndroid Build Coastguard Worker # is highly undesired. We must uniquify the names before importing. 297*14675a02SAndroid Build Coastguard Worker def uniquify_shared_names( 298*14675a02SAndroid Build Coastguard Worker graph_def: tf.compat.v1.GraphDef, suffix: bytes 299*14675a02SAndroid Build Coastguard Worker ) -> tf.compat.v1.GraphDef: 300*14675a02SAndroid Build Coastguard Worker for x in graph_def.node: 301*14675a02SAndroid Build Coastguard Worker shared_name = x.attr.get('shared_name') 302*14675a02SAndroid Build Coastguard Worker if shared_name is not None: 303*14675a02SAndroid Build Coastguard Worker if not shared_name.s: 304*14675a02SAndroid Build Coastguard Worker # Encountered an empty string shared name, avoid creating a shared 305*14675a02SAndroid Build Coastguard Worker # name that starts with an underscore (not allowed by TF). 306*14675a02SAndroid Build Coastguard Worker shared_name.s = b'None' 307*14675a02SAndroid Build Coastguard Worker shared_name.s += b'_' + suffix 308*14675a02SAndroid Build Coastguard Worker return graph_def 309*14675a02SAndroid Build Coastguard Worker 310*14675a02SAndroid Build Coastguard Worker uniquified_graph_def = uniquify_shared_names( 311*14675a02SAndroid Build Coastguard Worker graph_def, suffix=name.encode('utf-8') 312*14675a02SAndroid Build Coastguard Worker ) 313*14675a02SAndroid Build Coastguard Worker if comp_proto.tensorflow.initialize_op: 314*14675a02SAndroid Build Coastguard Worker uniquified_graph_def = add_control_deps_for_init_op( 315*14675a02SAndroid Build Coastguard Worker uniquified_graph_def, comp_proto.tensorflow.initialize_op 316*14675a02SAndroid Build Coastguard Worker ) 317*14675a02SAndroid Build Coastguard Worker import_result = tf.graph_util.import_graph_def( 318*14675a02SAndroid Build Coastguard Worker uniquified_graph_def, 319*14675a02SAndroid Build Coastguard Worker input_map=input_map, 320*14675a02SAndroid Build Coastguard Worker return_elements=return_elements, 321*14675a02SAndroid Build Coastguard Worker name=name, 322*14675a02SAndroid Build Coastguard Worker ) 323*14675a02SAndroid Build Coastguard Worker 324*14675a02SAndroid Build Coastguard Worker if split_outputs: 325*14675a02SAndroid Build Coastguard Worker subsets = [] 326*14675a02SAndroid Build Coastguard Worker offset = 0 327*14675a02SAndroid Build Coastguard Worker for subset_size in subset_sizes: 328*14675a02SAndroid Build Coastguard Worker next_offset = offset + subset_size 329*14675a02SAndroid Build Coastguard Worker subsets.append(import_result[offset:next_offset]) 330*14675a02SAndroid Build Coastguard Worker offset = next_offset 331*14675a02SAndroid Build Coastguard Worker results = tuple(subsets) 332*14675a02SAndroid Build Coastguard Worker else: 333*14675a02SAndroid Build Coastguard Worker results = import_result[: subset_sizes[0]] 334*14675a02SAndroid Build Coastguard Worker return results 335*14675a02SAndroid Build Coastguard Worker 336*14675a02SAndroid Build Coastguard Worker 337*14675a02SAndroid Build Coastguard Workerdef _get_deps_for_graph_node( 338*14675a02SAndroid Build Coastguard Worker graph_def: tf.compat.v1.GraphDef, node_name: str 339*14675a02SAndroid Build Coastguard Worker) -> set[str]: 340*14675a02SAndroid Build Coastguard Worker """Returns the set of node names that a node named `node_name` depends on. 341*14675a02SAndroid Build Coastguard Worker 342*14675a02SAndroid Build Coastguard Worker Note that this function does not work for nodes in the function library. 343*14675a02SAndroid Build Coastguard Worker 344*14675a02SAndroid Build Coastguard Worker Args: 345*14675a02SAndroid Build Coastguard Worker graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`. 346*14675a02SAndroid Build Coastguard Worker node_name: The node name, a string. 347*14675a02SAndroid Build Coastguard Worker 348*14675a02SAndroid Build Coastguard Worker Returns: 349*14675a02SAndroid Build Coastguard Worker An instance of `set()` containing string names of the nodes `node_name` 350*14675a02SAndroid Build Coastguard Worker depends on in graph_def. 351*14675a02SAndroid Build Coastguard Worker 352*14675a02SAndroid Build Coastguard Worker Raises: 353*14675a02SAndroid Build Coastguard Worker TypeError: If either argument is of the wrong type. 354*14675a02SAndroid Build Coastguard Worker """ 355*14675a02SAndroid Build Coastguard Worker type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def') 356*14675a02SAndroid Build Coastguard Worker type_checks.check_type(node_name, str, name='node_name') 357*14675a02SAndroid Build Coastguard Worker input_map = {} 358*14675a02SAndroid Build Coastguard Worker for node in graph_def.node: 359*14675a02SAndroid Build Coastguard Worker input_map[node.name] = set(tensor_utils.bare_name(x) for x in node.input) 360*14675a02SAndroid Build Coastguard Worker dependencies = set() 361*14675a02SAndroid Build Coastguard Worker initial_singleton = set([node_name]) 362*14675a02SAndroid Build Coastguard Worker nodes_to_process = initial_singleton 363*14675a02SAndroid Build Coastguard Worker while nodes_to_process: 364*14675a02SAndroid Build Coastguard Worker dependencies.update(nodes_to_process) 365*14675a02SAndroid Build Coastguard Worker nodes_to_process = set.union( 366*14675a02SAndroid Build Coastguard Worker *[input_map[name] for name in nodes_to_process] 367*14675a02SAndroid Build Coastguard Worker ).difference(dependencies) 368*14675a02SAndroid Build Coastguard Worker return dependencies.difference(initial_singleton) 369*14675a02SAndroid Build Coastguard Worker 370*14675a02SAndroid Build Coastguard Worker 371*14675a02SAndroid Build Coastguard Workerdef add_control_deps_for_init_op( 372*14675a02SAndroid Build Coastguard Worker graph_def: tf.compat.v1.GraphDef, init_op: str 373*14675a02SAndroid Build Coastguard Worker) -> tf.compat.v1.GraphDef: 374*14675a02SAndroid Build Coastguard Worker """Adds control deps on `init_op` to nodes in GraphDef. 375*14675a02SAndroid Build Coastguard Worker 376*14675a02SAndroid Build Coastguard Worker Note that control deps are not added to any of the ancestors of `init_op` 377*14675a02SAndroid Build Coastguard Worker (which would result in a control dep cycle) and control deps are not added to 378*14675a02SAndroid Build Coastguard Worker any nodes in the function library of a GraphDef. 379*14675a02SAndroid Build Coastguard Worker 380*14675a02SAndroid Build Coastguard Worker Args: 381*14675a02SAndroid Build Coastguard Worker graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`. 382*14675a02SAndroid Build Coastguard Worker init_op: The init op name, a string. 383*14675a02SAndroid Build Coastguard Worker 384*14675a02SAndroid Build Coastguard Worker Returns: 385*14675a02SAndroid Build Coastguard Worker The updated graph, an instance of `tf.compat.v1.GraphDef`. 386*14675a02SAndroid Build Coastguard Worker 387*14675a02SAndroid Build Coastguard Worker Raises: 388*14675a02SAndroid Build Coastguard Worker TypeError: If either argument is of the wrong type. 389*14675a02SAndroid Build Coastguard Worker """ 390*14675a02SAndroid Build Coastguard Worker type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def') 391*14675a02SAndroid Build Coastguard Worker type_checks.check_type(init_op, str, name='init_op') 392*14675a02SAndroid Build Coastguard Worker init_op_str = tensor_utils.bare_name(init_op) 393*14675a02SAndroid Build Coastguard Worker init_op_control_dep = '^{}'.format(init_op_str) 394*14675a02SAndroid Build Coastguard Worker deps = _get_deps_for_graph_node(graph_def, init_op_str).union( 395*14675a02SAndroid Build Coastguard Worker set([init_op_str]) 396*14675a02SAndroid Build Coastguard Worker ) 397*14675a02SAndroid Build Coastguard Worker new_graph_def = tf.compat.v1.GraphDef() 398*14675a02SAndroid Build Coastguard Worker new_graph_def.CopyFrom(graph_def) 399*14675a02SAndroid Build Coastguard Worker for new_node in new_graph_def.node: 400*14675a02SAndroid Build Coastguard Worker if new_node.name not in deps: 401*14675a02SAndroid Build Coastguard Worker node_inputs = new_node.input 402*14675a02SAndroid Build Coastguard Worker if init_op_control_dep not in node_inputs: 403*14675a02SAndroid Build Coastguard Worker new_node.input.append(init_op_control_dep) 404*14675a02SAndroid Build Coastguard Worker return new_graph_def 405*14675a02SAndroid Build Coastguard Worker 406*14675a02SAndroid Build Coastguard Worker 407*14675a02SAndroid Build Coastguard Workerdef create_tensor_map( 408*14675a02SAndroid Build Coastguard Worker binding: computation_pb2.TensorFlow.Binding, 409*14675a02SAndroid Build Coastguard Worker arg: list[Union[tf.Tensor, tf.Variable]], 410*14675a02SAndroid Build Coastguard Worker) -> dict[str, tf.Tensor]: 411*14675a02SAndroid Build Coastguard Worker """Creates a `dict` mapping tensor names in the binding to tensors in `arg`. 412*14675a02SAndroid Build Coastguard Worker 413*14675a02SAndroid Build Coastguard Worker Args: 414*14675a02SAndroid Build Coastguard Worker binding: An instance of `computation_pb2.TensorFlow.Binding`. 415*14675a02SAndroid Build Coastguard Worker arg: Either a singleton Python `list` with variant tensor in case of a 416*14675a02SAndroid Build Coastguard Worker sequence binding, or a Python `list` of tensors or resource variables 417*14675a02SAndroid Build Coastguard Worker otherwise for a tuple binding. 418*14675a02SAndroid Build Coastguard Worker 419*14675a02SAndroid Build Coastguard Worker Returns: 420*14675a02SAndroid Build Coastguard Worker An instance of Python `dict` with the map as specified above. 421*14675a02SAndroid Build Coastguard Worker 422*14675a02SAndroid Build Coastguard Worker Raises: 423*14675a02SAndroid Build Coastguard Worker TypeError: If the argument types are incorrect. 424*14675a02SAndroid Build Coastguard Worker ValueError: If the arguments are malformed (e.g., multiple variant tensors). 425*14675a02SAndroid Build Coastguard Worker """ 426*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 427*14675a02SAndroid Build Coastguard Worker binding, computation_pb2.TensorFlow.Binding, name='binding' 428*14675a02SAndroid Build Coastguard Worker ) 429*14675a02SAndroid Build Coastguard Worker type_checks.check_type(arg, list, name='arg') 430*14675a02SAndroid Build Coastguard Worker tensor_names_in_binding = _list_tensor_names_in_binding(binding) 431*14675a02SAndroid Build Coastguard Worker which_binding = binding.WhichOneof('binding') 432*14675a02SAndroid Build Coastguard Worker if which_binding == 'sequence': 433*14675a02SAndroid Build Coastguard Worker if (len(tensor_names_in_binding) != 1) or (len(arg) != 1): 434*14675a02SAndroid Build Coastguard Worker raise ValueError('Multiple variant tensors found.') 435*14675a02SAndroid Build Coastguard Worker variant_tensor_name = tensor_names_in_binding[0] 436*14675a02SAndroid Build Coastguard Worker arg = arg[0] 437*14675a02SAndroid Build Coastguard Worker if not tf.is_tensor(arg): 438*14675a02SAndroid Build Coastguard Worker raise TypeError('Expected a tensor, found {!r}.'.format(type(arg))) 439*14675a02SAndroid Build Coastguard Worker if arg.dtype != tf.variant: 440*14675a02SAndroid Build Coastguard Worker raise TypeError('Expected `tf.variant`, found {!r}.'.format(arg.dtype)) 441*14675a02SAndroid Build Coastguard Worker return {variant_tensor_name: arg} 442*14675a02SAndroid Build Coastguard Worker else: 443*14675a02SAndroid Build Coastguard Worker return { 444*14675a02SAndroid Build Coastguard Worker k: v.read_value() if hasattr(v, 'read_value') else v 445*14675a02SAndroid Build Coastguard Worker for k, v in zip(tensor_names_in_binding, arg) 446*14675a02SAndroid Build Coastguard Worker } 447*14675a02SAndroid Build Coastguard Worker 448*14675a02SAndroid Build Coastguard Worker 449*14675a02SAndroid Build Coastguard Workerdef _validate_data_comp(data_comp: tff.Computation, type_spec: tff.Type): 450*14675a02SAndroid Build Coastguard Worker type_checks.check_type(data_comp.type_signature, tff.FunctionType) 451*14675a02SAndroid Build Coastguard Worker if not type_spec.is_assignable_from(data_comp.type_signature.result): 452*14675a02SAndroid Build Coastguard Worker type_mismatch_string = tff.types.type_mismatch_error_message( 453*14675a02SAndroid Build Coastguard Worker type_spec, 454*14675a02SAndroid Build Coastguard Worker data_comp.type_signature.result, 455*14675a02SAndroid Build Coastguard Worker tff.types.TypeRelation.ASSIGNABLE, 456*14675a02SAndroid Build Coastguard Worker ) 457*14675a02SAndroid Build Coastguard Worker raise TypeError( 458*14675a02SAndroid Build Coastguard Worker 'The data source constructed with the supplied dataspec returns data ' 459*14675a02SAndroid Build Coastguard Worker 'which does not match type of request. Details of the mismatch:\n' 460*14675a02SAndroid Build Coastguard Worker + type_mismatch_string 461*14675a02SAndroid Build Coastguard Worker ) 462*14675a02SAndroid Build Coastguard Worker 463*14675a02SAndroid Build Coastguard Worker 464*14675a02SAndroid Build Coastguard Workerdef make_data_sources_with_dataspec( 465*14675a02SAndroid Build Coastguard Worker type_spec: tff.Type, ds: data_spec.NestedDataSpec 466*14675a02SAndroid Build Coastguard Worker) -> list[tff.Computation]: 467*14675a02SAndroid Build Coastguard Worker """Creates a list of computations that feed data into the graph using specified example selectors. 468*14675a02SAndroid Build Coastguard Worker 469*14675a02SAndroid Build Coastguard Worker The computations use the custom ExternalDataset op to feed in example data. 470*14675a02SAndroid Build Coastguard Worker The computations will expect one input: 471*14675a02SAndroid Build Coastguard Worker -- A token specifying where the data store is on the device. 472*14675a02SAndroid Build Coastguard Worker Example selectors that describes what data to take from the on-device data 473*14675a02SAndroid Build Coastguard Worker store will be hard-coded into the computations. 474*14675a02SAndroid Build Coastguard Worker 475*14675a02SAndroid Build Coastguard Worker Args: 476*14675a02SAndroid Build Coastguard Worker type_spec: The TFF type signature of the output, which must be either a 477*14675a02SAndroid Build Coastguard Worker sequence, or a named tuple of sequences. 478*14675a02SAndroid Build Coastguard Worker ds: Either a single `data_spec.DataSpec`, or a nested structure of these, 479*14675a02SAndroid Build Coastguard Worker made up of Python containers, that exactly matches the structure of the 480*14675a02SAndroid Build Coastguard Worker `type_spec`. 481*14675a02SAndroid Build Coastguard Worker 482*14675a02SAndroid Build Coastguard Worker Returns: 483*14675a02SAndroid Build Coastguard Worker A list of `tff.Computation`s, each of which accepts a single `string`-typed 484*14675a02SAndroid Build Coastguard Worker tensor as input (the token for the ExternalDataset op) and returns a 485*14675a02SAndroid Build Coastguard Worker sequence as output (with the result that matches the corresponding part of 486*14675a02SAndroid Build Coastguard Worker `type_spec`). The computations appear on the list in a depth-first order 487*14675a02SAndroid Build Coastguard Worker (matching exactly the convention used in the 488*14675a02SAndroid Build Coastguard Worker `_list_tensor_names_in_binding()` method below). 489*14675a02SAndroid Build Coastguard Worker 490*14675a02SAndroid Build Coastguard Worker Raises: 491*14675a02SAndroid Build Coastguard Worker TypeError: If the arguments are of the wrong types. 492*14675a02SAndroid Build Coastguard Worker """ 493*14675a02SAndroid Build Coastguard Worker assert ds 494*14675a02SAndroid Build Coastguard Worker type_spec = tff.to_type(type_spec) 495*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 496*14675a02SAndroid Build Coastguard Worker type_spec, (tff.SequenceType, tff.StructType), name='type_spec' 497*14675a02SAndroid Build Coastguard Worker ) 498*14675a02SAndroid Build Coastguard Worker if type_spec.is_sequence(): 499*14675a02SAndroid Build Coastguard Worker type_checks.check_type(ds, data_spec.DataSpec) 500*14675a02SAndroid Build Coastguard Worker assert isinstance(ds, data_spec.DataSpec) 501*14675a02SAndroid Build Coastguard Worker assert ds.example_selector_proto is not None 502*14675a02SAndroid Build Coastguard Worker sel_bytes = ds.example_selector_proto.SerializeToString() 503*14675a02SAndroid Build Coastguard Worker 504*14675a02SAndroid Build Coastguard Worker @tff.tf_computation(tf.string) 505*14675a02SAndroid Build Coastguard Worker def data_comp(token): 506*14675a02SAndroid Build Coastguard Worker """The data source computation. 507*14675a02SAndroid Build Coastguard Worker 508*14675a02SAndroid Build Coastguard Worker Args: 509*14675a02SAndroid Build Coastguard Worker token: The token placeholder tensor (`tf.string`). 510*14675a02SAndroid Build Coastguard Worker 511*14675a02SAndroid Build Coastguard Worker Returns: 512*14675a02SAndroid Build Coastguard Worker An instance of `tf.data.Dataset`. 513*14675a02SAndroid Build Coastguard Worker """ 514*14675a02SAndroid Build Coastguard Worker if ds.preprocessing_fn is not None: 515*14675a02SAndroid Build Coastguard Worker processed_ds = ds.preprocessing_fn( 516*14675a02SAndroid Build Coastguard Worker external_dataset.ExternalDataset(token=token, selector=sel_bytes) 517*14675a02SAndroid Build Coastguard Worker ) 518*14675a02SAndroid Build Coastguard Worker else: 519*14675a02SAndroid Build Coastguard Worker processed_ds = external_dataset.ExternalDataset( 520*14675a02SAndroid Build Coastguard Worker token=token, selector=sel_bytes 521*14675a02SAndroid Build Coastguard Worker ) 522*14675a02SAndroid Build Coastguard Worker 523*14675a02SAndroid Build Coastguard Worker if 'Dataset' not in type(processed_ds).__name__: 524*14675a02SAndroid Build Coastguard Worker raise TypeError( 525*14675a02SAndroid Build Coastguard Worker 'The preprocessing function returned an unrecognized non-dataset ' 526*14675a02SAndroid Build Coastguard Worker 'type {!r}.'.format(type(processed_ds)) 527*14675a02SAndroid Build Coastguard Worker ) 528*14675a02SAndroid Build Coastguard Worker return processed_ds 529*14675a02SAndroid Build Coastguard Worker 530*14675a02SAndroid Build Coastguard Worker _validate_data_comp(data_comp, type_spec) 531*14675a02SAndroid Build Coastguard Worker return [data_comp] 532*14675a02SAndroid Build Coastguard Worker else: 533*14675a02SAndroid Build Coastguard Worker type_spec.check_struct() 534*14675a02SAndroid Build Coastguard Worker if isinstance(ds, data_spec.DataSpec): 535*14675a02SAndroid Build Coastguard Worker raise TypeError( 536*14675a02SAndroid Build Coastguard Worker 'Expected nested structure of `DataSpec`s conforming to ' 537*14675a02SAndroid Build Coastguard Worker f'the structure of the type {type_spec}. ' 538*14675a02SAndroid Build Coastguard Worker 'Found single `DataSpec` instead.' 539*14675a02SAndroid Build Coastguard Worker ) 540*14675a02SAndroid Build Coastguard Worker ds = tff.structure.from_container(ds) 541*14675a02SAndroid Build Coastguard Worker assert isinstance(ds, tff.structure.Struct) 542*14675a02SAndroid Build Coastguard Worker type_spec_elements = tff.structure.to_elements(type_spec) 543*14675a02SAndroid Build Coastguard Worker data_spec_elements = tff.structure.to_elements(ds) 544*14675a02SAndroid Build Coastguard Worker type_spec_element_names = [str(k) for k, _ in type_spec_elements] 545*14675a02SAndroid Build Coastguard Worker data_spec_element_names = [str(k) for k, _ in data_spec_elements] 546*14675a02SAndroid Build Coastguard Worker if type_spec_element_names != data_spec_element_names: 547*14675a02SAndroid Build Coastguard Worker raise TypeError( 548*14675a02SAndroid Build Coastguard Worker 'Type vs. data spec elements names mismatch: {} vs. {}.'.format( 549*14675a02SAndroid Build Coastguard Worker str(type_spec_element_names), str(data_spec_element_names) 550*14675a02SAndroid Build Coastguard Worker ) 551*14675a02SAndroid Build Coastguard Worker ) 552*14675a02SAndroid Build Coastguard Worker elements = [] 553*14675a02SAndroid Build Coastguard Worker for element_index, (_, element_type) in enumerate(type_spec_elements): 554*14675a02SAndroid Build Coastguard Worker elements.extend( 555*14675a02SAndroid Build Coastguard Worker make_data_sources_with_dataspec(element_type, ds[element_index]) 556*14675a02SAndroid Build Coastguard Worker ) 557*14675a02SAndroid Build Coastguard Worker return elements 558*14675a02SAndroid Build Coastguard Worker 559*14675a02SAndroid Build Coastguard Worker 560*14675a02SAndroid Build Coastguard Workerdef make_data_sources_without_dataspec(type_spec) -> list[tff.Computation]: 561*14675a02SAndroid Build Coastguard Worker """Creates a list of computations that feed data into the graph. 562*14675a02SAndroid Build Coastguard Worker 563*14675a02SAndroid Build Coastguard Worker The computations use the custom ExternalDataset op to feed in example data. 564*14675a02SAndroid Build Coastguard Worker The computations will expect two inputs: 565*14675a02SAndroid Build Coastguard Worker -- A token specifying where the data store is on the device. 566*14675a02SAndroid Build Coastguard Worker -- An example selector that describes what data to take from the on-device 567*14675a02SAndroid Build Coastguard Worker data store. 568*14675a02SAndroid Build Coastguard Worker 569*14675a02SAndroid Build Coastguard Worker Args: 570*14675a02SAndroid Build Coastguard Worker type_spec: The TFF type signature of the output, which must be either a 571*14675a02SAndroid Build Coastguard Worker sequence, or a named tuple of sequences. 572*14675a02SAndroid Build Coastguard Worker 573*14675a02SAndroid Build Coastguard Worker Returns: 574*14675a02SAndroid Build Coastguard Worker A list of `tff.Computation`s, each of which accepts a single `string`-typed 575*14675a02SAndroid Build Coastguard Worker tensor as input (the token for the ExternalDataset op) and returns a 576*14675a02SAndroid Build Coastguard Worker sequence as output (with the result that matches the corresponding part of 577*14675a02SAndroid Build Coastguard Worker `type_spec`). The computations appear on the list in a depth-first order 578*14675a02SAndroid Build Coastguard Worker (matching exactly the convention used in the 579*14675a02SAndroid Build Coastguard Worker `_list_tensor_names_in_binding()` method below). 580*14675a02SAndroid Build Coastguard Worker 581*14675a02SAndroid Build Coastguard Worker Raises: 582*14675a02SAndroid Build Coastguard Worker TypeError: If the arguments are of the wrong types. 583*14675a02SAndroid Build Coastguard Worker """ 584*14675a02SAndroid Build Coastguard Worker type_spec = tff.to_type(type_spec) 585*14675a02SAndroid Build Coastguard Worker type_checks.check_type( 586*14675a02SAndroid Build Coastguard Worker type_spec, (tff.SequenceType, tff.StructType), name='type_spec' 587*14675a02SAndroid Build Coastguard Worker ) 588*14675a02SAndroid Build Coastguard Worker if type_spec.is_sequence(): 589*14675a02SAndroid Build Coastguard Worker 590*14675a02SAndroid Build Coastguard Worker @tff.tf_computation(tf.string, tf.string) 591*14675a02SAndroid Build Coastguard Worker def data_comp(token, example_selector): 592*14675a02SAndroid Build Coastguard Worker """The data source computation. 593*14675a02SAndroid Build Coastguard Worker 594*14675a02SAndroid Build Coastguard Worker Args: 595*14675a02SAndroid Build Coastguard Worker token: The token placeholder tensor (`tf.string`). 596*14675a02SAndroid Build Coastguard Worker example_selector: The example selector placeholder tensor (`tf.string`). 597*14675a02SAndroid Build Coastguard Worker 598*14675a02SAndroid Build Coastguard Worker Returns: 599*14675a02SAndroid Build Coastguard Worker An instance of `tf.data.Dataset`. 600*14675a02SAndroid Build Coastguard Worker """ 601*14675a02SAndroid Build Coastguard Worker processed_ds = external_dataset.ExternalDataset( 602*14675a02SAndroid Build Coastguard Worker token=token, selector=example_selector 603*14675a02SAndroid Build Coastguard Worker ) 604*14675a02SAndroid Build Coastguard Worker 605*14675a02SAndroid Build Coastguard Worker if 'Dataset' not in type(processed_ds).__name__: 606*14675a02SAndroid Build Coastguard Worker raise TypeError( 607*14675a02SAndroid Build Coastguard Worker 'The preprocessing function returned an unrecognized non-dataset ' 608*14675a02SAndroid Build Coastguard Worker 'type {!r}.'.format(type(processed_ds)) 609*14675a02SAndroid Build Coastguard Worker ) 610*14675a02SAndroid Build Coastguard Worker return processed_ds 611*14675a02SAndroid Build Coastguard Worker 612*14675a02SAndroid Build Coastguard Worker _validate_data_comp(data_comp, type_spec) 613*14675a02SAndroid Build Coastguard Worker return [data_comp] 614*14675a02SAndroid Build Coastguard Worker else: # type_spec is a struct. 615*14675a02SAndroid Build Coastguard Worker type_spec.check_struct() 616*14675a02SAndroid Build Coastguard Worker type_spec_elements = tff.structure.to_elements(type_spec) 617*14675a02SAndroid Build Coastguard Worker elements = [] 618*14675a02SAndroid Build Coastguard Worker for _, element_type in type_spec_elements: 619*14675a02SAndroid Build Coastguard Worker elements.extend(make_data_sources_without_dataspec(element_type)) 620*14675a02SAndroid Build Coastguard Worker return elements 621*14675a02SAndroid Build Coastguard Worker 622*14675a02SAndroid Build Coastguard Worker 623*14675a02SAndroid Build Coastguard Workerdef _list_tensor_names_in_binding( 624*14675a02SAndroid Build Coastguard Worker binding: computation_pb2.TensorFlow.Binding, 625*14675a02SAndroid Build Coastguard Worker) -> list[str]: 626*14675a02SAndroid Build Coastguard Worker """Returns a flat Python list of tensor names that appear in the `binding`. 627*14675a02SAndroid Build Coastguard Worker 628*14675a02SAndroid Build Coastguard Worker Args: 629*14675a02SAndroid Build Coastguard Worker binding: An instance of `computation_pb2.TensorFlow.Binding` in which any 630*14675a02SAndroid Build Coastguard Worker sequence bindings must contain variant tensors. 631*14675a02SAndroid Build Coastguard Worker 632*14675a02SAndroid Build Coastguard Worker Returns: 633*14675a02SAndroid Build Coastguard Worker A list of `str` instances with tensor names that appear in `binding` in the 634*14675a02SAndroid Build Coastguard Worker order in which they appear in the depth-first traversal of the potentially 635*14675a02SAndroid Build Coastguard Worker nested binding structure. 636*14675a02SAndroid Build Coastguard Worker 637*14675a02SAndroid Build Coastguard Worker Raises: 638*14675a02SAndroid Build Coastguard Worker TypeError: If the arguments are of the wrong types. 639*14675a02SAndroid Build Coastguard Worker """ 640*14675a02SAndroid Build Coastguard Worker type_checks.check_type(binding, computation_pb2.TensorFlow.Binding) 641*14675a02SAndroid Build Coastguard Worker which_binding = binding.WhichOneof('binding') 642*14675a02SAndroid Build Coastguard Worker if which_binding == 'tensor': 643*14675a02SAndroid Build Coastguard Worker return [str(binding.tensor.tensor_name)] 644*14675a02SAndroid Build Coastguard Worker elif which_binding == 'struct': 645*14675a02SAndroid Build Coastguard Worker result = [] 646*14675a02SAndroid Build Coastguard Worker for element in binding.struct.element: 647*14675a02SAndroid Build Coastguard Worker result.extend(_list_tensor_names_in_binding(element)) 648*14675a02SAndroid Build Coastguard Worker return result 649*14675a02SAndroid Build Coastguard Worker elif which_binding == 'sequence': 650*14675a02SAndroid Build Coastguard Worker which_sequence = binding.sequence.WhichOneof('binding') 651*14675a02SAndroid Build Coastguard Worker if which_sequence != 'variant_tensor_name': 652*14675a02SAndroid Build Coastguard Worker raise TypeError( 653*14675a02SAndroid Build Coastguard Worker 'Expected a variant tensor in sequence binding, found {}.'.format( 654*14675a02SAndroid Build Coastguard Worker which_sequence 655*14675a02SAndroid Build Coastguard Worker ) 656*14675a02SAndroid Build Coastguard Worker ) 657*14675a02SAndroid Build Coastguard Worker return [binding.sequence.variant_tensor_name] 658*14675a02SAndroid Build Coastguard Worker else: 659*14675a02SAndroid Build Coastguard Worker raise TypeError('Unexpected type of binding {}.'.format(which_binding)) 660