xref: /aosp_15_r20/external/tensorflow/tensorflow/python/util/nest.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Functions that work with structures.
17
18A structure is either:
19
20* one of the recognized Python collections, holding _nested structures_;
21* a value of any other type, typically a TensorFlow data type like Tensor,
22  Variable, or of compatible types such as int, float, ndarray, etc. these are
23  commonly referred to as _atoms_ of the structure.
24
25A structure of type `T` is a structure whose atomic items are of type `T`.
26For example, a structure of `tf.Tensor` only contains `tf.Tensor` as its atoms.
27
28Historically a _nested structure_ was called a _nested sequence_ in TensorFlow.
29A nested structure is sometimes called a _nest_ or a _tree_, but the formal
30name _nested structure_ is preferred.
31
32Refer to [Nesting Data Structures]
33(https://en.wikipedia.org/wiki/Nesting_(computing)#Data_structures).
34
35The following collection types are recognized by `tf.nest` as nested
36structures:
37
38* `collections.abc.Sequence` (except `string` and `bytes`).
39  This includes `list`, `tuple`, and `namedtuple`.
40* `collections.abc.Mapping` (with sortable keys).
41  This includes `dict` and `collections.OrderedDict`.
42* `collections.abc.MappingView` (with sortable keys).
43* [`attr.s` classes](https://www.attrs.org/).
44
45Any other values are considered **atoms**.  Not all collection types are
46considered nested structures.  For example, the following types are
47considered atoms:
48
49* `set`; `{"a", "b"}` is an atom, while `["a", "b"]` is a nested structure.
50* [`dataclass` classes](https://docs.python.org/library/dataclasses.html)
51* `tf.Tensor`
52* `numpy.array`
53
54`tf.nest.is_nested` checks whether an object is a nested structure or an atom.
55For example:
56
57  >>> tf.nest.is_nested("1234")
58  False
59  >>> tf.nest.is_nested([1, 3, [4, 5]])
60  True
61  >>> tf.nest.is_nested(((7, 8), (5, 6)))
62  True
63  >>> tf.nest.is_nested([])
64  True
65  >>> tf.nest.is_nested({"a": 1, "b": 2})
66  True
67  >>> tf.nest.is_nested({"a": 1, "b": 2}.keys())
68  True
69  >>> tf.nest.is_nested({"a": 1, "b": 2}.values())
70  True
71  >>> tf.nest.is_nested({"a": 1, "b": 2}.items())
72  True
73  >>> tf.nest.is_nested(set([1, 2]))
74  False
75  >>> ones = tf.ones([2, 3])
76  >>> tf.nest.is_nested(ones)
77  False
78
79Note: A proper structure shall form a tree. The user shall ensure there is no
80cyclic references within the items in the structure,
81i.e., no references in the structure of the input of these functions
82should be recursive. The behavior is undefined if there is a cycle.
83
84"""
85
86import collections as _collections
87
88import six as _six
89import wrapt as _wrapt
90
91from tensorflow.python.platform import tf_logging
92from tensorflow.python.util import _pywrap_nest
93from tensorflow.python.util import _pywrap_utils
94from tensorflow.python.util.compat import collections_abc as _collections_abc
95from tensorflow.python.util.tf_export import tf_export
96
97
98_SHALLOW_TREE_HAS_INVALID_KEYS = (
99    "The shallow_tree's keys are not a subset of the input_tree's keys. The "
100    "shallow_tree has the following keys that are not in the input_tree: {}.")
101
102_STRUCTURES_HAVE_MISMATCHING_TYPES = (
103    "The two structures don't have the same sequence type. Input structure has "
104    "type {input_type}, while shallow structure has type {shallow_type}.")
105
106_STRUCTURES_HAVE_MISMATCHING_LENGTHS = (
107    "The two structures don't have the same sequence length. Input "
108    "structure has length {input_length}, while shallow structure has length "
109    "{shallow_length}."
110)
111
112_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
113    "The input_tree has fewer items than the shallow_tree. Input structure "
114    "has length {input_size}, while shallow structure has length "
115    "{shallow_size}.")
116
117_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
118    "If shallow structure is a sequence, input must also be a sequence. "
119    "Input has type: {}.")
120
121
122def _get_attrs_items(obj):
123  """Returns a list of (name, value) pairs from an attrs instance.
124
125  The list will be sorted by name.
126
127  Args:
128    obj: an object.
129
130  Returns:
131    A list of (attr_name, attr_value) pairs, sorted by attr_name.
132  """
133  attrs = getattr(obj.__class__, "__attrs_attrs__")
134  attr_names = (a.name for a in attrs)
135  return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
136
137
138def _sorted(dict_):
139  """Returns a sorted list of the dict keys, with error if keys not sortable."""
140  try:
141    return sorted(dict_.keys())
142  except TypeError:
143    raise TypeError("nest only supports dicts with sortable keys.")
144
145
146# TODO(b/225045380): Move to a "leaf" library to use in trace_type.
147def is_namedtuple(instance, strict=False):
148  """Returns True iff `instance` is a `namedtuple`.
149
150  Args:
151    instance: An instance of a Python object.
152    strict: If True, `instance` is considered to be a `namedtuple` only if
153        it is a "plain" namedtuple. For instance, a class inheriting
154        from a `namedtuple` will be considered to be a `namedtuple`
155        iff `strict=False`.
156
157  Returns:
158    True if `instance` is a `namedtuple`.
159  """
160  return _pywrap_utils.IsNamedtuple(instance, strict)
161
162_is_namedtuple = is_namedtuple  # This function was private up to TF2.5.
163
164_is_mapping_view = _pywrap_utils.IsMappingView
165_is_attrs = _pywrap_utils.IsAttrs
166_is_composite_tensor = _pywrap_utils.IsCompositeTensor
167_is_type_spec = _pywrap_utils.IsTypeSpec
168_is_mutable_mapping = _pywrap_utils.IsMutableMapping
169_is_mapping = _pywrap_utils.IsMapping
170
171
172# TODO(b/225045380): Move to a "leaf" library to use in trace_type.
173@tf_export("__internal__.nest.is_attrs", v1=[])
174def is_attrs(obj):
175  """Returns a true if its input is an instance of an attr.s decorated class."""
176  return _is_attrs(obj)
177
178
179@tf_export("__internal__.nest.is_mapping", v1=[])
180def is_mapping(obj):
181  """Returns a true if its input is a collections.Mapping."""
182  return _is_mapping(obj)
183
184
185@tf_export("__internal__.nest.sequence_like", v1=[])
186def _sequence_like(instance, args):
187  """Converts the sequence `args` to the same type as `instance`.
188
189  Args:
190    instance: an instance of `tuple`, `list`, `namedtuple`, `dict`,
191        `collections.OrderedDict`, or `composite_tensor.Composite_Tensor`
192        or `type_spec.TypeSpec`.
193    args: items to be converted to the `instance` type.
194
195  Returns:
196    `args` with the type of `instance`.
197  """
198  if _is_mutable_mapping(instance):
199    # Pack dictionaries in a deterministic order by sorting the keys.
200    # Notice this means that we ignore the original order of `OrderedDict`
201    # instances. This is intentional, to avoid potential bugs caused by mixing
202    # ordered and plain dicts (e.g., flattening a dict but using a
203    # corresponding `OrderedDict` to pack it back).
204    result = dict(zip(_sorted(instance), args))
205    instance_type = type(instance)
206    if instance_type == _collections.defaultdict:
207      d = _collections.defaultdict(instance.default_factory)
208    else:
209      d = instance_type()
210    for key in instance:
211      d[key] = result[key]
212    return d
213  elif _is_mapping(instance):
214    result = dict(zip(_sorted(instance), args))
215    instance_type = type(instance)
216    if not getattr(instance_type, "__supported_by_tf_nest__", False):
217      tf_logging.log_first_n(
218          tf_logging.WARN, "Mapping types may not work well with tf.nest. "
219          "Prefer using MutableMapping for {}".format(instance_type), 1)
220    try:
221      return instance_type((key, result[key]) for key in instance)
222    except TypeError as err:
223      raise TypeError("Error creating an object of type {} like {}. Note that "
224                      "it must accept a single positional argument "
225                      "representing an iterable of key-value pairs, in "
226                      "addition to self. Cause: {}".format(
227                          type(instance), instance, err))
228  elif _is_mapping_view(instance):
229    # We can't directly construct mapping views, so we create a list instead
230    return list(args)
231  elif is_namedtuple(instance) or _is_attrs(instance):
232    if isinstance(instance, _wrapt.ObjectProxy):
233      instance_type = type(instance.__wrapped__)
234    else:
235      instance_type = type(instance)
236    return instance_type(*args)
237  elif _is_composite_tensor(instance):
238    assert len(args) == 1
239    spec = instance._type_spec  # pylint: disable=protected-access
240    return spec._from_components(args[0])  # pylint: disable=protected-access
241  elif _is_type_spec(instance):
242    # Pack a CompositeTensor's components according to a TypeSpec.
243    assert len(args) == 1
244    return instance._from_components(args[0])  # pylint: disable=protected-access
245  elif isinstance(instance, _six.moves.range):
246    return _sequence_like(list(instance), args)
247  elif isinstance(instance, _wrapt.ObjectProxy):
248    # For object proxies, first create the underlying type and then re-wrap it
249    # in the proxy type.
250    return type(instance)(_sequence_like(instance.__wrapped__, args))
251  else:
252    # Not a namedtuple
253    return type(instance)(args)
254
255
256def _yield_value(iterable):
257  for _, v in _yield_sorted_items(iterable):
258    yield v
259
260
261def _yield_sorted_items(iterable):
262  """Yield (key, value) pairs for `iterable` in a deterministic order.
263
264  For Sequences, the key will be an int, the array index of a value.
265  For Mappings, the key will be the dictionary key.
266  For objects (e.g. namedtuples), the key will be the attribute name.
267
268  In all cases, the keys will be iterated in sorted order.
269
270  Args:
271    iterable: an iterable.
272
273  Yields:
274    The iterable's (key, value) pairs, in order of sorted keys.
275  """
276  # Ordered to check common structure types (list, tuple, dict) first.
277  if isinstance(iterable, list):
278    for item in enumerate(iterable):
279      yield item
280  # namedtuples handled separately to avoid expensive namedtuple check.
281  elif type(iterable) == tuple:  # pylint: disable=unidiomatic-typecheck
282    for item in enumerate(iterable):
283      yield item
284  elif isinstance(iterable, (dict, _collections_abc.Mapping)):
285    # Iterate through dictionaries in a deterministic order by sorting the
286    # keys. Notice this means that we ignore the original order of `OrderedDict`
287    # instances. This is intentional, to avoid potential bugs caused by mixing
288    # ordered and plain dicts (e.g., flattening a dict but using a
289    # corresponding `OrderedDict` to pack it back).
290    for key in _sorted(iterable):
291      yield key, iterable[key]
292  elif _is_attrs(iterable):
293    for item in _get_attrs_items(iterable):
294      yield item
295  elif is_namedtuple(iterable):
296    for field in iterable._fields:
297      yield field, getattr(iterable, field)
298  elif _is_composite_tensor(iterable):
299    type_spec = iterable._type_spec  # pylint: disable=protected-access
300    yield type_spec.value_type.__name__, type_spec._to_components(iterable)  # pylint: disable=protected-access
301  elif _is_type_spec(iterable):
302    # Note: to allow CompositeTensors and their TypeSpecs to have matching
303    # structures, we need to use the same key string here.
304    yield iterable.value_type.__name__, iterable._component_specs  # pylint: disable=protected-access
305  else:
306    for item in enumerate(iterable):
307      yield item
308
309
310_is_nested = _pywrap_utils.IsNested
311
312_is_nested_or_composite = _pywrap_utils.IsNestedOrComposite
313
314
315@tf_export("nest.is_nested")
316def is_nested(seq):
317  """Returns true if its input is a nested structure.
318
319  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
320  for the definition of a nested structure.
321
322  Args:
323    seq: the value to test.
324
325  Returns:
326    True if the input is a nested structure.
327  """
328  return _is_nested(seq)
329
330
331def is_nested_or_composite(seq):
332  """Returns true if its input is a nested structure or a composite.
333
334  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
335  for the definition of a nested structure.
336
337  Args:
338    seq: the value to test.
339
340  Returns:
341    True if the input is a nested structure or a composite.
342  """
343  return _is_nested_or_composite(seq)
344
345
346# FIXME(feyu): Remove the back-compat names before closing b/201685523, after
347# all users of is_sequence are moved to the new names. (cl/405503918)
348def is_sequence(seq):
349  return _is_nested(seq)
350
351
352def is_sequence_or_composite(seq):
353  return _is_nested_or_composite(seq)
354
355
356@tf_export("nest.flatten")
357def flatten(structure, expand_composites=False):
358  """Returns a flat list from a given structure.
359
360  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
361  for the definition of a structure.
362
363  If the structure is an atom, then returns a single-item list: [structure].
364
365  This is the inverse of the `nest.pack_sequence_as` method that takes in a
366  flattened list and re-packs it into the nested structure.
367
368  In the case of dict instances, the sequence consists of the values, sorted by
369  key to ensure deterministic behavior. This is true also for OrderedDict
370  instances: their sequence order is ignored, the sorting order of keys is used
371  instead. The same convention is followed in `nest.pack_sequence_as`. This
372  correctly repacks dicts and OrderedDicts after they have been flattened, and
373  also allows flattening an OrderedDict and then repacking it back using a
374  corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys
375  cannot be flattened.
376
377  Users must not modify any collections used in nest while this function is
378  running.
379
380  Examples:
381
382  1. Python dict (ordered by key):
383
384    >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" }
385    >>> tf.nest.flatten(dict)
386    ['value1', 'value2', 'value3']
387
388  2. For a nested python tuple:
389
390    >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
391    >>> tf.nest.flatten(tuple)
392        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
393
394  3. For a nested dictionary of dictionaries:
395
396    >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)},
397    ... "key1": {"m": "val1", "g": "val2"} }
398    >>> tf.nest.flatten(dict)
399    ['val2', 'val1', 3.0, 1.0, 2.0]
400
401  4. Numpy array (will not flatten):
402
403    >>> array = np.array([[1, 2], [3, 4]])
404    >>> tf.nest.flatten(array)
405        [array([[1, 2],
406                [3, 4]])]
407
408  5. `tf.Tensor` (will not flatten):
409
410    >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
411    >>> tf.nest.flatten(tensor)
412        [<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
413          array([[1., 2., 3.],
414                 [4., 5., 6.],
415                 [7., 8., 9.]], dtype=float32)>]
416
417  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
418  of a flattened list of 'values' and a list of 'row_splits' which indicate how
419  to chop up the flattened list into different rows. For more details on
420  `tf.RaggedTensor`, please visit
421  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
422
423  with `expand_composites=False`, we just return the RaggedTensor as is.
424
425    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
426    >>> tf.nest.flatten(tensor, expand_composites=False)
427    [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>]
428
429  with `expand_composites=True`, we return the component Tensors that make up
430  the RaggedTensor representation (the values and row_splits tensors)
431
432    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
433    >>> tf.nest.flatten(tensor, expand_composites=True)
434    [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2],
435                                                      dtype=int32)>,
436     <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>]
437
438  Args:
439    structure: an atom or a nested structure. Note, numpy arrays are considered
440      atoms and are not flattened.
441    expand_composites: If true, then composite tensors such as
442      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
443      component tensors.
444
445  Returns:
446    A Python list, the flattened version of the input.
447
448  Raises:
449    TypeError: The nest is or contains a dict with non-sortable keys.
450  """
451  if structure is None:
452    return [None]
453  expand_composites = bool(expand_composites)
454  return _pywrap_utils.Flatten(structure, expand_composites)
455
456
457# See the swig file (util.i) for documentation.
458same_namedtuples = _pywrap_utils.SameNamedtuples
459_same_namedtuples = same_namedtuples  # This function was private up to TF2.5.
460
461
462class _DotString(object):
463
464  __slots__ = []
465
466  def __str__(self):
467    return "."
468
469  def __repr__(self):
470    return "."
471
472
473_DOT = _DotString()
474
475
476@tf_export("nest.assert_same_structure")
477def assert_same_structure(nest1, nest2, check_types=True,
478                          expand_composites=False):
479  """Asserts that two structures are nested in the same way.
480
481  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
482  for the definition of a structure.
483
484  Note the method does not check the types of atoms inside the structures.
485
486  Examples:
487
488  * These atom vs. atom comparisons will pass:
489
490    >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
491    >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
492
493  * These nested structure vs. nested structure comparisons will pass:
494
495    >>> structure1 = (((1, 2), 3), 4, (5, 6))
496    >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
497    >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
498    >>> tf.nest.assert_same_structure(structure1, structure2)
499    >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
500
501    >>> import collections
502    >>> tf.nest.assert_same_structure(
503    ...     collections.namedtuple("bar", "a b")(1, 2),
504    ...     collections.namedtuple("foo", "a b")(2, 3),
505    ...     check_types=False)
506
507    >>> tf.nest.assert_same_structure(
508    ...     collections.namedtuple("bar", "a b")(1, 2),
509    ...     { "a": 1, "b": 2 },
510    ...     check_types=False)
511
512    >>> tf.nest.assert_same_structure(
513    ...     { "a": 1, "b": 2, "c": 3 },
514    ...     { "c": 6, "b": 5, "a": 4 })
515
516    >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
517    ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
518    ...       row_splits=[0, 4, 4, 7, 8, 8])
519    >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
520    ...       values=[3, 1, 4],
521    ...       row_splits=[0, 3])
522    >>> tf.nest.assert_same_structure(
523    ...       ragged_tensor1,
524    ...       ragged_tensor2,
525    ...       expand_composites=True)
526
527  * These examples will raise exceptions:
528
529    >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
530    Traceback (most recent call last):
531    ...
532    ValueError: The two structures don't have the same nested structure
533
534    >>> tf.nest.assert_same_structure(
535    ...       collections.namedtuple('bar', 'a b')(1, 2),
536    ...       collections.namedtuple('foo', 'a b')(2, 3))
537    Traceback (most recent call last):
538    ...
539    TypeError: The two structures don't have the same nested structure
540
541  Args:
542    nest1: an atom or a nested structure.
543    nest2: an atom or a nested structure.
544    check_types: if `True` (default) types of structures are checked as well,
545      including the keys of dictionaries. If set to `False`, for example a list
546      and a tuple of objects will look the same if they have the same size. Note
547      that namedtuples with identical name and fields are always considered to
548      have the same shallow structure. Two types will also be considered the
549      same if they are both list subtypes (which allows "list" and
550      "_ListWrapper" from trackable dependency tracking to compare equal).
551      `check_types=True` only checks type of sub-structures. The types of atoms
552      are not checked.
553    expand_composites: If true, then composite tensors such as
554      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
555      component tensors.
556
557  Raises:
558    ValueError: If the two structures do not have the same number of atoms or
559      if the two structures are not nested in the same way.
560    TypeError: If the two structures differ in the type of sequence in any of
561      their substructures. Only possible if `check_types` is `True`.
562  """
563  # Convert to bool explicitly as otherwise pybind will not be able# to handle
564  # type mismatch message correctly. See GitHub issue 42329 for details.
565  check_types = bool(check_types)
566  expand_composites = bool(expand_composites)
567  try:
568    _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
569                                      expand_composites)
570  except (ValueError, TypeError) as e:
571    str1 = str(map_structure(lambda _: _DOT, nest1))
572    str2 = str(map_structure(lambda _: _DOT, nest2))
573    raise type(e)("%s\n"
574                  "Entire first structure:\n%s\n"
575                  "Entire second structure:\n%s"
576                  % (str(e), str1, str2))
577
578
579def flatten_dict_items(dictionary):
580  """Returns a dictionary with flattened keys and values.
581
582  This function flattens the keys and values of a dictionary, which can be
583  arbitrarily nested structures, and returns the flattened version of such
584  structures:
585
586  ```python
587  example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
588  result = {4: "a", 5: "b", 6: "c", 8: "d"}
589  flatten_dict_items(example_dictionary) == result
590  ```
591
592  The input dictionary must satisfy two properties:
593
594  1. Its keys and values should have the same exact nested structure.
595  2. The set of all flattened keys of the dictionary must not contain repeated
596     keys.
597
598  Args:
599    dictionary: the dictionary to zip
600
601  Returns:
602    The zipped dictionary.
603
604  Raises:
605    TypeError: If the input is not a dictionary.
606    ValueError: If any key and value do not have the same structure layout, or
607    if keys are not unique.
608  """
609  return _pywrap_nest.FlattenDictItems(dictionary)
610
611
612def _packed_nest_with_indices(structure,
613                              flat,
614                              index,
615                              is_nested_fn,
616                              sequence_fn=None):
617  """Helper function for pack_sequence_as.
618
619  Args:
620    structure: structure to mimic.
621    flat: Flattened values to output substructure for.
622    index: Index at which to start reading from flat.
623    is_nested_fn: Function used to test if a value should be treated as a
624      nested structure.
625    sequence_fn: Function used to generate a new strcuture instance.
626
627  Returns:
628    The tuple (new_index, child), where:
629      * new_index - the updated index into `flat` having processed `structure`.
630      * packed - the subset of `flat` corresponding to `structure`,
631                 having started at `index`, and packed into the same nested
632                 format.
633
634  Raises:
635    ValueError: if `structure` contains more atoms than `flat`
636      (assuming indexing starts from `index`).
637  """
638  packed = []
639  sequence_fn = sequence_fn or _sequence_like
640  for s in _yield_value(structure):
641    if is_nested_fn(s):
642      new_index, child = _packed_nest_with_indices(s, flat, index, is_nested_fn,
643                                                   sequence_fn)
644      packed.append(sequence_fn(s, child))
645      index = new_index
646    else:
647      packed.append(flat[index])
648      index += 1
649  return index, packed
650
651
652def _pack_sequence_as(structure, flat_sequence, expand_composites,
653                      sequence_fn=None):
654  """Implements sequence packing, with the option to alter the structure."""
655  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
656  sequence_fn = sequence_fn or _sequence_like
657  def truncate(value, length):
658    value_str = str(value)
659    return value_str[:length] + (value_str[length:] and "...")
660
661  if not is_nested_fn(flat_sequence):
662    raise TypeError(
663        "Attempted to pack value:\n  {}\ninto a structure, but found "
664        "incompatible type `{}` instead.".format(
665            truncate(flat_sequence, 100), type(flat_sequence)))
666
667  if not is_nested_fn(structure):
668    if len(flat_sequence) != 1:
669      raise ValueError(
670          "The target structure is of type `{}`\n  {}\nHowever the input "
671          "is a sequence ({}) of length {}.\n  {}\nnest cannot "
672          "guarantee that it is safe to map one to the other.".format(
673              type(structure), truncate(structure, 100), type(flat_sequence),
674              len(flat_sequence), truncate(flat_sequence, 100)))
675    return flat_sequence[0]
676
677  try:
678    final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0,
679                                                    is_nested_fn, sequence_fn)
680    if final_index < len(flat_sequence):
681      raise IndexError
682  except IndexError:
683    flat_structure = flatten(structure, expand_composites=expand_composites)
684    if len(flat_structure) != len(flat_sequence):
685      raise ValueError(
686          "Could not pack sequence. Structure had %d atoms, but "
687          "flat_sequence had %d items.  Structure: %s, flat_sequence: %s." %
688          (len(flat_structure), len(flat_sequence), structure, flat_sequence))
689  return sequence_fn(structure, packed)
690
691
692@tf_export("nest.pack_sequence_as")
693def pack_sequence_as(structure, flat_sequence, expand_composites=False):
694  """Returns a given flattened sequence packed into a given structure.
695
696  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
697  for the definition of a structure.
698
699  If `structure` is an atom, `flat_sequence` must be a single-item list;
700  in this case the return value is `flat_sequence[0]`.
701
702  If `structure` is or contains a dict instance, the keys will be sorted to
703  pack the flat sequence in deterministic order. This is true also for
704  `OrderedDict` instances: their sequence order is ignored, the sorting order of
705  keys is used instead. The same convention is followed in `flatten`.
706  This correctly repacks dicts and `OrderedDict`s after they have been
707  flattened, and also allows flattening an `OrderedDict` and then repacking it
708  back using a corresponding plain dict, or vice-versa.
709  Dictionaries with non-sortable keys cannot be flattened.
710
711  Examples:
712
713  1. Python dict:
714
715    >>> structure = { "key3": "", "key1": "", "key2": "" }
716    >>> flat_sequence = ["value1", "value2", "value3"]
717    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
718    {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'}
719
720  2. For a nested python tuple:
721
722    >>> structure = (('a','b'), ('c','d','e'), 'f')
723    >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
724    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
725    ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
726
727  3. For a nested dictionary of dictionaries:
728
729    >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')},
730    ...               "key1": {"e": "val1", "d": "val2"} }
731    >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0]
732    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
733    {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
734
735  4. Numpy array (considered a scalar):
736
737    >>> structure = ['a']
738    >>> flat_sequence = [np.array([[1, 2], [3, 4]])]
739    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
740    [array([[1, 2],
741           [3, 4]])]
742
743  5. tf.Tensor (considered a scalar):
744
745    >>> structure = ['a']
746    >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])]
747    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
748    [<tf.Tensor: shape=(2, 3), dtype=float32,
749     numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>]
750
751  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
752  of a flattened list of 'values' and a list of 'row_splits' which indicate how
753  to chop up the flattened list into different rows. For more details on
754  `tf.RaggedTensor`, please visit
755  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
756
757  With `expand_composites=False`, we treat RaggedTensor as a scalar.
758
759    >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]),
760    ...               "bar": tf.constant([[5]]) }
761    >>> flat_sequence = [ "one", "two" ]
762    >>> tf.nest.pack_sequence_as(structure, flat_sequence,
763    ... expand_composites=False)
764    {'foo': 'two', 'bar': 'one'}
765
766  With `expand_composites=True`, we expect that the flattened input contains
767  the tensors making up the ragged tensor i.e. the values and row_splits
768  tensors.
769
770    >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]),
771    ...               "bar": tf.constant([[5.]]) }
772    >>> tensors = tf.nest.flatten(structure, expand_composites=True)
773    >>> print(tensors)
774    [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
775     dtype=float32)>,
776     <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.],
777     dtype=float32)>,
778     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>]
779    >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ')
780    ...                     if t.dtype==tf.float32 else t
781    ...                     for t in tensors]
782    >>> tf.nest.pack_sequence_as(structure, verified_tensors,
783    ...                          expand_composites=True)
784    {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>,
785     'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
786     dtype=float32)>}
787
788  Args:
789    structure: Nested structure, whose structure is given by nested lists,
790      tuples, and dicts. Note: numpy arrays and strings are considered
791      scalars.
792    flat_sequence: flat sequence to pack.
793    expand_composites: If true, then composite tensors such as
794      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
795      component tensors.
796
797  Returns:
798    packed: `flat_sequence` converted to have the same recursive structure as
799      `structure`.
800
801  Raises:
802    ValueError: If `flat_sequence` and `structure` have different
803      atom counts.
804    TypeError: `structure` is or contains a dict with non-sortable keys.
805  """
806  return _pack_sequence_as(structure, flat_sequence, expand_composites)
807
808
809@tf_export("nest.map_structure")
810def map_structure(func, *structure, **kwargs):
811  """Creates a new structure by applying `func` to each atom in `structure`.
812
813  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
814  for the definition of a structure.
815
816  Applies `func(x[0], x[1], ...)` where x[i] enumerates all atoms in
817  `structure[i]`.  All items in `structure` must have the same arity,
818  and the return value will contain results with the same structure layout.
819
820  Examples:
821
822  * A single Python dict:
823
824  >>> a = {"hello": 24, "world": 76}
825  >>> tf.nest.map_structure(lambda p: p * 2, a)
826  {'hello': 48, 'world': 152}
827
828  * Multiple Python dictionaries:
829
830  >>> d1 = {"hello": 24, "world": 76}
831  >>> d2 = {"hello": 36, "world": 14}
832  >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2)
833  {'hello': 60, 'world': 90}
834
835  * A single Python list:
836
837  >>> a = [24, 76, "ab"]
838  >>> tf.nest.map_structure(lambda p: p * 2, a)
839  [48, 152, 'abab']
840
841  * Scalars:
842
843  >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4)
844  7
845
846  * Empty structures:
847
848  >>> tf.nest.map_structure(lambda x: x + 1, ())
849  ()
850
851  * Check the types of iterables:
852
853  >>> s1 = (((1, 2), 3), 4, (5, 6))
854  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
855  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list)
856  Traceback (most recent call last):
857  ...
858  TypeError: The two structures don't have the same nested structure
859
860  * Type check is set to False:
861
862  >>> s1 = (((1, 2), 3), 4, (5, 6))
863  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
864  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False)
865  (((None, None), None), None, (None, None))
866
867  Args:
868    func: A callable that accepts as many arguments as there are structures.
869    *structure: atom or nested structure.
870    **kwargs: Valid keyword args are:
871      * `check_types`: If set to `True` (default) the types of iterables within
872        the structures have to be same (e.g. `map_structure(func, [1], (1,))`
873        raises a `TypeError` exception). To allow this set this argument to
874        `False`. Note that namedtuples with identical name and fields are always
875        considered to have the same shallow structure.
876      * `expand_composites`: If set to `True`, then composite tensors such as
877        `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
878        component tensors.  If `False` (the default), then composite tensors are
879        not expanded.
880
881  Returns:
882    A new structure with the same arity as `structure[0]`, whose atoms
883    correspond to `func(x[0], x[1], ...)` where `x[i]` is the atom in the
884    corresponding location in `structure[i]`. If there are different structure
885    types and `check_types` is `False` the structure types of the first
886    structure will be used.
887
888  Raises:
889    TypeError: If `func` is not callable or if the structures do not match
890      each other by depth tree.
891    ValueError: If no structure is provided or if the structures do not match
892      each other by type.
893    ValueError: If wrong keyword arguments are provided.
894  """
895  if not callable(func):
896    raise TypeError("func must be callable, got: %s" % func)
897
898  if not structure:
899    raise ValueError("Must provide at least one structure")
900
901  check_types = kwargs.pop("check_types", True)
902  expand_composites = kwargs.pop("expand_composites", False)
903
904  if kwargs:
905    raise ValueError(
906        "Only valid keyword arguments are `check_types` and "
907        "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys())))
908
909  for other in structure[1:]:
910    assert_same_structure(structure[0], other, check_types=check_types,
911                          expand_composites=expand_composites)
912
913  flat_structure = (flatten(s, expand_composites) for s in structure)
914  entries = zip(*flat_structure)
915
916  return pack_sequence_as(
917      structure[0], [func(*x) for x in entries],
918      expand_composites=expand_composites)
919
920
921def map_structure_with_paths(func, *structure, **kwargs):
922  """Applies `func` to each entry in `structure` and returns a new structure.
923
924  Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
925  `structure[i]` and `path` is the common path to x[i] in the structures.  All
926  structures in `structure` must have the same arity, and the return value will
927  contain the results with the same structure layout. Special kwarg
928  `check_types` determines whether the types of iterables within the structure
929  must be the same-- see **kwargs definition below.
930
931  Args:
932    func: A callable with the signature func(path, *values, **kwargs) that is
933      evaluated on the leaves of the structure.
934    *structure: A variable number of compatible structures to process.
935    **kwargs: Optional kwargs to be passed through to func. Special kwarg
936      `check_types` is not passed to func, but instead determines whether the
937      types of iterables within the structures have to be same (e.g.,
938      `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
939      default, the types must match. To allow iteration over structures of
940      different types (but common arity), set this kwarg to `False`.
941
942  Returns:
943    A structure of the same form as the input structures whose leaves are the
944    result of evaluating func on corresponding leaves of the input structures.
945
946  Raises:
947    TypeError: If `func` is not callable or if the structures do not match
948      each other by depth tree.
949    TypeError: If `check_types` is not `False` and the two structures differ in
950      the type of sequence in any of their substructures.
951    ValueError: If no structures are provided.
952  """
953  def wrapper_func(tuple_path, *inputs, **kwargs):
954    string_path = "/".join(str(s) for s in tuple_path)
955    return func(string_path, *inputs, **kwargs)
956
957  return map_structure_with_tuple_paths_up_to(structure[0],
958                                              wrapper_func,
959                                              *structure,
960                                              **kwargs)
961
962
963def map_structure_with_tuple_paths(func, *structure, **kwargs):
964  """Applies `func` to each entry in `structure` and returns a new structure.
965
966  Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry
967  in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary
968  keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the
969  common path to x[i] in the structures. All structures in `structure` must have
970  the same arity, and the return value will contain the results in the same
971  structure. Special kwarg `check_types` determines whether the types of
972  iterables within the structure must be the same-- see **kwargs definition
973  below.
974
975  Args:
976    func: A callable with the signature `func(tuple_path, *values, **kwargs)`
977      that is evaluated on the leaves of the structure.
978    *structure: A variable number of compatible structures to process.
979    **kwargs: Optional kwargs to be passed through to func. Special kwarg
980      `check_types` is not passed to func, but instead determines whether the
981      types of iterables within the structures have to be same (e.g.
982      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
983      this set this argument to `False`.
984
985  Returns:
986    A structure of the same form as the input structures whose leaves are the
987    result of evaluating func on corresponding leaves of the input structures.
988
989  Raises:
990    TypeError: If `func` is not callable or if the structures do not match
991      each other by depth tree.
992    TypeError: If `check_types` is not `False` and the two structures differ in
993      the type of sequence in any of their substructures.
994    ValueError: If no structures are provided.
995  """
996  return map_structure_with_tuple_paths_up_to(structure[0],
997                                              func,
998                                              *structure,
999                                              **kwargs)
1000
1001
1002def _yield_flat_up_to(shallow_tree, input_tree, is_nested_fn, path=()):
1003  """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
1004
1005  Args:
1006    shallow_tree: Nested structure. Traverse no further than its leaf nodes.
1007    input_tree: Nested structure. Return the paths and values from this tree.
1008      Must have the same upper structure as shallow_tree.
1009    is_nested_fn: Function used to test if a value should be treated as a
1010      nested structure.
1011    path: Tuple. Optional argument, only used when recursing. The path from the
1012      root of the original shallow_tree, down to the root of the shallow_tree
1013      arg of this recursive call.
1014
1015  Yields:
1016    Pairs of (path, value), where path the tuple path of a leaf node in
1017    shallow_tree, and value is the value of the corresponding node in
1018    input_tree.
1019  """
1020  if not is_nested_fn(shallow_tree):
1021    yield (path, input_tree)
1022  else:
1023    input_tree = dict(_yield_sorted_items(input_tree))
1024    for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
1025      subpath = path + (shallow_key,)
1026      input_subtree = input_tree[shallow_key]
1027      for leaf_path, leaf_value in _yield_flat_up_to(
1028          shallow_subtree, input_subtree, is_nested_fn, path=subpath):
1029        yield (leaf_path, leaf_value)
1030
1031
1032def assert_shallow_structure(shallow_tree,
1033                             input_tree,
1034                             check_types=True,
1035                             expand_composites=False):
1036  """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
1037
1038  That is, this function tests if the `input_tree` structure can be created from
1039  the `shallow_tree` structure by replacing its leaf nodes with deeper
1040  tree structures.
1041
1042  Examples:
1043
1044  The following code will raise an exception:
1045  ```python
1046    shallow_tree = {"a": "A", "b": "B"}
1047    input_tree = {"a": 1, "c": 2}
1048    assert_shallow_structure(shallow_tree, input_tree)
1049  ```
1050
1051  The following code will raise an exception:
1052  ```python
1053    shallow_tree = ["a", "b"]
1054    input_tree = ["c", ["d", "e"], "f"]
1055    assert_shallow_structure(shallow_tree, input_tree)
1056  ```
1057
1058  Args:
1059    shallow_tree: an arbitrarily nested structure.
1060    input_tree: an arbitrarily nested structure.
1061    check_types: if `True` (default) the sequence types of `shallow_tree` and
1062      `input_tree` have to be the same. Note that even with check_types==True,
1063      this function will consider two different namedtuple classes with the same
1064      name and _fields attribute to be the same class.
1065    expand_composites: If true, then composite tensors such as
1066      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1067      component tensors.
1068  Raises:
1069    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1070    TypeError: If the sequence types of `shallow_tree` are different from
1071      `input_tree`. Only raised if `check_types` is `True`.
1072    ValueError: If the sequence lengths of `shallow_tree` are different from
1073      `input_tree`.
1074  """
1075  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1076  if is_nested_fn(shallow_tree):
1077    if not is_nested_fn(input_tree):
1078      raise TypeError(
1079          "If shallow structure is a sequence, input must also be a sequence. "
1080          "Input has type: %s." % type(input_tree))
1081
1082    if isinstance(shallow_tree, _wrapt.ObjectProxy):
1083      shallow_type = type(shallow_tree.__wrapped__)
1084    else:
1085      shallow_type = type(shallow_tree)
1086
1087    if check_types and not isinstance(input_tree, shallow_type):
1088      # Duck-typing means that nest should be fine with two different
1089      # namedtuples with identical name and fields.
1090      shallow_is_namedtuple = is_namedtuple(shallow_tree, False)
1091      input_is_namedtuple = is_namedtuple(input_tree, False)
1092      if shallow_is_namedtuple and input_is_namedtuple:
1093        if not same_namedtuples(shallow_tree, input_tree):
1094          raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1095              input_type=type(input_tree),
1096              shallow_type=type(shallow_tree)))
1097
1098      elif isinstance(shallow_tree, list) and isinstance(input_tree, list):
1099        # List subclasses are considered the same,
1100        # e.g. python list vs. _ListWrapper.
1101        pass
1102
1103      elif ((_is_composite_tensor(shallow_tree) or
1104             _is_composite_tensor(input_tree)) and
1105            (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))):
1106        pass  # Compatibility will be checked below.
1107
1108      elif not (isinstance(shallow_tree, _collections_abc.Mapping) and
1109                isinstance(input_tree, _collections_abc.Mapping)):
1110        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1111            input_type=type(input_tree),
1112            shallow_type=type(shallow_tree)))
1113
1114    if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree):
1115      if not (
1116          (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and
1117          (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))):
1118        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1119            input_type=type(input_tree),
1120            shallow_type=type(shallow_tree)))
1121      # pylint: disable=protected-access
1122      type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else
1123                     shallow_tree._type_spec)._without_tensor_names()
1124      type_spec_2 = (input_tree if _is_type_spec(input_tree) else
1125                     input_tree._type_spec)._without_tensor_names()
1126      # pylint: enable=protected-access
1127      result = type_spec_1.most_specific_common_supertype([type_spec_2])
1128      if result is None:
1129        raise ValueError("Incompatible CompositeTensor TypeSpecs: %s vs. %s" %
1130                         (type_spec_1, type_spec_2))
1131
1132    elif _is_type_spec(shallow_tree):
1133      if not _is_type_spec(input_tree):
1134        raise TypeError("If shallow structure is a TypeSpec, input must also "
1135                        "be a TypeSpec.  Input has type: %s."
1136                        % type(input_tree))
1137    else:
1138      if len(input_tree) != len(shallow_tree):
1139        raise ValueError(
1140            _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
1141                input_length=len(input_tree), shallow_length=len(shallow_tree)))
1142      elif len(input_tree) < len(shallow_tree):
1143        raise ValueError(
1144            _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
1145                input_size=len(input_tree), shallow_size=len(shallow_tree)))
1146
1147    if isinstance(shallow_tree, _collections_abc.Mapping):
1148      absent_keys = set(shallow_tree) - set(input_tree)
1149      if absent_keys:
1150        raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
1151                         .format(sorted(absent_keys)))
1152
1153    for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
1154                                            _yield_value(input_tree)):
1155      assert_shallow_structure(shallow_branch, input_branch,
1156                               check_types=check_types,
1157                               expand_composites=expand_composites)
1158
1159
1160@tf_export("__internal__.nest.flatten_up_to", v1=[])
1161def flatten_up_to(shallow_tree, input_tree, check_types=True,
1162                  expand_composites=False):
1163  """Flattens `input_tree` up to `shallow_tree`.
1164
1165  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
1166  for the definition of a structure.
1167
1168  Any further depth in structure in `input_tree` is retained as structures in
1169  the partially flatten output.
1170
1171  If `shallow_tree` and `input_tree` are atoms, this returns a
1172  single-item list: `[input_tree]`.
1173
1174  Use Case:
1175
1176  Sometimes we may wish to partially flatten a structure, retaining some
1177  of the nested structure. We achieve this by specifying a shallow structure,
1178  `shallow_tree`, we wish to flatten up to.
1179
1180  The input, `input_tree`, can be thought of as having the same structure layout
1181  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1182
1183  Examples:
1184
1185  ```python
1186  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1187  shallow_tree = [[True, True], [False, True]]
1188
1189  flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
1190  flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
1191
1192  # Output is:
1193  # [[2, 2], [3, 3], [4, 9], [5, 5]]
1194  # [True, True, False, True]
1195  ```
1196
1197  ```python
1198  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1199  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1200
1201  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1202  input_tree_flattened = flatten(input_tree)
1203
1204  # Output is:
1205  # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1206  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1207  ```
1208
1209  Edge Cases for atoms:
1210
1211  ```python
1212  flatten_up_to(0, 0)  # Output: [0]
1213  flatten_up_to(0, [0, 1, 2])  # Output: [[0, 1, 2]]
1214  flatten_up_to([0, 1, 2], 0)  # Output: TypeError
1215  flatten_up_to([0, 1, 2], [0, 1, 2])  # Output: [0, 1, 2]
1216  ```
1217
1218  Args:
1219    shallow_tree: a possibly pruned structure of input_tree.
1220    input_tree: an atom or a nested structure.
1221      Note, numpy arrays are considered atoms.
1222    check_types: bool. If True, check that each node in shallow_tree has the
1223      same type as the corresponding node in input_tree.
1224    expand_composites: If true, then composite tensors such as
1225      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1226      component tensors.
1227
1228  Returns:
1229    A Python list, the partially flattened version of `input_tree` according to
1230    the structure of `shallow_tree`.
1231
1232  Raises:
1233    TypeError: If `shallow_tree` is a nested structure but `input_tree` is not.
1234    TypeError: If the structure types of `shallow_tree` are different from
1235      `input_tree`.
1236    ValueError: If the structure lengths of `shallow_tree` are different from
1237      `input_tree`.
1238  """
1239  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1240  assert_shallow_structure(shallow_tree,
1241                           input_tree,
1242                           check_types=check_types,
1243                           expand_composites=expand_composites)
1244  # Discard paths returned by _yield_flat_up_to.
1245  return [
1246      v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_nested_fn)
1247  ]
1248
1249
1250def flatten_with_tuple_paths_up_to(shallow_tree,
1251                                   input_tree,
1252                                   check_types=True,
1253                                   expand_composites=False):
1254  """Flattens `input_tree` up to `shallow_tree`.
1255
1256  Any further depth in structure in `input_tree` is retained as structures in
1257  the partially flattened output.
1258
1259  Returns a list of (path, value) pairs, where value a leaf node in the
1260  flattened tree, and path is the tuple path of that leaf in input_tree.
1261
1262  If `shallow_tree` and `input_tree` are not sequences, this returns a
1263  single-item list: `[((), input_tree)]`.
1264
1265  Use Case:
1266
1267  Sometimes we may wish to partially flatten a nested sequence, retaining some
1268  of the nested structure. We achieve this by specifying a shallow structure,
1269  `shallow_tree`, we wish to flatten up to.
1270
1271  The input, `input_tree`, can be thought of as having the same structure layout
1272  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1273
1274  Examples:
1275
1276  ```python
1277  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1278  shallow_tree = [[True, True], [False, True]]
1279
1280  flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1281                                                        input_tree)
1282  flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1283                                                          shallow_tree)
1284
1285  # Output is:
1286  # [((0, 0), [2, 2]),
1287  #  ((0, 1), [3, 3]),
1288  #  ((1, 0), [4, 9]),
1289  #  ((1, 1), [5, 5])]
1290  #
1291  # [((0, 0), True),
1292  #  ((0, 1), True),
1293  #  ((1, 0), False),
1294  #  ((1, 1), True)]
1295  ```
1296
1297  ```python
1298  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1299  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1300
1301  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1302  input_tree_flattened = flatten(input_tree)
1303
1304  # Output is:
1305  # [((0, 0), ('a', 1)),
1306  #  ((0, 1, 0), ('b', 2)),
1307  #  ((0, 1, 1, 0), ('c', 3)),
1308  #  ((0, 1, 1, 1), ('d', 4))]
1309  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1310  ```
1311
1312  Non-Sequence Edge Cases:
1313
1314  ```python
1315  flatten_with_tuple_paths_up_to(0, 0)  # Output: [(), 0]
1316
1317  flatten_with_tuple_paths_up_to(0, [0, 1, 2])  # Output: [(), [0, 1, 2]]
1318
1319  flatten_with_tuple_paths_up_to([0, 1, 2], 0)  # Output: TypeError
1320
1321  flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
1322  # Output: [((0,) 0), ((1,), 1), ((2,), 2)]
1323  ```
1324
1325  Args:
1326    shallow_tree: a possibly pruned structure of input_tree.
1327    input_tree: an atom or a nested structure.
1328      Note, numpy arrays are considered atoms.
1329    check_types: bool. If True, check that each node in shallow_tree has the
1330      same type as the corresponding node in input_tree.
1331    expand_composites: If true, then composite tensors such as
1332      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1333      component tensors.
1334
1335  Returns:
1336    A Python list, the partially flattened version of `input_tree` according to
1337    the structure of `shallow_tree`.
1338
1339  Raises:
1340    TypeError: If `shallow_tree` is a nested structure but `input_tree` is not.
1341    TypeError: If the structure types of `shallow_tree` are different from
1342      `input_tree`.
1343    ValueError: If the structure lengths of `shallow_tree` are different from
1344      `input_tree`.
1345  """
1346  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1347  assert_shallow_structure(shallow_tree,
1348                           input_tree,
1349                           check_types=check_types,
1350                           expand_composites=expand_composites)
1351  return list(_yield_flat_up_to(shallow_tree, input_tree, is_nested_fn))
1352
1353
1354@tf_export("__internal__.nest.map_structure_up_to", v1=[])
1355def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
1356  """Applies a function or op to a number of partially flattened inputs.
1357
1358  The `inputs` are flattened up to `shallow_tree` before being mapped.
1359
1360  Use Case:
1361
1362  Sometimes we wish to apply a function to a partially flattened
1363  structure (for example when the function itself takes structure inputs). We
1364  achieve this by specifying a shallow structure, `shallow_tree` we wish to
1365  flatten up to.
1366
1367  The `inputs`, can be thought of as having the same structure layout as
1368  `shallow_tree`, but with leaf nodes that are themselves tree structures.
1369
1370  This function therefore will return something with the same base structure as
1371  `shallow_tree`.
1372
1373  Examples:
1374
1375  ```python
1376  shallow_tree = [None, None]
1377  inp_val = [1, 2, 3]
1378  out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val)
1379
1380  # Output is: [2, 4]
1381  ```
1382
1383  ```python
1384  ab_tuple = collections.namedtuple("ab_tuple", "a, b")
1385  op_tuple = collections.namedtuple("op_tuple", "add, mul")
1386  inp_val = ab_tuple(a=2, b=3)
1387  inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
1388  out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
1389                            inp_val, inp_ops)
1390
1391  # Output is: ab_tuple(a=6, b=15)
1392  ```
1393
1394  ```python
1395  data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
1396  name_list = ['evens', ['odds', 'primes']]
1397  out = map_structure_up_to(
1398      name_list,
1399      lambda name, sec: "first_{}_{}".format(len(sec), name),
1400      name_list, data_list)
1401
1402  # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
1403  ```
1404
1405  Args:
1406    shallow_tree: a shallow structure, common to all the inputs.
1407    func: callable which will be applied to each input individually.
1408    *inputs: structures that are compatible with shallow_tree. The function
1409        `func` is applied to corresponding structures due to partial flattening
1410        of each input, so the function must support arity of `len(inputs)`.
1411    **kwargs: kwargs to feed to func(). Special kwarg
1412      `check_types` is not passed to func, but instead determines whether the
1413      types of iterables within the structures have to be same (e.g.
1414      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1415      this set this argument to `False`.
1416
1417  Raises:
1418    TypeError: If `shallow_tree` is a nested structure but `input_tree` is not.
1419    TypeError: If the structure types of `shallow_tree` are different from
1420      `input_tree`.
1421    ValueError: If the structure lengths of `shallow_tree` are different from
1422      `input_tree`.
1423
1424  Returns:
1425    result of repeatedly applying `func`, with the same structure layout as
1426    `shallow_tree`.
1427  """
1428  return map_structure_with_tuple_paths_up_to(
1429      shallow_tree,
1430      lambda _, *values: func(*values),  # Discards the path arg.
1431      *inputs,
1432      **kwargs)
1433
1434
1435def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
1436  """Applies a function or op to a number of partially flattened inputs.
1437
1438  Like map_structure_up_to(), except that the 'func' argument takes a path
1439  tuple as its first argument, followed by the corresponding values from
1440  *inputs.
1441
1442  Example:
1443
1444  ```python
1445  lowercase = {'a': 'a', 'b': ('b0', 'b1')}
1446  uppercase = {'a': 'A', 'b': ('B0', 'B1')}
1447
1448  def print_path_and_values(path, *values):
1449    print("path: {}, values: {}".format(path, values))
1450
1451  shallow_tree = {'a': None}
1452  map_structure_with_tuple_paths_up_to(shallow_tree,
1453                                       print_path_and_values,
1454                                       lowercase,
1455                                       uppercase)
1456  path: ('a',), values: ('a', 'A')
1457  path: ('b', 0), values: ('b0', 'B0')
1458  path: ('b', 1), values: ('b1', 'B1')
1459
1460  shallow_tree = {'b': None}
1461  map_structure_with_tuple_paths_up_to(shallow_tree,
1462                                       print_path_and_values,
1463                                       lowercase,
1464                                       uppercase,
1465                                       check_types=False)
1466  path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
1467
1468  shallow_tree = {'a': None, 'b': {1: None}}
1469  map_structure_with_tuple_paths_up_to(shallow_tree,
1470                                       print_path_and_values,
1471                                       lowercase,
1472                                       uppercase,
1473                                       check_types=False)
1474  path: ('a',), values: ('a', 'A')
1475  path: ('b', 1), values: ('b1', B1')
1476  ```
1477
1478  Args:
1479    shallow_tree: a shallow structure, common to all the inputs.
1480    func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
1481      where path is a tuple path to an atom in shallow_tree, and
1482      inputs_i_value is the corresponding value from inputs[i].
1483    *inputs: structures that are all structurally compatible with
1484        shallow_tree.
1485    **kwargs: kwargs to feed to func(). Special kwarg
1486      `check_types` is not passed to func, but instead determines whether the
1487      types of iterables within the structures have to be same (e.g.
1488      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1489      this set this argument to `False`.
1490
1491  Raises:
1492    TypeError: If `shallow_tree` is a nested structure but one of `*inputs` is
1493      not.
1494    TypeError: If the structure types of `shallow_tree` are different from
1495      `input_tree`.
1496    ValueError: If the structure lengths of `shallow_tree` are different from
1497      `input_tree`.
1498
1499  Returns:
1500    Result of repeatedly applying `func`. Has the same structure layout as
1501    `shallow_tree`.
1502  """
1503  if not inputs:
1504    raise ValueError("Cannot map over no sequences")
1505
1506  check_types = kwargs.pop("check_types", True)
1507  expand_composites = kwargs.pop("expand_composites", False)
1508  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1509
1510  for input_tree in inputs:
1511    assert_shallow_structure(
1512        shallow_tree,
1513        input_tree,
1514        check_types=check_types,
1515        expand_composites=expand_composites)
1516
1517  # Flatten each input separately, apply the function to corresponding items,
1518  # then repack based on the structure of the first input.
1519  flat_value_gen = (
1520      flatten_up_to(  # pylint: disable=g-complex-comprehension
1521          shallow_tree,
1522          input_tree,
1523          check_types,
1524          expand_composites=expand_composites) for input_tree in inputs)
1525  flat_path_gen = (
1526      path
1527      for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn))
1528  results = [
1529      func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1530  ]
1531  return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
1532                          expand_composites=expand_composites)
1533
1534
1535@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[])
1536def get_traverse_shallow_structure(traverse_fn, structure,
1537                                   expand_composites=False):
1538  """Generates a shallow structure from a `traverse_fn` and `structure`.
1539
1540  `traverse_fn` must accept any possible subtree of `structure` and return
1541  a depth=1 structure containing `True` or `False` values, describing which
1542  of the top-level subtrees may be traversed.  It may also
1543  return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
1544
1545  Examples are available in the unit tests (nest_test.py).
1546
1547  Args:
1548    traverse_fn: Function taking a substructure and returning either a scalar
1549      `bool` (whether to traverse that substructure or not) or a depth=1
1550      shallow structure of the same type, describing which parts of the
1551      substructure to traverse.
1552    structure: The structure to traverse.
1553    expand_composites: If true, then composite tensors such as
1554      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1555      component tensors.
1556
1557  Returns:
1558    A shallow structure containing python bools, which can be passed to
1559    `map_structure_up_to` and `flatten_up_to`.
1560
1561  Raises:
1562    TypeError: if `traverse_fn` returns a nested structure for an atom input.
1563      or a structure with depth higher than 1 for a nested structure input,
1564      or if any leaf values in the returned structure or scalar are not type
1565      `bool`.
1566  """
1567  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1568  to_traverse = traverse_fn(structure)
1569  if not is_nested_fn(structure):
1570    if not isinstance(to_traverse, bool):
1571      raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
1572                      % (to_traverse, structure))
1573    return to_traverse
1574  level_traverse = []
1575  if isinstance(to_traverse, bool):
1576    if not to_traverse:
1577      # Do not traverse this substructure at all.  Exit early.
1578      return False
1579    else:
1580      # Traverse the entire substructure.
1581      for branch in _yield_value(structure):
1582        level_traverse.append(
1583            get_traverse_shallow_structure(traverse_fn, branch,
1584                                           expand_composites=expand_composites))
1585  elif not is_nested_fn(to_traverse):
1586    raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
1587                    % (to_traverse, structure))
1588  else:
1589    # Traverse some subset of this substructure.
1590    assert_shallow_structure(to_traverse, structure,
1591                             expand_composites=expand_composites)
1592    for t, branch in zip(_yield_value(to_traverse),
1593                         _yield_value(structure)):
1594      if not isinstance(t, bool):
1595        raise TypeError(
1596            "traverse_fn didn't return a depth=1 structure of bools.  saw: %s "
1597            " for structure: %s" % (to_traverse, structure))
1598      if t:
1599        level_traverse.append(
1600            get_traverse_shallow_structure(traverse_fn, branch))
1601      else:
1602        level_traverse.append(False)
1603  return _sequence_like(structure, level_traverse)
1604
1605
1606@tf_export("__internal__.nest.yield_flat_paths", v1=[])
1607def yield_flat_paths(nest, expand_composites=False):
1608  """Yields paths for some nested structure.
1609
1610  Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest)
1611  for the definition of a structure.
1612
1613  Paths are lists of objects which can be str-converted, which may include
1614  integers or other types which are used as indices in a dict.
1615
1616  The flat list will be in the corresponding order as if you called
1617  `nest.flatten` on the structure. This is handy for naming Tensors such
1618  the TF scope structure matches the tuple structure.
1619
1620  E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
1621
1622  ```shell
1623  nest.flatten(value)
1624  [3, 23, 42]
1625  list(nest.yield_flat_paths(value))
1626  [('a',), ('b', 'c'), ('b', 'd')]
1627  ```
1628
1629  ```shell
1630  list(nest.yield_flat_paths({'a': [3]}))
1631  [('a', 0)]
1632  list(nest.yield_flat_paths({'a': 3}))
1633  [('a',)]
1634  ```
1635
1636  Args:
1637    nest: the value to produce a flattened paths list for.
1638    expand_composites: If true, then composite tensors such as
1639      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1640      component tensors.
1641
1642  Yields:
1643    Tuples containing index or key values which form the path to a specific
1644    leaf value in the nested structure.
1645  """
1646  is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested
1647  for k, _ in _yield_flat_up_to(nest, nest, is_nested_fn):
1648    yield k
1649
1650
1651def flatten_with_joined_string_paths(structure, separator="/",
1652                                     expand_composites=False):
1653  """Returns a list of (string path, atom) tuples.
1654
1655  The order of tuples produced matches that of `nest.flatten`. This allows you
1656  to flatten a nested structure while keeping information about where in the
1657  structure each atom was located. See `nest.yield_flat_paths`
1658  for more information.
1659
1660  Args:
1661    structure: the nested structure to flatten.
1662    separator: string to separate levels of hierarchy in the results, defaults
1663      to '/'.
1664    expand_composites: If true, then composite tensors such as
1665      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1666      component tensors.
1667
1668  Returns:
1669    A list of (string, atom) tuples.
1670  """
1671  flat_paths = yield_flat_paths(structure, expand_composites=expand_composites)
1672  def stringify_and_join(path_elements):
1673    return separator.join(str(path_element) for path_element in path_elements)
1674
1675  flat_string_paths = (stringify_and_join(path) for path in flat_paths)
1676  return list(zip(flat_string_paths,
1677                  flatten(structure, expand_composites=expand_composites)))
1678
1679
1680def flatten_with_tuple_paths(structure, expand_composites=False):
1681  """Returns a list of `(tuple_path, atom)` tuples.
1682
1683  The order of pairs produced matches that of `nest.flatten`. This allows you
1684  to flatten a nested structure while keeping information about where in the
1685  structure each atom was located. See `nest.yield_flat_paths`
1686  for more information about tuple paths.
1687
1688  Args:
1689    structure: the nested structure to flatten.
1690    expand_composites: If true, then composite tensors such as
1691      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1692      component tensors.
1693
1694  Returns:
1695    A list of `(tuple_path, atom)` tuples. Each `tuple_path` is a tuple
1696    of indices and/or dictionary keys that uniquely specify the path to
1697    `atom` within `structure`.
1698  """
1699  return list(zip(yield_flat_paths(structure,
1700                                   expand_composites=expand_composites),
1701                  flatten(structure, expand_composites=expand_composites)))
1702
1703
1704@tf_export("__internal__.nest.list_to_tuple", v1=[])
1705def list_to_tuple(structure):
1706  """Replace all lists with tuples.
1707
1708  The fork of nest that tf.data uses treats lists as atoms, while
1709  tf.nest treats them as structures to recurse into. Keras has chosen to adopt
1710  the latter convention, and must therefore deeply replace all lists with tuples
1711  before passing structures to Dataset.from_generator.
1712
1713  Args:
1714    structure: A nested structure to be remapped.
1715
1716  Returns:
1717    structure mapped to replace all lists with tuples.
1718  """
1719  def sequence_fn(instance, args):
1720    if isinstance(instance, list):
1721      return tuple(args)
1722    return _sequence_like(instance, args)
1723
1724  return _pack_sequence_as(structure, flatten(structure), False,
1725                           sequence_fn=sequence_fn)
1726
1727
1728_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
1729_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping)
1730_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
1731_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
1732_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy)
1733