xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/graph_helpers.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Utilities for manipulating TensorFlow graph logic."""
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard Workerfrom typing import Optional, Union
17*14675a02SAndroid Build Coastguard Worker
18*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
19*14675a02SAndroid Build Coastguard Workerimport tensorflow_federated as tff
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import data_spec
22*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import tensor_utils
23*14675a02SAndroid Build Coastguard Workerfrom fcp.artifact_building import type_checks
24*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow import external_dataset
25*14675a02SAndroid Build Coastguard Workerfrom tensorflow_federated.proto.v0 import computation_pb2
26*14675a02SAndroid Build Coastguard Worker
27*14675a02SAndroid Build Coastguard WorkerTfValue = Union[tf.Variable, tf.Tensor]
28*14675a02SAndroid Build Coastguard WorkerDatasetTensor = tf.Tensor
29*14675a02SAndroid Build Coastguard WorkerArgument = Union[TfValue, list[TfValue], DatasetTensor]
30*14675a02SAndroid Build Coastguard WorkerArgs = Optional[Union[Argument, tuple[Argument, ...]]]
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard WorkerResult = Argument
33*14675a02SAndroid Build Coastguard WorkerMaybeSplitOutputs = Union[Result, tuple[Result, ...]]
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard Worker
36*14675a02SAndroid Build Coastguard WorkerEXAMPLE_SELECTOR_PLACEHOLDER_PREFIX = 'example_selector'
37*14675a02SAndroid Build Coastguard Worker
38*14675a02SAndroid Build Coastguard Worker
39*14675a02SAndroid Build Coastguard Workerdef generate_example_selector_placeholders(
40*14675a02SAndroid Build Coastguard Worker    type_spec: tff.Type,
41*14675a02SAndroid Build Coastguard Worker    name_prefix: str,
42*14675a02SAndroid Build Coastguard Worker):
43*14675a02SAndroid Build Coastguard Worker  """Generates list of tff.compat.v1.placeholders for each leaf in a type spec.
44*14675a02SAndroid Build Coastguard Worker
45*14675a02SAndroid Build Coastguard Worker  The order of the placeholders aligns with the order given by
46*14675a02SAndroid Build Coastguard Worker  tff.structure.to_elements().
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Worker  Placeholders will be named by concatenating the name_prefix arg with the list
49*14675a02SAndroid Build Coastguard Worker  of indexes at each level of the struct to get to the placeholder's leaf in the
50*14675a02SAndroid Build Coastguard Worker  tff.Type.
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Worker  Args:
53*14675a02SAndroid Build Coastguard Worker    type_spec: A type spec to infer the list of placeholders from. This is
54*14675a02SAndroid Build Coastguard Worker      expected to be a tff.SequenceType or a tff.StructType, and if it is a
55*14675a02SAndroid Build Coastguard Worker      tff.StructType, it is expected to be a tree of tff.StructTypes with
56*14675a02SAndroid Build Coastguard Worker      tff.SequenceTypes at the leaves. This is expected to reflect the TFF type
57*14675a02SAndroid Build Coastguard Worker      signature of the input client data.
58*14675a02SAndroid Build Coastguard Worker    name_prefix: The name prefix that should be used when naming each
59*14675a02SAndroid Build Coastguard Worker      placeholder.
60*14675a02SAndroid Build Coastguard Worker
61*14675a02SAndroid Build Coastguard Worker  Returns:
62*14675a02SAndroid Build Coastguard Worker    A list of tf.compat.v2.placeholders.
63*14675a02SAndroid Build Coastguard Worker  """
64*14675a02SAndroid Build Coastguard Worker  type_spec = tff.to_type(type_spec)
65*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
66*14675a02SAndroid Build Coastguard Worker      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
67*14675a02SAndroid Build Coastguard Worker  )
68*14675a02SAndroid Build Coastguard Worker  if type_spec.is_sequence():
69*14675a02SAndroid Build Coastguard Worker    # Each client input is a sequence of serialized `tf.Example`s, which is why
70*14675a02SAndroid Build Coastguard Worker    # the leaves of these TFF type signatures are sequences. Each input sequence
71*14675a02SAndroid Build Coastguard Worker    # of `tf.Example`s requires a single `ExampleSelector` that determines that
72*14675a02SAndroid Build Coastguard Worker    # stream of `tf.Example`s is selected from the data store, which is why we
73*14675a02SAndroid Build Coastguard Worker    # only have a single placeholder for the `ExampleSelector`.
74*14675a02SAndroid Build Coastguard Worker    return [tf.compat.v1.placeholder(tf.string, shape=[], name=name_prefix)]
75*14675a02SAndroid Build Coastguard Worker  else:
76*14675a02SAndroid Build Coastguard Worker    type_spec.check_struct()
77*14675a02SAndroid Build Coastguard Worker    type_spec_elements = tff.structure.to_elements(type_spec)
78*14675a02SAndroid Build Coastguard Worker    placeholders = []
79*14675a02SAndroid Build Coastguard Worker    for element_index, (_, element_type) in enumerate(type_spec_elements):
80*14675a02SAndroid Build Coastguard Worker      placeholders.extend(
81*14675a02SAndroid Build Coastguard Worker          generate_example_selector_placeholders(
82*14675a02SAndroid Build Coastguard Worker              element_type, f'{name_prefix}_{element_index}'
83*14675a02SAndroid Build Coastguard Worker          )
84*14675a02SAndroid Build Coastguard Worker      )
85*14675a02SAndroid Build Coastguard Worker    return placeholders
86*14675a02SAndroid Build Coastguard Worker
87*14675a02SAndroid Build Coastguard Worker
88*14675a02SAndroid Build Coastguard Workerdef embed_data_logic(
89*14675a02SAndroid Build Coastguard Worker    client_data_type: tff.Type,
90*14675a02SAndroid Build Coastguard Worker    dataspec: Optional[data_spec.NestedDataSpec] = None,
91*14675a02SAndroid Build Coastguard Worker) -> tuple[tf.Tensor, list[MaybeSplitOutputs], list[tf.Tensor]]:
92*14675a02SAndroid Build Coastguard Worker  """Embeds the data logic into the current TensorFlow graph.
93*14675a02SAndroid Build Coastguard Worker
94*14675a02SAndroid Build Coastguard Worker  Adds dataset ops to the current graph, using the custom `ExternalDataset`
95*14675a02SAndroid Build Coastguard Worker  which returns a placeholder token. The initialization op and data values are
96*14675a02SAndroid Build Coastguard Worker  also returned.
97*14675a02SAndroid Build Coastguard Worker
98*14675a02SAndroid Build Coastguard Worker  Args:
99*14675a02SAndroid Build Coastguard Worker    client_data_type: The TFF type signature of the input client data.
100*14675a02SAndroid Build Coastguard Worker    dataspec: If provided, either an instance of `data_spec.DataSpec` or a
101*14675a02SAndroid Build Coastguard Worker      nested structure of these that matches the structure of the first element
102*14675a02SAndroid Build Coastguard Worker      of the input to the client work part of the computation.
103*14675a02SAndroid Build Coastguard Worker
104*14675a02SAndroid Build Coastguard Worker  Returns:
105*14675a02SAndroid Build Coastguard Worker    A `tuple` containing the following (in order):
106*14675a02SAndroid Build Coastguard Worker      token_placeholder: A dataset token placeholder tensor
107*14675a02SAndroid Build Coastguard Worker      data_values: A list of dataset output values
108*14675a02SAndroid Build Coastguard Worker      example_selector_placeholders: A possibly empty list of placeholders used
109*14675a02SAndroid Build Coastguard Worker        for passing in example selector information into the client graph. This
110*14675a02SAndroid Build Coastguard Worker        list will be empty iff dataspec is supplied.
111*14675a02SAndroid Build Coastguard Worker
112*14675a02SAndroid Build Coastguard Worker  Raises:
113*14675a02SAndroid Build Coastguard Worker    ValueError: If the number of dataset output from one data source is not 1.
114*14675a02SAndroid Build Coastguard Worker    ValueError: If a node exists in the graph already that contains a node with
115*14675a02SAndroid Build Coastguard Worker      the same name as the example selector placeholders.
116*14675a02SAndroid Build Coastguard Worker  """
117*14675a02SAndroid Build Coastguard Worker  data_values = []
118*14675a02SAndroid Build Coastguard Worker  # Embeds the token placeholder for the custom ExternalDataset op.
119*14675a02SAndroid Build Coastguard Worker  token_placeholder = tf.compat.v1.placeholder(
120*14675a02SAndroid Build Coastguard Worker      tf.string, shape=[], name='data_token'
121*14675a02SAndroid Build Coastguard Worker  )
122*14675a02SAndroid Build Coastguard Worker
123*14675a02SAndroid Build Coastguard Worker  example_selector_placeholders = []
124*14675a02SAndroid Build Coastguard Worker  if dataspec is None:
125*14675a02SAndroid Build Coastguard Worker    example_selector_placeholders = generate_example_selector_placeholders(
126*14675a02SAndroid Build Coastguard Worker        client_data_type, EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX
127*14675a02SAndroid Build Coastguard Worker    )
128*14675a02SAndroid Build Coastguard Worker    # If the first placeholder does not have the expected prefix, then it is due
129*14675a02SAndroid Build Coastguard Worker    # to other variables in the graph, likely created from the input
130*14675a02SAndroid Build Coastguard Worker    # tff.Computation, having the special name. This check ensures that no other
131*14675a02SAndroid Build Coastguard Worker    # variables use this special example selector placeholder name and we can
132*14675a02SAndroid Build Coastguard Worker    # easily extract example selector placeholders in the generated artifact.
133*14675a02SAndroid Build Coastguard Worker    if example_selector_placeholders and (
134*14675a02SAndroid Build Coastguard Worker        not (
135*14675a02SAndroid Build Coastguard Worker            example_selector_placeholders[0].name.startswith(
136*14675a02SAndroid Build Coastguard Worker                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}:'
137*14675a02SAndroid Build Coastguard Worker            )
138*14675a02SAndroid Build Coastguard Worker            or example_selector_placeholders[0].name.startswith(
139*14675a02SAndroid Build Coastguard Worker                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}_0'
140*14675a02SAndroid Build Coastguard Worker            )
141*14675a02SAndroid Build Coastguard Worker        )
142*14675a02SAndroid Build Coastguard Worker    ):
143*14675a02SAndroid Build Coastguard Worker      raise ValueError(
144*14675a02SAndroid Build Coastguard Worker          'Graph already contains a placeholder with name '
145*14675a02SAndroid Build Coastguard Worker          f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}. Please '
146*14675a02SAndroid Build Coastguard Worker          'avoid the use of this special name.'
147*14675a02SAndroid Build Coastguard Worker      )
148*14675a02SAndroid Build Coastguard Worker    data_sources = make_data_sources_without_dataspec(client_data_type)
149*14675a02SAndroid Build Coastguard Worker    assert len(example_selector_placeholders) == len(data_sources)
150*14675a02SAndroid Build Coastguard Worker  else:
151*14675a02SAndroid Build Coastguard Worker    data_sources = make_data_sources_with_dataspec(client_data_type, dataspec)
152*14675a02SAndroid Build Coastguard Worker
153*14675a02SAndroid Build Coastguard Worker  # Embeds data source computations into the current graph.
154*14675a02SAndroid Build Coastguard Worker  for index, data_comp in enumerate(data_sources):
155*14675a02SAndroid Build Coastguard Worker    data_comp_import_args = [token_placeholder]
156*14675a02SAndroid Build Coastguard Worker    if example_selector_placeholders:
157*14675a02SAndroid Build Coastguard Worker      data_comp_import_args.append(example_selector_placeholders[index])
158*14675a02SAndroid Build Coastguard Worker    ds_values = import_tensorflow(
159*14675a02SAndroid Build Coastguard Worker        'data_{}'.format(index), data_comp, data_comp_import_args
160*14675a02SAndroid Build Coastguard Worker    )  # pytype: disable=wrong-arg-types
161*14675a02SAndroid Build Coastguard Worker    if len(ds_values) != 1:
162*14675a02SAndroid Build Coastguard Worker      raise ValueError(
163*14675a02SAndroid Build Coastguard Worker          'Expected one dataset output from a data source, found {}.'.format(
164*14675a02SAndroid Build Coastguard Worker              str(len(ds_values))
165*14675a02SAndroid Build Coastguard Worker          )
166*14675a02SAndroid Build Coastguard Worker      )
167*14675a02SAndroid Build Coastguard Worker    data_values.extend(ds_values)
168*14675a02SAndroid Build Coastguard Worker
169*14675a02SAndroid Build Coastguard Worker  return token_placeholder, data_values, example_selector_placeholders
170*14675a02SAndroid Build Coastguard Worker
171*14675a02SAndroid Build Coastguard Worker
172*14675a02SAndroid Build Coastguard Workerdef import_tensorflow(
173*14675a02SAndroid Build Coastguard Worker    name: str,
174*14675a02SAndroid Build Coastguard Worker    comp: tff.framework.ConcreteComputation,
175*14675a02SAndroid Build Coastguard Worker    args: Args = None,
176*14675a02SAndroid Build Coastguard Worker    split_outputs: bool = False,
177*14675a02SAndroid Build Coastguard Worker    session_token_tensor: Optional[tf.Tensor] = None,
178*14675a02SAndroid Build Coastguard Worker) -> MaybeSplitOutputs:
179*14675a02SAndroid Build Coastguard Worker  """Imports a tensorflow computation into the current graph.
180*14675a02SAndroid Build Coastguard Worker
181*14675a02SAndroid Build Coastguard Worker  Args:
182*14675a02SAndroid Build Coastguard Worker    name: The string name to use as the graph import prefix.
183*14675a02SAndroid Build Coastguard Worker    comp: An instance of `tff.framework.ConcreteComputation` with just the
184*14675a02SAndroid Build Coastguard Worker      `tensorflow` section.
185*14675a02SAndroid Build Coastguard Worker    args: Either a single argument, a tuple of arguments, or None. An argument
186*14675a02SAndroid Build Coastguard Worker      must be either: - a Python `list` containing either tensors or variables,
187*14675a02SAndroid Build Coastguard Worker      or - a single variant tensor representing a dataset input.
188*14675a02SAndroid Build Coastguard Worker    split_outputs: Whether to unpack the result tuple into a Python tuple. If
189*14675a02SAndroid Build Coastguard Worker      `True`, `import_tensorflow` will return a tuple with multiple result
190*14675a02SAndroid Build Coastguard Worker      objects, corresponding to the return elements in the type signature of
191*14675a02SAndroid Build Coastguard Worker      `comp`. Notice that the return type signature of `comp` must be a tuple in
192*14675a02SAndroid Build Coastguard Worker      this case. If `False`, `import_tensorflow` will return the entire result
193*14675a02SAndroid Build Coastguard Worker      in a flattened form as a single Python result object. Each Python result
194*14675a02SAndroid Build Coastguard Worker      object, similar to the argumens in `args`, will be either a Python `list`
195*14675a02SAndroid Build Coastguard Worker      of variant tensors or a singleton Python list containing only the dataset
196*14675a02SAndroid Build Coastguard Worker      variant tensor.
197*14675a02SAndroid Build Coastguard Worker    session_token_tensor: A tensor in the current graph containing the "session
198*14675a02SAndroid Build Coastguard Worker      token" of the TensorFlow being imported. This is useful for passing a
199*14675a02SAndroid Build Coastguard Worker      session-global identifier into the graph for use with ops like
200*14675a02SAndroid Build Coastguard Worker      `ServeSlices` and `ExternalDataset` that take in a token which references
201*14675a02SAndroid Build Coastguard Worker      session-global state.
202*14675a02SAndroid Build Coastguard Worker
203*14675a02SAndroid Build Coastguard Worker  Returns:
204*14675a02SAndroid Build Coastguard Worker    One of:
205*14675a02SAndroid Build Coastguard Worker      - A single result (Python `list` of variable value or variant tensors) if
206*14675a02SAndroid Build Coastguard Worker        `split_outputs` is `False`.
207*14675a02SAndroid Build Coastguard Worker      - A Python `tuple` of such results, if `split_outputs` is `True`.
208*14675a02SAndroid Build Coastguard Worker
209*14675a02SAndroid Build Coastguard Worker  Raises:
210*14675a02SAndroid Build Coastguard Worker    TypeError: If the arguments are of the wrong types.
211*14675a02SAndroid Build Coastguard Worker  """
212*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(name, str, name='name')
213*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(comp, tff.framework.ConcreteComputation, name='comp')
214*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(split_outputs, bool, name='split_outputs')
215*14675a02SAndroid Build Coastguard Worker
216*14675a02SAndroid Build Coastguard Worker  comp_proto = tff.framework.ConcreteComputation.get_proto(comp)
217*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
218*14675a02SAndroid Build Coastguard Worker      comp_proto, computation_pb2.Computation, name='comp_proto'
219*14675a02SAndroid Build Coastguard Worker  )
220*14675a02SAndroid Build Coastguard Worker
221*14675a02SAndroid Build Coastguard Worker  which_comp = comp_proto.WhichOneof('computation')
222*14675a02SAndroid Build Coastguard Worker  if which_comp != 'tensorflow':
223*14675a02SAndroid Build Coastguard Worker    raise TypeError(
224*14675a02SAndroid Build Coastguard Worker        'Expected a TensorFlow computation, found {}.'.format(which_comp)
225*14675a02SAndroid Build Coastguard Worker    )
226*14675a02SAndroid Build Coastguard Worker  if args is None:
227*14675a02SAndroid Build Coastguard Worker    input_map = None
228*14675a02SAndroid Build Coastguard Worker  elif isinstance(args, tuple):
229*14675a02SAndroid Build Coastguard Worker    which_binding = comp_proto.tensorflow.parameter.WhichOneof('binding')
230*14675a02SAndroid Build Coastguard Worker    if which_binding != 'struct':
231*14675a02SAndroid Build Coastguard Worker      raise TypeError(
232*14675a02SAndroid Build Coastguard Worker          'Expected a struct binding with a struct of args, found {}.'.format(
233*14675a02SAndroid Build Coastguard Worker              which_binding
234*14675a02SAndroid Build Coastguard Worker          )
235*14675a02SAndroid Build Coastguard Worker      )
236*14675a02SAndroid Build Coastguard Worker    input_map = {}
237*14675a02SAndroid Build Coastguard Worker    for index, arg in enumerate(args):
238*14675a02SAndroid Build Coastguard Worker      input_map.update(
239*14675a02SAndroid Build Coastguard Worker          create_tensor_map(
240*14675a02SAndroid Build Coastguard Worker              comp_proto.tensorflow.parameter.struct.element[index], arg
241*14675a02SAndroid Build Coastguard Worker          )
242*14675a02SAndroid Build Coastguard Worker      )
243*14675a02SAndroid Build Coastguard Worker  else:
244*14675a02SAndroid Build Coastguard Worker    input_map = create_tensor_map(comp_proto.tensorflow.parameter, args)
245*14675a02SAndroid Build Coastguard Worker  if input_map is not None:
246*14675a02SAndroid Build Coastguard Worker    # Add remappings for all potential control dependencies in the graph as
247*14675a02SAndroid Build Coastguard Worker    # well. Since `tf.graph_util.import_graph_def` input map works on the tensor
248*14675a02SAndroid Build Coastguard Worker    # (not graph node) level, we must handle this case also.
249*14675a02SAndroid Build Coastguard Worker    def control_dep_name(name: str) -> str:
250*14675a02SAndroid Build Coastguard Worker      if name.startswith('^'):
251*14675a02SAndroid Build Coastguard Worker        return name
252*14675a02SAndroid Build Coastguard Worker      node_name = name.split(':', maxsplit=1)[0]
253*14675a02SAndroid Build Coastguard Worker      return f'^{node_name}'
254*14675a02SAndroid Build Coastguard Worker
255*14675a02SAndroid Build Coastguard Worker    input_map.update(
256*14675a02SAndroid Build Coastguard Worker        {
257*14675a02SAndroid Build Coastguard Worker            control_dep_name(k): control_dep_name(v.name)
258*14675a02SAndroid Build Coastguard Worker            for k, v in input_map.items()
259*14675a02SAndroid Build Coastguard Worker            if not k.startswith('^')
260*14675a02SAndroid Build Coastguard Worker        }
261*14675a02SAndroid Build Coastguard Worker    )
262*14675a02SAndroid Build Coastguard Worker  input_map = {} if input_map is None else input_map
263*14675a02SAndroid Build Coastguard Worker  if (
264*14675a02SAndroid Build Coastguard Worker      session_token_tensor is not None
265*14675a02SAndroid Build Coastguard Worker      and comp_proto.tensorflow.session_token_tensor_name
266*14675a02SAndroid Build Coastguard Worker  ):
267*14675a02SAndroid Build Coastguard Worker    input_map[comp_proto.tensorflow.session_token_tensor_name] = (
268*14675a02SAndroid Build Coastguard Worker        session_token_tensor
269*14675a02SAndroid Build Coastguard Worker    )
270*14675a02SAndroid Build Coastguard Worker  if split_outputs:
271*14675a02SAndroid Build Coastguard Worker    return_elements = []
272*14675a02SAndroid Build Coastguard Worker    subset_sizes = []
273*14675a02SAndroid Build Coastguard Worker    which_binding = comp_proto.tensorflow.result.WhichOneof('binding')
274*14675a02SAndroid Build Coastguard Worker    if which_binding != 'struct':
275*14675a02SAndroid Build Coastguard Worker      raise TypeError(
276*14675a02SAndroid Build Coastguard Worker          'If `split_outputs` is `True`, the result of the computation we are '
277*14675a02SAndroid Build Coastguard Worker          'importing must be a `struct`; found {}.'.format(which_binding)
278*14675a02SAndroid Build Coastguard Worker      )
279*14675a02SAndroid Build Coastguard Worker    for binding in comp_proto.tensorflow.result.struct.element:
280*14675a02SAndroid Build Coastguard Worker      tensor_names = _list_tensor_names_in_binding(binding)
281*14675a02SAndroid Build Coastguard Worker      return_elements.extend(tensor_names)
282*14675a02SAndroid Build Coastguard Worker      subset_sizes.append(len(tensor_names))
283*14675a02SAndroid Build Coastguard Worker  else:
284*14675a02SAndroid Build Coastguard Worker    return_elements = _list_tensor_names_in_binding(
285*14675a02SAndroid Build Coastguard Worker        comp_proto.tensorflow.result
286*14675a02SAndroid Build Coastguard Worker    )
287*14675a02SAndroid Build Coastguard Worker    subset_sizes = [len(return_elements)]
288*14675a02SAndroid Build Coastguard Worker
289*14675a02SAndroid Build Coastguard Worker  graph_def = tensor_utils.import_graph_def_from_any(
290*14675a02SAndroid Build Coastguard Worker      comp_proto.tensorflow.graph_def
291*14675a02SAndroid Build Coastguard Worker  )
292*14675a02SAndroid Build Coastguard Worker
293*14675a02SAndroid Build Coastguard Worker  # We will be importing multiple GraphDefs into the server or client graphs.
294*14675a02SAndroid Build Coastguard Worker  # These individual graphs may have identifical `shared_name` attributes on
295*14675a02SAndroid Build Coastguard Worker  # variable ops, which causes the runtime to reference the same resource, which
296*14675a02SAndroid Build Coastguard Worker  # is highly undesired. We must uniquify the names before importing.
297*14675a02SAndroid Build Coastguard Worker  def uniquify_shared_names(
298*14675a02SAndroid Build Coastguard Worker      graph_def: tf.compat.v1.GraphDef, suffix: bytes
299*14675a02SAndroid Build Coastguard Worker  ) -> tf.compat.v1.GraphDef:
300*14675a02SAndroid Build Coastguard Worker    for x in graph_def.node:
301*14675a02SAndroid Build Coastguard Worker      shared_name = x.attr.get('shared_name')
302*14675a02SAndroid Build Coastguard Worker      if shared_name is not None:
303*14675a02SAndroid Build Coastguard Worker        if not shared_name.s:
304*14675a02SAndroid Build Coastguard Worker          # Encountered an empty string shared name, avoid creating a shared
305*14675a02SAndroid Build Coastguard Worker          # name that starts with an underscore (not allowed by TF).
306*14675a02SAndroid Build Coastguard Worker          shared_name.s = b'None'
307*14675a02SAndroid Build Coastguard Worker        shared_name.s += b'_' + suffix
308*14675a02SAndroid Build Coastguard Worker    return graph_def
309*14675a02SAndroid Build Coastguard Worker
310*14675a02SAndroid Build Coastguard Worker  uniquified_graph_def = uniquify_shared_names(
311*14675a02SAndroid Build Coastguard Worker      graph_def, suffix=name.encode('utf-8')
312*14675a02SAndroid Build Coastguard Worker  )
313*14675a02SAndroid Build Coastguard Worker  if comp_proto.tensorflow.initialize_op:
314*14675a02SAndroid Build Coastguard Worker    uniquified_graph_def = add_control_deps_for_init_op(
315*14675a02SAndroid Build Coastguard Worker        uniquified_graph_def, comp_proto.tensorflow.initialize_op
316*14675a02SAndroid Build Coastguard Worker    )
317*14675a02SAndroid Build Coastguard Worker  import_result = tf.graph_util.import_graph_def(
318*14675a02SAndroid Build Coastguard Worker      uniquified_graph_def,
319*14675a02SAndroid Build Coastguard Worker      input_map=input_map,
320*14675a02SAndroid Build Coastguard Worker      return_elements=return_elements,
321*14675a02SAndroid Build Coastguard Worker      name=name,
322*14675a02SAndroid Build Coastguard Worker  )
323*14675a02SAndroid Build Coastguard Worker
324*14675a02SAndroid Build Coastguard Worker  if split_outputs:
325*14675a02SAndroid Build Coastguard Worker    subsets = []
326*14675a02SAndroid Build Coastguard Worker    offset = 0
327*14675a02SAndroid Build Coastguard Worker    for subset_size in subset_sizes:
328*14675a02SAndroid Build Coastguard Worker      next_offset = offset + subset_size
329*14675a02SAndroid Build Coastguard Worker      subsets.append(import_result[offset:next_offset])
330*14675a02SAndroid Build Coastguard Worker      offset = next_offset
331*14675a02SAndroid Build Coastguard Worker    results = tuple(subsets)
332*14675a02SAndroid Build Coastguard Worker  else:
333*14675a02SAndroid Build Coastguard Worker    results = import_result[: subset_sizes[0]]
334*14675a02SAndroid Build Coastguard Worker  return results
335*14675a02SAndroid Build Coastguard Worker
336*14675a02SAndroid Build Coastguard Worker
337*14675a02SAndroid Build Coastguard Workerdef _get_deps_for_graph_node(
338*14675a02SAndroid Build Coastguard Worker    graph_def: tf.compat.v1.GraphDef, node_name: str
339*14675a02SAndroid Build Coastguard Worker) -> set[str]:
340*14675a02SAndroid Build Coastguard Worker  """Returns the set of node names that a node named `node_name` depends on.
341*14675a02SAndroid Build Coastguard Worker
342*14675a02SAndroid Build Coastguard Worker  Note that this function does not work for nodes in the function library.
343*14675a02SAndroid Build Coastguard Worker
344*14675a02SAndroid Build Coastguard Worker  Args:
345*14675a02SAndroid Build Coastguard Worker    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
346*14675a02SAndroid Build Coastguard Worker    node_name: The node name, a string.
347*14675a02SAndroid Build Coastguard Worker
348*14675a02SAndroid Build Coastguard Worker  Returns:
349*14675a02SAndroid Build Coastguard Worker    An instance of `set()` containing string names of the nodes `node_name`
350*14675a02SAndroid Build Coastguard Worker    depends on in graph_def.
351*14675a02SAndroid Build Coastguard Worker
352*14675a02SAndroid Build Coastguard Worker  Raises:
353*14675a02SAndroid Build Coastguard Worker    TypeError: If either argument is of the wrong type.
354*14675a02SAndroid Build Coastguard Worker  """
355*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
356*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(node_name, str, name='node_name')
357*14675a02SAndroid Build Coastguard Worker  input_map = {}
358*14675a02SAndroid Build Coastguard Worker  for node in graph_def.node:
359*14675a02SAndroid Build Coastguard Worker    input_map[node.name] = set(tensor_utils.bare_name(x) for x in node.input)
360*14675a02SAndroid Build Coastguard Worker  dependencies = set()
361*14675a02SAndroid Build Coastguard Worker  initial_singleton = set([node_name])
362*14675a02SAndroid Build Coastguard Worker  nodes_to_process = initial_singleton
363*14675a02SAndroid Build Coastguard Worker  while nodes_to_process:
364*14675a02SAndroid Build Coastguard Worker    dependencies.update(nodes_to_process)
365*14675a02SAndroid Build Coastguard Worker    nodes_to_process = set.union(
366*14675a02SAndroid Build Coastguard Worker        *[input_map[name] for name in nodes_to_process]
367*14675a02SAndroid Build Coastguard Worker    ).difference(dependencies)
368*14675a02SAndroid Build Coastguard Worker  return dependencies.difference(initial_singleton)
369*14675a02SAndroid Build Coastguard Worker
370*14675a02SAndroid Build Coastguard Worker
371*14675a02SAndroid Build Coastguard Workerdef add_control_deps_for_init_op(
372*14675a02SAndroid Build Coastguard Worker    graph_def: tf.compat.v1.GraphDef, init_op: str
373*14675a02SAndroid Build Coastguard Worker) -> tf.compat.v1.GraphDef:
374*14675a02SAndroid Build Coastguard Worker  """Adds control deps on `init_op` to nodes in GraphDef.
375*14675a02SAndroid Build Coastguard Worker
376*14675a02SAndroid Build Coastguard Worker  Note that control deps are not added to any of the ancestors of `init_op`
377*14675a02SAndroid Build Coastguard Worker  (which would result in a control dep cycle) and control deps are not added to
378*14675a02SAndroid Build Coastguard Worker  any nodes in the function library of a GraphDef.
379*14675a02SAndroid Build Coastguard Worker
380*14675a02SAndroid Build Coastguard Worker  Args:
381*14675a02SAndroid Build Coastguard Worker    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
382*14675a02SAndroid Build Coastguard Worker    init_op: The init op name, a string.
383*14675a02SAndroid Build Coastguard Worker
384*14675a02SAndroid Build Coastguard Worker  Returns:
385*14675a02SAndroid Build Coastguard Worker    The updated graph, an instance of `tf.compat.v1.GraphDef`.
386*14675a02SAndroid Build Coastguard Worker
387*14675a02SAndroid Build Coastguard Worker  Raises:
388*14675a02SAndroid Build Coastguard Worker    TypeError: If either argument is of the wrong type.
389*14675a02SAndroid Build Coastguard Worker  """
390*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
391*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(init_op, str, name='init_op')
392*14675a02SAndroid Build Coastguard Worker  init_op_str = tensor_utils.bare_name(init_op)
393*14675a02SAndroid Build Coastguard Worker  init_op_control_dep = '^{}'.format(init_op_str)
394*14675a02SAndroid Build Coastguard Worker  deps = _get_deps_for_graph_node(graph_def, init_op_str).union(
395*14675a02SAndroid Build Coastguard Worker      set([init_op_str])
396*14675a02SAndroid Build Coastguard Worker  )
397*14675a02SAndroid Build Coastguard Worker  new_graph_def = tf.compat.v1.GraphDef()
398*14675a02SAndroid Build Coastguard Worker  new_graph_def.CopyFrom(graph_def)
399*14675a02SAndroid Build Coastguard Worker  for new_node in new_graph_def.node:
400*14675a02SAndroid Build Coastguard Worker    if new_node.name not in deps:
401*14675a02SAndroid Build Coastguard Worker      node_inputs = new_node.input
402*14675a02SAndroid Build Coastguard Worker      if init_op_control_dep not in node_inputs:
403*14675a02SAndroid Build Coastguard Worker        new_node.input.append(init_op_control_dep)
404*14675a02SAndroid Build Coastguard Worker  return new_graph_def
405*14675a02SAndroid Build Coastguard Worker
406*14675a02SAndroid Build Coastguard Worker
407*14675a02SAndroid Build Coastguard Workerdef create_tensor_map(
408*14675a02SAndroid Build Coastguard Worker    binding: computation_pb2.TensorFlow.Binding,
409*14675a02SAndroid Build Coastguard Worker    arg: list[Union[tf.Tensor, tf.Variable]],
410*14675a02SAndroid Build Coastguard Worker) -> dict[str, tf.Tensor]:
411*14675a02SAndroid Build Coastguard Worker  """Creates a `dict` mapping tensor names in the binding to tensors in `arg`.
412*14675a02SAndroid Build Coastguard Worker
413*14675a02SAndroid Build Coastguard Worker  Args:
414*14675a02SAndroid Build Coastguard Worker    binding: An instance of `computation_pb2.TensorFlow.Binding`.
415*14675a02SAndroid Build Coastguard Worker    arg: Either a singleton Python `list` with variant tensor in case of a
416*14675a02SAndroid Build Coastguard Worker      sequence binding, or a Python `list` of tensors or resource variables
417*14675a02SAndroid Build Coastguard Worker      otherwise for a tuple binding.
418*14675a02SAndroid Build Coastguard Worker
419*14675a02SAndroid Build Coastguard Worker  Returns:
420*14675a02SAndroid Build Coastguard Worker    An instance of Python `dict` with the map as specified above.
421*14675a02SAndroid Build Coastguard Worker
422*14675a02SAndroid Build Coastguard Worker  Raises:
423*14675a02SAndroid Build Coastguard Worker    TypeError: If the argument types are incorrect.
424*14675a02SAndroid Build Coastguard Worker    ValueError: If the arguments are malformed (e.g., multiple variant tensors).
425*14675a02SAndroid Build Coastguard Worker  """
426*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
427*14675a02SAndroid Build Coastguard Worker      binding, computation_pb2.TensorFlow.Binding, name='binding'
428*14675a02SAndroid Build Coastguard Worker  )
429*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(arg, list, name='arg')
430*14675a02SAndroid Build Coastguard Worker  tensor_names_in_binding = _list_tensor_names_in_binding(binding)
431*14675a02SAndroid Build Coastguard Worker  which_binding = binding.WhichOneof('binding')
432*14675a02SAndroid Build Coastguard Worker  if which_binding == 'sequence':
433*14675a02SAndroid Build Coastguard Worker    if (len(tensor_names_in_binding) != 1) or (len(arg) != 1):
434*14675a02SAndroid Build Coastguard Worker      raise ValueError('Multiple variant tensors found.')
435*14675a02SAndroid Build Coastguard Worker    variant_tensor_name = tensor_names_in_binding[0]
436*14675a02SAndroid Build Coastguard Worker    arg = arg[0]
437*14675a02SAndroid Build Coastguard Worker    if not tf.is_tensor(arg):
438*14675a02SAndroid Build Coastguard Worker      raise TypeError('Expected a tensor, found {!r}.'.format(type(arg)))
439*14675a02SAndroid Build Coastguard Worker    if arg.dtype != tf.variant:
440*14675a02SAndroid Build Coastguard Worker      raise TypeError('Expected `tf.variant`, found {!r}.'.format(arg.dtype))
441*14675a02SAndroid Build Coastguard Worker    return {variant_tensor_name: arg}
442*14675a02SAndroid Build Coastguard Worker  else:
443*14675a02SAndroid Build Coastguard Worker    return {
444*14675a02SAndroid Build Coastguard Worker        k: v.read_value() if hasattr(v, 'read_value') else v
445*14675a02SAndroid Build Coastguard Worker        for k, v in zip(tensor_names_in_binding, arg)
446*14675a02SAndroid Build Coastguard Worker    }
447*14675a02SAndroid Build Coastguard Worker
448*14675a02SAndroid Build Coastguard Worker
449*14675a02SAndroid Build Coastguard Workerdef _validate_data_comp(data_comp: tff.Computation, type_spec: tff.Type):
450*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(data_comp.type_signature, tff.FunctionType)
451*14675a02SAndroid Build Coastguard Worker  if not type_spec.is_assignable_from(data_comp.type_signature.result):
452*14675a02SAndroid Build Coastguard Worker    type_mismatch_string = tff.types.type_mismatch_error_message(
453*14675a02SAndroid Build Coastguard Worker        type_spec,
454*14675a02SAndroid Build Coastguard Worker        data_comp.type_signature.result,
455*14675a02SAndroid Build Coastguard Worker        tff.types.TypeRelation.ASSIGNABLE,
456*14675a02SAndroid Build Coastguard Worker    )
457*14675a02SAndroid Build Coastguard Worker    raise TypeError(
458*14675a02SAndroid Build Coastguard Worker        'The data source constructed with the supplied dataspec returns data '
459*14675a02SAndroid Build Coastguard Worker        'which does not match type of request. Details of the mismatch:\n'
460*14675a02SAndroid Build Coastguard Worker        + type_mismatch_string
461*14675a02SAndroid Build Coastguard Worker    )
462*14675a02SAndroid Build Coastguard Worker
463*14675a02SAndroid Build Coastguard Worker
464*14675a02SAndroid Build Coastguard Workerdef make_data_sources_with_dataspec(
465*14675a02SAndroid Build Coastguard Worker    type_spec: tff.Type, ds: data_spec.NestedDataSpec
466*14675a02SAndroid Build Coastguard Worker) -> list[tff.Computation]:
467*14675a02SAndroid Build Coastguard Worker  """Creates a list of computations that feed data into the graph using specified example selectors.
468*14675a02SAndroid Build Coastguard Worker
469*14675a02SAndroid Build Coastguard Worker  The computations use the custom ExternalDataset op to feed in example data.
470*14675a02SAndroid Build Coastguard Worker  The computations will expect one input:
471*14675a02SAndroid Build Coastguard Worker    -- A token specifying where the data store is on the device.
472*14675a02SAndroid Build Coastguard Worker  Example selectors that describes what data to take from the on-device data
473*14675a02SAndroid Build Coastguard Worker  store will be hard-coded into the computations.
474*14675a02SAndroid Build Coastguard Worker
475*14675a02SAndroid Build Coastguard Worker  Args:
476*14675a02SAndroid Build Coastguard Worker    type_spec: The TFF type signature of the output, which must be either a
477*14675a02SAndroid Build Coastguard Worker      sequence, or a named tuple of sequences.
478*14675a02SAndroid Build Coastguard Worker    ds: Either a single `data_spec.DataSpec`, or a nested structure of these,
479*14675a02SAndroid Build Coastguard Worker      made up of Python containers, that exactly matches the structure of the
480*14675a02SAndroid Build Coastguard Worker      `type_spec`.
481*14675a02SAndroid Build Coastguard Worker
482*14675a02SAndroid Build Coastguard Worker  Returns:
483*14675a02SAndroid Build Coastguard Worker    A list of `tff.Computation`s, each of which accepts a single `string`-typed
484*14675a02SAndroid Build Coastguard Worker    tensor as input (the token for the ExternalDataset op) and returns a
485*14675a02SAndroid Build Coastguard Worker    sequence as output (with the result that matches the corresponding part of
486*14675a02SAndroid Build Coastguard Worker    `type_spec`). The computations appear on the list in a depth-first order
487*14675a02SAndroid Build Coastguard Worker    (matching exactly the convention used in the
488*14675a02SAndroid Build Coastguard Worker    `_list_tensor_names_in_binding()` method below).
489*14675a02SAndroid Build Coastguard Worker
490*14675a02SAndroid Build Coastguard Worker  Raises:
491*14675a02SAndroid Build Coastguard Worker    TypeError: If the arguments are of the wrong types.
492*14675a02SAndroid Build Coastguard Worker  """
493*14675a02SAndroid Build Coastguard Worker  assert ds
494*14675a02SAndroid Build Coastguard Worker  type_spec = tff.to_type(type_spec)
495*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
496*14675a02SAndroid Build Coastguard Worker      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
497*14675a02SAndroid Build Coastguard Worker  )
498*14675a02SAndroid Build Coastguard Worker  if type_spec.is_sequence():
499*14675a02SAndroid Build Coastguard Worker    type_checks.check_type(ds, data_spec.DataSpec)
500*14675a02SAndroid Build Coastguard Worker    assert isinstance(ds, data_spec.DataSpec)
501*14675a02SAndroid Build Coastguard Worker    assert ds.example_selector_proto is not None
502*14675a02SAndroid Build Coastguard Worker    sel_bytes = ds.example_selector_proto.SerializeToString()
503*14675a02SAndroid Build Coastguard Worker
504*14675a02SAndroid Build Coastguard Worker    @tff.tf_computation(tf.string)
505*14675a02SAndroid Build Coastguard Worker    def data_comp(token):
506*14675a02SAndroid Build Coastguard Worker      """The data source computation.
507*14675a02SAndroid Build Coastguard Worker
508*14675a02SAndroid Build Coastguard Worker      Args:
509*14675a02SAndroid Build Coastguard Worker        token: The token placeholder tensor (`tf.string`).
510*14675a02SAndroid Build Coastguard Worker
511*14675a02SAndroid Build Coastguard Worker      Returns:
512*14675a02SAndroid Build Coastguard Worker        An instance of `tf.data.Dataset`.
513*14675a02SAndroid Build Coastguard Worker      """
514*14675a02SAndroid Build Coastguard Worker      if ds.preprocessing_fn is not None:
515*14675a02SAndroid Build Coastguard Worker        processed_ds = ds.preprocessing_fn(
516*14675a02SAndroid Build Coastguard Worker            external_dataset.ExternalDataset(token=token, selector=sel_bytes)
517*14675a02SAndroid Build Coastguard Worker        )
518*14675a02SAndroid Build Coastguard Worker      else:
519*14675a02SAndroid Build Coastguard Worker        processed_ds = external_dataset.ExternalDataset(
520*14675a02SAndroid Build Coastguard Worker            token=token, selector=sel_bytes
521*14675a02SAndroid Build Coastguard Worker        )
522*14675a02SAndroid Build Coastguard Worker
523*14675a02SAndroid Build Coastguard Worker      if 'Dataset' not in type(processed_ds).__name__:
524*14675a02SAndroid Build Coastguard Worker        raise TypeError(
525*14675a02SAndroid Build Coastguard Worker            'The preprocessing function returned an unrecognized non-dataset '
526*14675a02SAndroid Build Coastguard Worker            'type {!r}.'.format(type(processed_ds))
527*14675a02SAndroid Build Coastguard Worker        )
528*14675a02SAndroid Build Coastguard Worker      return processed_ds
529*14675a02SAndroid Build Coastguard Worker
530*14675a02SAndroid Build Coastguard Worker    _validate_data_comp(data_comp, type_spec)
531*14675a02SAndroid Build Coastguard Worker    return [data_comp]
532*14675a02SAndroid Build Coastguard Worker  else:
533*14675a02SAndroid Build Coastguard Worker    type_spec.check_struct()
534*14675a02SAndroid Build Coastguard Worker    if isinstance(ds, data_spec.DataSpec):
535*14675a02SAndroid Build Coastguard Worker      raise TypeError(
536*14675a02SAndroid Build Coastguard Worker          'Expected nested structure of `DataSpec`s conforming to '
537*14675a02SAndroid Build Coastguard Worker          f'the structure of the type {type_spec}. '
538*14675a02SAndroid Build Coastguard Worker          'Found single `DataSpec` instead.'
539*14675a02SAndroid Build Coastguard Worker      )
540*14675a02SAndroid Build Coastguard Worker    ds = tff.structure.from_container(ds)
541*14675a02SAndroid Build Coastguard Worker    assert isinstance(ds, tff.structure.Struct)
542*14675a02SAndroid Build Coastguard Worker    type_spec_elements = tff.structure.to_elements(type_spec)
543*14675a02SAndroid Build Coastguard Worker    data_spec_elements = tff.structure.to_elements(ds)
544*14675a02SAndroid Build Coastguard Worker    type_spec_element_names = [str(k) for k, _ in type_spec_elements]
545*14675a02SAndroid Build Coastguard Worker    data_spec_element_names = [str(k) for k, _ in data_spec_elements]
546*14675a02SAndroid Build Coastguard Worker    if type_spec_element_names != data_spec_element_names:
547*14675a02SAndroid Build Coastguard Worker      raise TypeError(
548*14675a02SAndroid Build Coastguard Worker          'Type vs. data spec elements names mismatch: {} vs. {}.'.format(
549*14675a02SAndroid Build Coastguard Worker              str(type_spec_element_names), str(data_spec_element_names)
550*14675a02SAndroid Build Coastguard Worker          )
551*14675a02SAndroid Build Coastguard Worker      )
552*14675a02SAndroid Build Coastguard Worker    elements = []
553*14675a02SAndroid Build Coastguard Worker    for element_index, (_, element_type) in enumerate(type_spec_elements):
554*14675a02SAndroid Build Coastguard Worker      elements.extend(
555*14675a02SAndroid Build Coastguard Worker          make_data_sources_with_dataspec(element_type, ds[element_index])
556*14675a02SAndroid Build Coastguard Worker      )
557*14675a02SAndroid Build Coastguard Worker    return elements
558*14675a02SAndroid Build Coastguard Worker
559*14675a02SAndroid Build Coastguard Worker
560*14675a02SAndroid Build Coastguard Workerdef make_data_sources_without_dataspec(type_spec) -> list[tff.Computation]:
561*14675a02SAndroid Build Coastguard Worker  """Creates a list of computations that feed data into the graph.
562*14675a02SAndroid Build Coastguard Worker
563*14675a02SAndroid Build Coastguard Worker  The computations use the custom ExternalDataset op to feed in example data.
564*14675a02SAndroid Build Coastguard Worker  The computations will expect two inputs:
565*14675a02SAndroid Build Coastguard Worker    -- A token specifying where the data store is on the device.
566*14675a02SAndroid Build Coastguard Worker    -- An example selector that describes what data to take from the on-device
567*14675a02SAndroid Build Coastguard Worker      data store.
568*14675a02SAndroid Build Coastguard Worker
569*14675a02SAndroid Build Coastguard Worker  Args:
570*14675a02SAndroid Build Coastguard Worker    type_spec: The TFF type signature of the output, which must be either a
571*14675a02SAndroid Build Coastguard Worker      sequence, or a named tuple of sequences.
572*14675a02SAndroid Build Coastguard Worker
573*14675a02SAndroid Build Coastguard Worker  Returns:
574*14675a02SAndroid Build Coastguard Worker    A list of `tff.Computation`s, each of which accepts a single `string`-typed
575*14675a02SAndroid Build Coastguard Worker    tensor as input (the token for the ExternalDataset op) and returns a
576*14675a02SAndroid Build Coastguard Worker    sequence as output (with the result that matches the corresponding part of
577*14675a02SAndroid Build Coastguard Worker    `type_spec`). The computations appear on the list in a depth-first order
578*14675a02SAndroid Build Coastguard Worker    (matching exactly the convention used in the
579*14675a02SAndroid Build Coastguard Worker    `_list_tensor_names_in_binding()` method below).
580*14675a02SAndroid Build Coastguard Worker
581*14675a02SAndroid Build Coastguard Worker  Raises:
582*14675a02SAndroid Build Coastguard Worker    TypeError: If the arguments are of the wrong types.
583*14675a02SAndroid Build Coastguard Worker  """
584*14675a02SAndroid Build Coastguard Worker  type_spec = tff.to_type(type_spec)
585*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(
586*14675a02SAndroid Build Coastguard Worker      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
587*14675a02SAndroid Build Coastguard Worker  )
588*14675a02SAndroid Build Coastguard Worker  if type_spec.is_sequence():
589*14675a02SAndroid Build Coastguard Worker
590*14675a02SAndroid Build Coastguard Worker    @tff.tf_computation(tf.string, tf.string)
591*14675a02SAndroid Build Coastguard Worker    def data_comp(token, example_selector):
592*14675a02SAndroid Build Coastguard Worker      """The data source computation.
593*14675a02SAndroid Build Coastguard Worker
594*14675a02SAndroid Build Coastguard Worker      Args:
595*14675a02SAndroid Build Coastguard Worker        token: The token placeholder tensor (`tf.string`).
596*14675a02SAndroid Build Coastguard Worker        example_selector: The example selector placeholder tensor (`tf.string`).
597*14675a02SAndroid Build Coastguard Worker
598*14675a02SAndroid Build Coastguard Worker      Returns:
599*14675a02SAndroid Build Coastguard Worker        An instance of `tf.data.Dataset`.
600*14675a02SAndroid Build Coastguard Worker      """
601*14675a02SAndroid Build Coastguard Worker      processed_ds = external_dataset.ExternalDataset(
602*14675a02SAndroid Build Coastguard Worker          token=token, selector=example_selector
603*14675a02SAndroid Build Coastguard Worker      )
604*14675a02SAndroid Build Coastguard Worker
605*14675a02SAndroid Build Coastguard Worker      if 'Dataset' not in type(processed_ds).__name__:
606*14675a02SAndroid Build Coastguard Worker        raise TypeError(
607*14675a02SAndroid Build Coastguard Worker            'The preprocessing function returned an unrecognized non-dataset '
608*14675a02SAndroid Build Coastguard Worker            'type {!r}.'.format(type(processed_ds))
609*14675a02SAndroid Build Coastguard Worker        )
610*14675a02SAndroid Build Coastguard Worker      return processed_ds
611*14675a02SAndroid Build Coastguard Worker
612*14675a02SAndroid Build Coastguard Worker    _validate_data_comp(data_comp, type_spec)
613*14675a02SAndroid Build Coastguard Worker    return [data_comp]
614*14675a02SAndroid Build Coastguard Worker  else:  # type_spec is a struct.
615*14675a02SAndroid Build Coastguard Worker    type_spec.check_struct()
616*14675a02SAndroid Build Coastguard Worker    type_spec_elements = tff.structure.to_elements(type_spec)
617*14675a02SAndroid Build Coastguard Worker    elements = []
618*14675a02SAndroid Build Coastguard Worker    for _, element_type in type_spec_elements:
619*14675a02SAndroid Build Coastguard Worker      elements.extend(make_data_sources_without_dataspec(element_type))
620*14675a02SAndroid Build Coastguard Worker    return elements
621*14675a02SAndroid Build Coastguard Worker
622*14675a02SAndroid Build Coastguard Worker
623*14675a02SAndroid Build Coastguard Workerdef _list_tensor_names_in_binding(
624*14675a02SAndroid Build Coastguard Worker    binding: computation_pb2.TensorFlow.Binding,
625*14675a02SAndroid Build Coastguard Worker) -> list[str]:
626*14675a02SAndroid Build Coastguard Worker  """Returns a flat Python list of tensor names that appear in the `binding`.
627*14675a02SAndroid Build Coastguard Worker
628*14675a02SAndroid Build Coastguard Worker  Args:
629*14675a02SAndroid Build Coastguard Worker    binding: An instance of `computation_pb2.TensorFlow.Binding` in which any
630*14675a02SAndroid Build Coastguard Worker      sequence bindings must contain variant tensors.
631*14675a02SAndroid Build Coastguard Worker
632*14675a02SAndroid Build Coastguard Worker  Returns:
633*14675a02SAndroid Build Coastguard Worker    A list of `str` instances with tensor names that appear in `binding` in the
634*14675a02SAndroid Build Coastguard Worker    order in which they appear in the depth-first traversal of the potentially
635*14675a02SAndroid Build Coastguard Worker    nested binding structure.
636*14675a02SAndroid Build Coastguard Worker
637*14675a02SAndroid Build Coastguard Worker  Raises:
638*14675a02SAndroid Build Coastguard Worker    TypeError: If the arguments are of the wrong types.
639*14675a02SAndroid Build Coastguard Worker  """
640*14675a02SAndroid Build Coastguard Worker  type_checks.check_type(binding, computation_pb2.TensorFlow.Binding)
641*14675a02SAndroid Build Coastguard Worker  which_binding = binding.WhichOneof('binding')
642*14675a02SAndroid Build Coastguard Worker  if which_binding == 'tensor':
643*14675a02SAndroid Build Coastguard Worker    return [str(binding.tensor.tensor_name)]
644*14675a02SAndroid Build Coastguard Worker  elif which_binding == 'struct':
645*14675a02SAndroid Build Coastguard Worker    result = []
646*14675a02SAndroid Build Coastguard Worker    for element in binding.struct.element:
647*14675a02SAndroid Build Coastguard Worker      result.extend(_list_tensor_names_in_binding(element))
648*14675a02SAndroid Build Coastguard Worker    return result
649*14675a02SAndroid Build Coastguard Worker  elif which_binding == 'sequence':
650*14675a02SAndroid Build Coastguard Worker    which_sequence = binding.sequence.WhichOneof('binding')
651*14675a02SAndroid Build Coastguard Worker    if which_sequence != 'variant_tensor_name':
652*14675a02SAndroid Build Coastguard Worker      raise TypeError(
653*14675a02SAndroid Build Coastguard Worker          'Expected a variant tensor in sequence binding, found {}.'.format(
654*14675a02SAndroid Build Coastguard Worker              which_sequence
655*14675a02SAndroid Build Coastguard Worker          )
656*14675a02SAndroid Build Coastguard Worker      )
657*14675a02SAndroid Build Coastguard Worker    return [binding.sequence.variant_tensor_name]
658*14675a02SAndroid Build Coastguard Worker  else:
659*14675a02SAndroid Build Coastguard Worker    raise TypeError('Unexpected type of binding {}.'.format(which_binding))
660