1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Functions that work with structures. 17 18A structure is either: 19 20* one of the recognized Python collections, holding _nested structures_; 21* a value of any other type, typically a TensorFlow data type like Tensor, 22 Variable, or of compatible types such as int, float, ndarray, etc. these are 23 commonly referred to as _atoms_ of the structure. 24 25A structure of type `T` is a structure whose atomic items are of type `T`. 26For example, a structure of `tf.Tensor` only contains `tf.Tensor` as its atoms. 27 28Historically a _nested structure_ was called a _nested sequence_ in TensorFlow. 29A nested structure is sometimes called a _nest_ or a _tree_, but the formal 30name _nested structure_ is preferred. 31 32Refer to [Nesting Data Structures] 33(https://en.wikipedia.org/wiki/Nesting_(computing)#Data_structures). 34 35The following collection types are recognized by `tf.nest` as nested 36structures: 37 38* `collections.abc.Sequence` (except `string` and `bytes`). 39 This includes `list`, `tuple`, and `namedtuple`. 40* `collections.abc.Mapping` (with sortable keys). 41 This includes `dict` and `collections.OrderedDict`. 42* `collections.abc.MappingView` (with sortable keys). 43* [`attr.s` classes](https://www.attrs.org/). 44 45Any other values are considered **atoms**. Not all collection types are 46considered nested structures. For example, the following types are 47considered atoms: 48 49* `set`; `{"a", "b"}` is an atom, while `["a", "b"]` is a nested structure. 50* [`dataclass` classes](https://docs.python.org/library/dataclasses.html) 51* `tf.Tensor` 52* `numpy.array` 53 54`tf.nest.is_nested` checks whether an object is a nested structure or an atom. 55For example: 56 57 >>> tf.nest.is_nested("1234") 58 False 59 >>> tf.nest.is_nested([1, 3, [4, 5]]) 60 True 61 >>> tf.nest.is_nested(((7, 8), (5, 6))) 62 True 63 >>> tf.nest.is_nested([]) 64 True 65 >>> tf.nest.is_nested({"a": 1, "b": 2}) 66 True 67 >>> tf.nest.is_nested({"a": 1, "b": 2}.keys()) 68 True 69 >>> tf.nest.is_nested({"a": 1, "b": 2}.values()) 70 True 71 >>> tf.nest.is_nested({"a": 1, "b": 2}.items()) 72 True 73 >>> tf.nest.is_nested(set([1, 2])) 74 False 75 >>> ones = tf.ones([2, 3]) 76 >>> tf.nest.is_nested(ones) 77 False 78 79Note: A proper structure shall form a tree. The user shall ensure there is no 80cyclic references within the items in the structure, 81i.e., no references in the structure of the input of these functions 82should be recursive. The behavior is undefined if there is a cycle. 83 84""" 85 86import collections as _collections 87 88import six as _six 89import wrapt as _wrapt 90 91from tensorflow.python.platform import tf_logging 92from tensorflow.python.util import _pywrap_nest 93from tensorflow.python.util import _pywrap_utils 94from tensorflow.python.util.compat import collections_abc as _collections_abc 95from tensorflow.python.util.tf_export import tf_export 96 97 98_SHALLOW_TREE_HAS_INVALID_KEYS = ( 99 "The shallow_tree's keys are not a subset of the input_tree's keys. The " 100 "shallow_tree has the following keys that are not in the input_tree: {}.") 101 102_STRUCTURES_HAVE_MISMATCHING_TYPES = ( 103 "The two structures don't have the same sequence type. Input structure has " 104 "type {input_type}, while shallow structure has type {shallow_type}.") 105 106_STRUCTURES_HAVE_MISMATCHING_LENGTHS = ( 107 "The two structures don't have the same sequence length. Input " 108 "structure has length {input_length}, while shallow structure has length " 109 "{shallow_length}." 110) 111 112_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = ( 113 "The input_tree has fewer items than the shallow_tree. Input structure " 114 "has length {input_size}, while shallow structure has length " 115 "{shallow_size}.") 116 117_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = ( 118 "If shallow structure is a sequence, input must also be a sequence. " 119 "Input has type: {}.") 120 121 122def _get_attrs_items(obj): 123 """Returns a list of (name, value) pairs from an attrs instance. 124 125 The list will be sorted by name. 126 127 Args: 128 obj: an object. 129 130 Returns: 131 A list of (attr_name, attr_value) pairs, sorted by attr_name. 132 """ 133 attrs = getattr(obj.__class__, "__attrs_attrs__") 134 attr_names = (a.name for a in attrs) 135 return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names] 136 137 138def _sorted(dict_): 139 """Returns a sorted list of the dict keys, with error if keys not sortable.""" 140 try: 141 return sorted(dict_.keys()) 142 except TypeError: 143 raise TypeError("nest only supports dicts with sortable keys.") 144 145 146# TODO(b/225045380): Move to a "leaf" library to use in trace_type. 147def is_namedtuple(instance, strict=False): 148 """Returns True iff `instance` is a `namedtuple`. 149 150 Args: 151 instance: An instance of a Python object. 152 strict: If True, `instance` is considered to be a `namedtuple` only if 153 it is a "plain" namedtuple. For instance, a class inheriting 154 from a `namedtuple` will be considered to be a `namedtuple` 155 iff `strict=False`. 156 157 Returns: 158 True if `instance` is a `namedtuple`. 159 """ 160 return _pywrap_utils.IsNamedtuple(instance, strict) 161 162_is_namedtuple = is_namedtuple # This function was private up to TF2.5. 163 164_is_mapping_view = _pywrap_utils.IsMappingView 165_is_attrs = _pywrap_utils.IsAttrs 166_is_composite_tensor = _pywrap_utils.IsCompositeTensor 167_is_type_spec = _pywrap_utils.IsTypeSpec 168_is_mutable_mapping = _pywrap_utils.IsMutableMapping 169_is_mapping = _pywrap_utils.IsMapping 170 171 172# TODO(b/225045380): Move to a "leaf" library to use in trace_type. 173@tf_export("__internal__.nest.is_attrs", v1=[]) 174def is_attrs(obj): 175 """Returns a true if its input is an instance of an attr.s decorated class.""" 176 return _is_attrs(obj) 177 178 179@tf_export("__internal__.nest.is_mapping", v1=[]) 180def is_mapping(obj): 181 """Returns a true if its input is a collections.Mapping.""" 182 return _is_mapping(obj) 183 184 185@tf_export("__internal__.nest.sequence_like", v1=[]) 186def _sequence_like(instance, args): 187 """Converts the sequence `args` to the same type as `instance`. 188 189 Args: 190 instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, 191 `collections.OrderedDict`, or `composite_tensor.Composite_Tensor` 192 or `type_spec.TypeSpec`. 193 args: items to be converted to the `instance` type. 194 195 Returns: 196 `args` with the type of `instance`. 197 """ 198 if _is_mutable_mapping(instance): 199 # Pack dictionaries in a deterministic order by sorting the keys. 200 # Notice this means that we ignore the original order of `OrderedDict` 201 # instances. This is intentional, to avoid potential bugs caused by mixing 202 # ordered and plain dicts (e.g., flattening a dict but using a 203 # corresponding `OrderedDict` to pack it back). 204 result = dict(zip(_sorted(instance), args)) 205 instance_type = type(instance) 206 if instance_type == _collections.defaultdict: 207 d = _collections.defaultdict(instance.default_factory) 208 else: 209 d = instance_type() 210 for key in instance: 211 d[key] = result[key] 212 return d 213 elif _is_mapping(instance): 214 result = dict(zip(_sorted(instance), args)) 215 instance_type = type(instance) 216 if not getattr(instance_type, "__supported_by_tf_nest__", False): 217 tf_logging.log_first_n( 218 tf_logging.WARN, "Mapping types may not work well with tf.nest. " 219 "Prefer using MutableMapping for {}".format(instance_type), 1) 220 try: 221 return instance_type((key, result[key]) for key in instance) 222 except TypeError as err: 223 raise TypeError("Error creating an object of type {} like {}. Note that " 224 "it must accept a single positional argument " 225 "representing an iterable of key-value pairs, in " 226 "addition to self. Cause: {}".format( 227 type(instance), instance, err)) 228 elif _is_mapping_view(instance): 229 # We can't directly construct mapping views, so we create a list instead 230 return list(args) 231 elif is_namedtuple(instance) or _is_attrs(instance): 232 if isinstance(instance, _wrapt.ObjectProxy): 233 instance_type = type(instance.__wrapped__) 234 else: 235 instance_type = type(instance) 236 return instance_type(*args) 237 elif _is_composite_tensor(instance): 238 assert len(args) == 1 239 spec = instance._type_spec # pylint: disable=protected-access 240 return spec._from_components(args[0]) # pylint: disable=protected-access 241 elif _is_type_spec(instance): 242 # Pack a CompositeTensor's components according to a TypeSpec. 243 assert len(args) == 1 244 return instance._from_components(args[0]) # pylint: disable=protected-access 245 elif isinstance(instance, _six.moves.range): 246 return _sequence_like(list(instance), args) 247 elif isinstance(instance, _wrapt.ObjectProxy): 248 # For object proxies, first create the underlying type and then re-wrap it 249 # in the proxy type. 250 return type(instance)(_sequence_like(instance.__wrapped__, args)) 251 else: 252 # Not a namedtuple 253 return type(instance)(args) 254 255 256def _yield_value(iterable): 257 for _, v in _yield_sorted_items(iterable): 258 yield v 259 260 261def _yield_sorted_items(iterable): 262 """Yield (key, value) pairs for `iterable` in a deterministic order. 263 264 For Sequences, the key will be an int, the array index of a value. 265 For Mappings, the key will be the dictionary key. 266 For objects (e.g. namedtuples), the key will be the attribute name. 267 268 In all cases, the keys will be iterated in sorted order. 269 270 Args: 271 iterable: an iterable. 272 273 Yields: 274 The iterable's (key, value) pairs, in order of sorted keys. 275 """ 276 # Ordered to check common structure types (list, tuple, dict) first. 277 if isinstance(iterable, list): 278 for item in enumerate(iterable): 279 yield item 280 # namedtuples handled separately to avoid expensive namedtuple check. 281 elif type(iterable) == tuple: # pylint: disable=unidiomatic-typecheck 282 for item in enumerate(iterable): 283 yield item 284 elif isinstance(iterable, (dict, _collections_abc.Mapping)): 285 # Iterate through dictionaries in a deterministic order by sorting the 286 # keys. Notice this means that we ignore the original order of `OrderedDict` 287 # instances. This is intentional, to avoid potential bugs caused by mixing 288 # ordered and plain dicts (e.g., flattening a dict but using a 289 # corresponding `OrderedDict` to pack it back). 290 for key in _sorted(iterable): 291 yield key, iterable[key] 292 elif _is_attrs(iterable): 293 for item in _get_attrs_items(iterable): 294 yield item 295 elif is_namedtuple(iterable): 296 for field in iterable._fields: 297 yield field, getattr(iterable, field) 298 elif _is_composite_tensor(iterable): 299 type_spec = iterable._type_spec # pylint: disable=protected-access 300 yield type_spec.value_type.__name__, type_spec._to_components(iterable) # pylint: disable=protected-access 301 elif _is_type_spec(iterable): 302 # Note: to allow CompositeTensors and their TypeSpecs to have matching 303 # structures, we need to use the same key string here. 304 yield iterable.value_type.__name__, iterable._component_specs # pylint: disable=protected-access 305 else: 306 for item in enumerate(iterable): 307 yield item 308 309 310_is_nested = _pywrap_utils.IsNested 311 312_is_nested_or_composite = _pywrap_utils.IsNestedOrComposite 313 314 315@tf_export("nest.is_nested") 316def is_nested(seq): 317 """Returns true if its input is a nested structure. 318 319 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 320 for the definition of a nested structure. 321 322 Args: 323 seq: the value to test. 324 325 Returns: 326 True if the input is a nested structure. 327 """ 328 return _is_nested(seq) 329 330 331def is_nested_or_composite(seq): 332 """Returns true if its input is a nested structure or a composite. 333 334 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 335 for the definition of a nested structure. 336 337 Args: 338 seq: the value to test. 339 340 Returns: 341 True if the input is a nested structure or a composite. 342 """ 343 return _is_nested_or_composite(seq) 344 345 346# FIXME(feyu): Remove the back-compat names before closing b/201685523, after 347# all users of is_sequence are moved to the new names. (cl/405503918) 348def is_sequence(seq): 349 return _is_nested(seq) 350 351 352def is_sequence_or_composite(seq): 353 return _is_nested_or_composite(seq) 354 355 356@tf_export("nest.flatten") 357def flatten(structure, expand_composites=False): 358 """Returns a flat list from a given structure. 359 360 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 361 for the definition of a structure. 362 363 If the structure is an atom, then returns a single-item list: [structure]. 364 365 This is the inverse of the `nest.pack_sequence_as` method that takes in a 366 flattened list and re-packs it into the nested structure. 367 368 In the case of dict instances, the sequence consists of the values, sorted by 369 key to ensure deterministic behavior. This is true also for OrderedDict 370 instances: their sequence order is ignored, the sorting order of keys is used 371 instead. The same convention is followed in `nest.pack_sequence_as`. This 372 correctly repacks dicts and OrderedDicts after they have been flattened, and 373 also allows flattening an OrderedDict and then repacking it back using a 374 corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys 375 cannot be flattened. 376 377 Users must not modify any collections used in nest while this function is 378 running. 379 380 Examples: 381 382 1. Python dict (ordered by key): 383 384 >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" } 385 >>> tf.nest.flatten(dict) 386 ['value1', 'value2', 'value3'] 387 388 2. For a nested python tuple: 389 390 >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 391 >>> tf.nest.flatten(tuple) 392 [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 393 394 3. For a nested dictionary of dictionaries: 395 396 >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)}, 397 ... "key1": {"m": "val1", "g": "val2"} } 398 >>> tf.nest.flatten(dict) 399 ['val2', 'val1', 3.0, 1.0, 2.0] 400 401 4. Numpy array (will not flatten): 402 403 >>> array = np.array([[1, 2], [3, 4]]) 404 >>> tf.nest.flatten(array) 405 [array([[1, 2], 406 [3, 4]])] 407 408 5. `tf.Tensor` (will not flatten): 409 410 >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]) 411 >>> tf.nest.flatten(tensor) 412 [<tf.Tensor: shape=(3, 3), dtype=float32, numpy= 413 array([[1., 2., 3.], 414 [4., 5., 6.], 415 [7., 8., 9.]], dtype=float32)>] 416 417 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 418 of a flattened list of 'values' and a list of 'row_splits' which indicate how 419 to chop up the flattened list into different rows. For more details on 420 `tf.RaggedTensor`, please visit 421 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 422 423 with `expand_composites=False`, we just return the RaggedTensor as is. 424 425 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 426 >>> tf.nest.flatten(tensor, expand_composites=False) 427 [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>] 428 429 with `expand_composites=True`, we return the component Tensors that make up 430 the RaggedTensor representation (the values and row_splits tensors) 431 432 >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]]) 433 >>> tf.nest.flatten(tensor, expand_composites=True) 434 [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2], 435 dtype=int32)>, 436 <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>] 437 438 Args: 439 structure: an atom or a nested structure. Note, numpy arrays are considered 440 atoms and are not flattened. 441 expand_composites: If true, then composite tensors such as 442 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 443 component tensors. 444 445 Returns: 446 A Python list, the flattened version of the input. 447 448 Raises: 449 TypeError: The nest is or contains a dict with non-sortable keys. 450 """ 451 if structure is None: 452 return [None] 453 expand_composites = bool(expand_composites) 454 return _pywrap_utils.Flatten(structure, expand_composites) 455 456 457# See the swig file (util.i) for documentation. 458same_namedtuples = _pywrap_utils.SameNamedtuples 459_same_namedtuples = same_namedtuples # This function was private up to TF2.5. 460 461 462class _DotString(object): 463 464 __slots__ = [] 465 466 def __str__(self): 467 return "." 468 469 def __repr__(self): 470 return "." 471 472 473_DOT = _DotString() 474 475 476@tf_export("nest.assert_same_structure") 477def assert_same_structure(nest1, nest2, check_types=True, 478 expand_composites=False): 479 """Asserts that two structures are nested in the same way. 480 481 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 482 for the definition of a structure. 483 484 Note the method does not check the types of atoms inside the structures. 485 486 Examples: 487 488 * These atom vs. atom comparisons will pass: 489 490 >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32)) 491 >>> tf.nest.assert_same_structure("abc", np.array([1, 2])) 492 493 * These nested structure vs. nested structure comparisons will pass: 494 495 >>> structure1 = (((1, 2), 3), 4, (5, 6)) 496 >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) 497 >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]] 498 >>> tf.nest.assert_same_structure(structure1, structure2) 499 >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False) 500 501 >>> import collections 502 >>> tf.nest.assert_same_structure( 503 ... collections.namedtuple("bar", "a b")(1, 2), 504 ... collections.namedtuple("foo", "a b")(2, 3), 505 ... check_types=False) 506 507 >>> tf.nest.assert_same_structure( 508 ... collections.namedtuple("bar", "a b")(1, 2), 509 ... { "a": 1, "b": 2 }, 510 ... check_types=False) 511 512 >>> tf.nest.assert_same_structure( 513 ... { "a": 1, "b": 2, "c": 3 }, 514 ... { "c": 6, "b": 5, "a": 4 }) 515 516 >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits( 517 ... values=[3, 1, 4, 1, 5, 9, 2, 6], 518 ... row_splits=[0, 4, 4, 7, 8, 8]) 519 >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits( 520 ... values=[3, 1, 4], 521 ... row_splits=[0, 3]) 522 >>> tf.nest.assert_same_structure( 523 ... ragged_tensor1, 524 ... ragged_tensor2, 525 ... expand_composites=True) 526 527 * These examples will raise exceptions: 528 529 >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1])) 530 Traceback (most recent call last): 531 ... 532 ValueError: The two structures don't have the same nested structure 533 534 >>> tf.nest.assert_same_structure( 535 ... collections.namedtuple('bar', 'a b')(1, 2), 536 ... collections.namedtuple('foo', 'a b')(2, 3)) 537 Traceback (most recent call last): 538 ... 539 TypeError: The two structures don't have the same nested structure 540 541 Args: 542 nest1: an atom or a nested structure. 543 nest2: an atom or a nested structure. 544 check_types: if `True` (default) types of structures are checked as well, 545 including the keys of dictionaries. If set to `False`, for example a list 546 and a tuple of objects will look the same if they have the same size. Note 547 that namedtuples with identical name and fields are always considered to 548 have the same shallow structure. Two types will also be considered the 549 same if they are both list subtypes (which allows "list" and 550 "_ListWrapper" from trackable dependency tracking to compare equal). 551 `check_types=True` only checks type of sub-structures. The types of atoms 552 are not checked. 553 expand_composites: If true, then composite tensors such as 554 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 555 component tensors. 556 557 Raises: 558 ValueError: If the two structures do not have the same number of atoms or 559 if the two structures are not nested in the same way. 560 TypeError: If the two structures differ in the type of sequence in any of 561 their substructures. Only possible if `check_types` is `True`. 562 """ 563 # Convert to bool explicitly as otherwise pybind will not be able# to handle 564 # type mismatch message correctly. See GitHub issue 42329 for details. 565 check_types = bool(check_types) 566 expand_composites = bool(expand_composites) 567 try: 568 _pywrap_utils.AssertSameStructure(nest1, nest2, check_types, 569 expand_composites) 570 except (ValueError, TypeError) as e: 571 str1 = str(map_structure(lambda _: _DOT, nest1)) 572 str2 = str(map_structure(lambda _: _DOT, nest2)) 573 raise type(e)("%s\n" 574 "Entire first structure:\n%s\n" 575 "Entire second structure:\n%s" 576 % (str(e), str1, str2)) 577 578 579def flatten_dict_items(dictionary): 580 """Returns a dictionary with flattened keys and values. 581 582 This function flattens the keys and values of a dictionary, which can be 583 arbitrarily nested structures, and returns the flattened version of such 584 structures: 585 586 ```python 587 example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} 588 result = {4: "a", 5: "b", 6: "c", 8: "d"} 589 flatten_dict_items(example_dictionary) == result 590 ``` 591 592 The input dictionary must satisfy two properties: 593 594 1. Its keys and values should have the same exact nested structure. 595 2. The set of all flattened keys of the dictionary must not contain repeated 596 keys. 597 598 Args: 599 dictionary: the dictionary to zip 600 601 Returns: 602 The zipped dictionary. 603 604 Raises: 605 TypeError: If the input is not a dictionary. 606 ValueError: If any key and value do not have the same structure layout, or 607 if keys are not unique. 608 """ 609 return _pywrap_nest.FlattenDictItems(dictionary) 610 611 612def _packed_nest_with_indices(structure, 613 flat, 614 index, 615 is_nested_fn, 616 sequence_fn=None): 617 """Helper function for pack_sequence_as. 618 619 Args: 620 structure: structure to mimic. 621 flat: Flattened values to output substructure for. 622 index: Index at which to start reading from flat. 623 is_nested_fn: Function used to test if a value should be treated as a 624 nested structure. 625 sequence_fn: Function used to generate a new strcuture instance. 626 627 Returns: 628 The tuple (new_index, child), where: 629 * new_index - the updated index into `flat` having processed `structure`. 630 * packed - the subset of `flat` corresponding to `structure`, 631 having started at `index`, and packed into the same nested 632 format. 633 634 Raises: 635 ValueError: if `structure` contains more atoms than `flat` 636 (assuming indexing starts from `index`). 637 """ 638 packed = [] 639 sequence_fn = sequence_fn or _sequence_like 640 for s in _yield_value(structure): 641 if is_nested_fn(s): 642 new_index, child = _packed_nest_with_indices(s, flat, index, is_nested_fn, 643 sequence_fn) 644 packed.append(sequence_fn(s, child)) 645 index = new_index 646 else: 647 packed.append(flat[index]) 648 index += 1 649 return index, packed 650 651 652def _pack_sequence_as(structure, flat_sequence, expand_composites, 653 sequence_fn=None): 654 """Implements sequence packing, with the option to alter the structure.""" 655 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 656 sequence_fn = sequence_fn or _sequence_like 657 def truncate(value, length): 658 value_str = str(value) 659 return value_str[:length] + (value_str[length:] and "...") 660 661 if not is_nested_fn(flat_sequence): 662 raise TypeError( 663 "Attempted to pack value:\n {}\ninto a structure, but found " 664 "incompatible type `{}` instead.".format( 665 truncate(flat_sequence, 100), type(flat_sequence))) 666 667 if not is_nested_fn(structure): 668 if len(flat_sequence) != 1: 669 raise ValueError( 670 "The target structure is of type `{}`\n {}\nHowever the input " 671 "is a sequence ({}) of length {}.\n {}\nnest cannot " 672 "guarantee that it is safe to map one to the other.".format( 673 type(structure), truncate(structure, 100), type(flat_sequence), 674 len(flat_sequence), truncate(flat_sequence, 100))) 675 return flat_sequence[0] 676 677 try: 678 final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0, 679 is_nested_fn, sequence_fn) 680 if final_index < len(flat_sequence): 681 raise IndexError 682 except IndexError: 683 flat_structure = flatten(structure, expand_composites=expand_composites) 684 if len(flat_structure) != len(flat_sequence): 685 raise ValueError( 686 "Could not pack sequence. Structure had %d atoms, but " 687 "flat_sequence had %d items. Structure: %s, flat_sequence: %s." % 688 (len(flat_structure), len(flat_sequence), structure, flat_sequence)) 689 return sequence_fn(structure, packed) 690 691 692@tf_export("nest.pack_sequence_as") 693def pack_sequence_as(structure, flat_sequence, expand_composites=False): 694 """Returns a given flattened sequence packed into a given structure. 695 696 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 697 for the definition of a structure. 698 699 If `structure` is an atom, `flat_sequence` must be a single-item list; 700 in this case the return value is `flat_sequence[0]`. 701 702 If `structure` is or contains a dict instance, the keys will be sorted to 703 pack the flat sequence in deterministic order. This is true also for 704 `OrderedDict` instances: their sequence order is ignored, the sorting order of 705 keys is used instead. The same convention is followed in `flatten`. 706 This correctly repacks dicts and `OrderedDict`s after they have been 707 flattened, and also allows flattening an `OrderedDict` and then repacking it 708 back using a corresponding plain dict, or vice-versa. 709 Dictionaries with non-sortable keys cannot be flattened. 710 711 Examples: 712 713 1. Python dict: 714 715 >>> structure = { "key3": "", "key1": "", "key2": "" } 716 >>> flat_sequence = ["value1", "value2", "value3"] 717 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 718 {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'} 719 720 2. For a nested python tuple: 721 722 >>> structure = (('a','b'), ('c','d','e'), 'f') 723 >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] 724 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 725 ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) 726 727 3. For a nested dictionary of dictionaries: 728 729 >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')}, 730 ... "key1": {"e": "val1", "d": "val2"} } 731 >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0] 732 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 733 {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} 734 735 4. Numpy array (considered a scalar): 736 737 >>> structure = ['a'] 738 >>> flat_sequence = [np.array([[1, 2], [3, 4]])] 739 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 740 [array([[1, 2], 741 [3, 4]])] 742 743 5. tf.Tensor (considered a scalar): 744 745 >>> structure = ['a'] 746 >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])] 747 >>> tf.nest.pack_sequence_as(structure, flat_sequence) 748 [<tf.Tensor: shape=(2, 3), dtype=float32, 749 numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>] 750 751 6. `tf.RaggedTensor`: This is a composite tensor thats representation consists 752 of a flattened list of 'values' and a list of 'row_splits' which indicate how 753 to chop up the flattened list into different rows. For more details on 754 `tf.RaggedTensor`, please visit 755 https://www.tensorflow.org/api_docs/python/tf/RaggedTensor. 756 757 With `expand_composites=False`, we treat RaggedTensor as a scalar. 758 759 >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]), 760 ... "bar": tf.constant([[5]]) } 761 >>> flat_sequence = [ "one", "two" ] 762 >>> tf.nest.pack_sequence_as(structure, flat_sequence, 763 ... expand_composites=False) 764 {'foo': 'two', 'bar': 'one'} 765 766 With `expand_composites=True`, we expect that the flattened input contains 767 the tensors making up the ragged tensor i.e. the values and row_splits 768 tensors. 769 770 >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]), 771 ... "bar": tf.constant([[5.]]) } 772 >>> tensors = tf.nest.flatten(structure, expand_composites=True) 773 >>> print(tensors) 774 [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 775 dtype=float32)>, 776 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], 777 dtype=float32)>, 778 <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>] 779 >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ') 780 ... if t.dtype==tf.float32 else t 781 ... for t in tensors] 782 >>> tf.nest.pack_sequence_as(structure, verified_tensors, 783 ... expand_composites=True) 784 {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>, 785 'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]], 786 dtype=float32)>} 787 788 Args: 789 structure: Nested structure, whose structure is given by nested lists, 790 tuples, and dicts. Note: numpy arrays and strings are considered 791 scalars. 792 flat_sequence: flat sequence to pack. 793 expand_composites: If true, then composite tensors such as 794 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 795 component tensors. 796 797 Returns: 798 packed: `flat_sequence` converted to have the same recursive structure as 799 `structure`. 800 801 Raises: 802 ValueError: If `flat_sequence` and `structure` have different 803 atom counts. 804 TypeError: `structure` is or contains a dict with non-sortable keys. 805 """ 806 return _pack_sequence_as(structure, flat_sequence, expand_composites) 807 808 809@tf_export("nest.map_structure") 810def map_structure(func, *structure, **kwargs): 811 """Creates a new structure by applying `func` to each atom in `structure`. 812 813 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 814 for the definition of a structure. 815 816 Applies `func(x[0], x[1], ...)` where x[i] enumerates all atoms in 817 `structure[i]`. All items in `structure` must have the same arity, 818 and the return value will contain results with the same structure layout. 819 820 Examples: 821 822 * A single Python dict: 823 824 >>> a = {"hello": 24, "world": 76} 825 >>> tf.nest.map_structure(lambda p: p * 2, a) 826 {'hello': 48, 'world': 152} 827 828 * Multiple Python dictionaries: 829 830 >>> d1 = {"hello": 24, "world": 76} 831 >>> d2 = {"hello": 36, "world": 14} 832 >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2) 833 {'hello': 60, 'world': 90} 834 835 * A single Python list: 836 837 >>> a = [24, 76, "ab"] 838 >>> tf.nest.map_structure(lambda p: p * 2, a) 839 [48, 152, 'abab'] 840 841 * Scalars: 842 843 >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4) 844 7 845 846 * Empty structures: 847 848 >>> tf.nest.map_structure(lambda x: x + 1, ()) 849 () 850 851 * Check the types of iterables: 852 853 >>> s1 = (((1, 2), 3), 4, (5, 6)) 854 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 855 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list) 856 Traceback (most recent call last): 857 ... 858 TypeError: The two structures don't have the same nested structure 859 860 * Type check is set to False: 861 862 >>> s1 = (((1, 2), 3), 4, (5, 6)) 863 >>> s1_list = [[[1, 2], 3], 4, [5, 6]] 864 >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False) 865 (((None, None), None), None, (None, None)) 866 867 Args: 868 func: A callable that accepts as many arguments as there are structures. 869 *structure: atom or nested structure. 870 **kwargs: Valid keyword args are: 871 * `check_types`: If set to `True` (default) the types of iterables within 872 the structures have to be same (e.g. `map_structure(func, [1], (1,))` 873 raises a `TypeError` exception). To allow this set this argument to 874 `False`. Note that namedtuples with identical name and fields are always 875 considered to have the same shallow structure. 876 * `expand_composites`: If set to `True`, then composite tensors such as 877 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 878 component tensors. If `False` (the default), then composite tensors are 879 not expanded. 880 881 Returns: 882 A new structure with the same arity as `structure[0]`, whose atoms 883 correspond to `func(x[0], x[1], ...)` where `x[i]` is the atom in the 884 corresponding location in `structure[i]`. If there are different structure 885 types and `check_types` is `False` the structure types of the first 886 structure will be used. 887 888 Raises: 889 TypeError: If `func` is not callable or if the structures do not match 890 each other by depth tree. 891 ValueError: If no structure is provided or if the structures do not match 892 each other by type. 893 ValueError: If wrong keyword arguments are provided. 894 """ 895 if not callable(func): 896 raise TypeError("func must be callable, got: %s" % func) 897 898 if not structure: 899 raise ValueError("Must provide at least one structure") 900 901 check_types = kwargs.pop("check_types", True) 902 expand_composites = kwargs.pop("expand_composites", False) 903 904 if kwargs: 905 raise ValueError( 906 "Only valid keyword arguments are `check_types` and " 907 "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys()))) 908 909 for other in structure[1:]: 910 assert_same_structure(structure[0], other, check_types=check_types, 911 expand_composites=expand_composites) 912 913 flat_structure = (flatten(s, expand_composites) for s in structure) 914 entries = zip(*flat_structure) 915 916 return pack_sequence_as( 917 structure[0], [func(*x) for x in entries], 918 expand_composites=expand_composites) 919 920 921def map_structure_with_paths(func, *structure, **kwargs): 922 """Applies `func` to each entry in `structure` and returns a new structure. 923 924 Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in 925 `structure[i]` and `path` is the common path to x[i] in the structures. All 926 structures in `structure` must have the same arity, and the return value will 927 contain the results with the same structure layout. Special kwarg 928 `check_types` determines whether the types of iterables within the structure 929 must be the same-- see **kwargs definition below. 930 931 Args: 932 func: A callable with the signature func(path, *values, **kwargs) that is 933 evaluated on the leaves of the structure. 934 *structure: A variable number of compatible structures to process. 935 **kwargs: Optional kwargs to be passed through to func. Special kwarg 936 `check_types` is not passed to func, but instead determines whether the 937 types of iterables within the structures have to be same (e.g., 938 `map_structure(func, [1], (1,))` raises a `TypeError` exception). By 939 default, the types must match. To allow iteration over structures of 940 different types (but common arity), set this kwarg to `False`. 941 942 Returns: 943 A structure of the same form as the input structures whose leaves are the 944 result of evaluating func on corresponding leaves of the input structures. 945 946 Raises: 947 TypeError: If `func` is not callable or if the structures do not match 948 each other by depth tree. 949 TypeError: If `check_types` is not `False` and the two structures differ in 950 the type of sequence in any of their substructures. 951 ValueError: If no structures are provided. 952 """ 953 def wrapper_func(tuple_path, *inputs, **kwargs): 954 string_path = "/".join(str(s) for s in tuple_path) 955 return func(string_path, *inputs, **kwargs) 956 957 return map_structure_with_tuple_paths_up_to(structure[0], 958 wrapper_func, 959 *structure, 960 **kwargs) 961 962 963def map_structure_with_tuple_paths(func, *structure, **kwargs): 964 """Applies `func` to each entry in `structure` and returns a new structure. 965 966 Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry 967 in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary 968 keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the 969 common path to x[i] in the structures. All structures in `structure` must have 970 the same arity, and the return value will contain the results in the same 971 structure. Special kwarg `check_types` determines whether the types of 972 iterables within the structure must be the same-- see **kwargs definition 973 below. 974 975 Args: 976 func: A callable with the signature `func(tuple_path, *values, **kwargs)` 977 that is evaluated on the leaves of the structure. 978 *structure: A variable number of compatible structures to process. 979 **kwargs: Optional kwargs to be passed through to func. Special kwarg 980 `check_types` is not passed to func, but instead determines whether the 981 types of iterables within the structures have to be same (e.g. 982 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 983 this set this argument to `False`. 984 985 Returns: 986 A structure of the same form as the input structures whose leaves are the 987 result of evaluating func on corresponding leaves of the input structures. 988 989 Raises: 990 TypeError: If `func` is not callable or if the structures do not match 991 each other by depth tree. 992 TypeError: If `check_types` is not `False` and the two structures differ in 993 the type of sequence in any of their substructures. 994 ValueError: If no structures are provided. 995 """ 996 return map_structure_with_tuple_paths_up_to(structure[0], 997 func, 998 *structure, 999 **kwargs) 1000 1001 1002def _yield_flat_up_to(shallow_tree, input_tree, is_nested_fn, path=()): 1003 """Yields (path, value) pairs of input_tree flattened up to shallow_tree. 1004 1005 Args: 1006 shallow_tree: Nested structure. Traverse no further than its leaf nodes. 1007 input_tree: Nested structure. Return the paths and values from this tree. 1008 Must have the same upper structure as shallow_tree. 1009 is_nested_fn: Function used to test if a value should be treated as a 1010 nested structure. 1011 path: Tuple. Optional argument, only used when recursing. The path from the 1012 root of the original shallow_tree, down to the root of the shallow_tree 1013 arg of this recursive call. 1014 1015 Yields: 1016 Pairs of (path, value), where path the tuple path of a leaf node in 1017 shallow_tree, and value is the value of the corresponding node in 1018 input_tree. 1019 """ 1020 if not is_nested_fn(shallow_tree): 1021 yield (path, input_tree) 1022 else: 1023 input_tree = dict(_yield_sorted_items(input_tree)) 1024 for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree): 1025 subpath = path + (shallow_key,) 1026 input_subtree = input_tree[shallow_key] 1027 for leaf_path, leaf_value in _yield_flat_up_to( 1028 shallow_subtree, input_subtree, is_nested_fn, path=subpath): 1029 yield (leaf_path, leaf_value) 1030 1031 1032def assert_shallow_structure(shallow_tree, 1033 input_tree, 1034 check_types=True, 1035 expand_composites=False): 1036 """Asserts that `shallow_tree` is a shallow structure of `input_tree`. 1037 1038 That is, this function tests if the `input_tree` structure can be created from 1039 the `shallow_tree` structure by replacing its leaf nodes with deeper 1040 tree structures. 1041 1042 Examples: 1043 1044 The following code will raise an exception: 1045 ```python 1046 shallow_tree = {"a": "A", "b": "B"} 1047 input_tree = {"a": 1, "c": 2} 1048 assert_shallow_structure(shallow_tree, input_tree) 1049 ``` 1050 1051 The following code will raise an exception: 1052 ```python 1053 shallow_tree = ["a", "b"] 1054 input_tree = ["c", ["d", "e"], "f"] 1055 assert_shallow_structure(shallow_tree, input_tree) 1056 ``` 1057 1058 Args: 1059 shallow_tree: an arbitrarily nested structure. 1060 input_tree: an arbitrarily nested structure. 1061 check_types: if `True` (default) the sequence types of `shallow_tree` and 1062 `input_tree` have to be the same. Note that even with check_types==True, 1063 this function will consider two different namedtuple classes with the same 1064 name and _fields attribute to be the same class. 1065 expand_composites: If true, then composite tensors such as 1066 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1067 component tensors. 1068 Raises: 1069 TypeError: If `shallow_tree` is a sequence but `input_tree` is not. 1070 TypeError: If the sequence types of `shallow_tree` are different from 1071 `input_tree`. Only raised if `check_types` is `True`. 1072 ValueError: If the sequence lengths of `shallow_tree` are different from 1073 `input_tree`. 1074 """ 1075 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1076 if is_nested_fn(shallow_tree): 1077 if not is_nested_fn(input_tree): 1078 raise TypeError( 1079 "If shallow structure is a sequence, input must also be a sequence. " 1080 "Input has type: %s." % type(input_tree)) 1081 1082 if isinstance(shallow_tree, _wrapt.ObjectProxy): 1083 shallow_type = type(shallow_tree.__wrapped__) 1084 else: 1085 shallow_type = type(shallow_tree) 1086 1087 if check_types and not isinstance(input_tree, shallow_type): 1088 # Duck-typing means that nest should be fine with two different 1089 # namedtuples with identical name and fields. 1090 shallow_is_namedtuple = is_namedtuple(shallow_tree, False) 1091 input_is_namedtuple = is_namedtuple(input_tree, False) 1092 if shallow_is_namedtuple and input_is_namedtuple: 1093 if not same_namedtuples(shallow_tree, input_tree): 1094 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1095 input_type=type(input_tree), 1096 shallow_type=type(shallow_tree))) 1097 1098 elif isinstance(shallow_tree, list) and isinstance(input_tree, list): 1099 # List subclasses are considered the same, 1100 # e.g. python list vs. _ListWrapper. 1101 pass 1102 1103 elif ((_is_composite_tensor(shallow_tree) or 1104 _is_composite_tensor(input_tree)) and 1105 (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))): 1106 pass # Compatibility will be checked below. 1107 1108 elif not (isinstance(shallow_tree, _collections_abc.Mapping) and 1109 isinstance(input_tree, _collections_abc.Mapping)): 1110 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1111 input_type=type(input_tree), 1112 shallow_type=type(shallow_tree))) 1113 1114 if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree): 1115 if not ( 1116 (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and 1117 (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))): 1118 raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format( 1119 input_type=type(input_tree), 1120 shallow_type=type(shallow_tree))) 1121 # pylint: disable=protected-access 1122 type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else 1123 shallow_tree._type_spec)._without_tensor_names() 1124 type_spec_2 = (input_tree if _is_type_spec(input_tree) else 1125 input_tree._type_spec)._without_tensor_names() 1126 # pylint: enable=protected-access 1127 result = type_spec_1.most_specific_common_supertype([type_spec_2]) 1128 if result is None: 1129 raise ValueError("Incompatible CompositeTensor TypeSpecs: %s vs. %s" % 1130 (type_spec_1, type_spec_2)) 1131 1132 elif _is_type_spec(shallow_tree): 1133 if not _is_type_spec(input_tree): 1134 raise TypeError("If shallow structure is a TypeSpec, input must also " 1135 "be a TypeSpec. Input has type: %s." 1136 % type(input_tree)) 1137 else: 1138 if len(input_tree) != len(shallow_tree): 1139 raise ValueError( 1140 _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format( 1141 input_length=len(input_tree), shallow_length=len(shallow_tree))) 1142 elif len(input_tree) < len(shallow_tree): 1143 raise ValueError( 1144 _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format( 1145 input_size=len(input_tree), shallow_size=len(shallow_tree))) 1146 1147 if isinstance(shallow_tree, _collections_abc.Mapping): 1148 absent_keys = set(shallow_tree) - set(input_tree) 1149 if absent_keys: 1150 raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS 1151 .format(sorted(absent_keys))) 1152 1153 for shallow_branch, input_branch in zip(_yield_value(shallow_tree), 1154 _yield_value(input_tree)): 1155 assert_shallow_structure(shallow_branch, input_branch, 1156 check_types=check_types, 1157 expand_composites=expand_composites) 1158 1159 1160@tf_export("__internal__.nest.flatten_up_to", v1=[]) 1161def flatten_up_to(shallow_tree, input_tree, check_types=True, 1162 expand_composites=False): 1163 """Flattens `input_tree` up to `shallow_tree`. 1164 1165 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 1166 for the definition of a structure. 1167 1168 Any further depth in structure in `input_tree` is retained as structures in 1169 the partially flatten output. 1170 1171 If `shallow_tree` and `input_tree` are atoms, this returns a 1172 single-item list: `[input_tree]`. 1173 1174 Use Case: 1175 1176 Sometimes we may wish to partially flatten a structure, retaining some 1177 of the nested structure. We achieve this by specifying a shallow structure, 1178 `shallow_tree`, we wish to flatten up to. 1179 1180 The input, `input_tree`, can be thought of as having the same structure layout 1181 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1182 1183 Examples: 1184 1185 ```python 1186 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1187 shallow_tree = [[True, True], [False, True]] 1188 1189 flattened_input_tree = flatten_up_to(shallow_tree, input_tree) 1190 flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) 1191 1192 # Output is: 1193 # [[2, 2], [3, 3], [4, 9], [5, 5]] 1194 # [True, True, False, True] 1195 ``` 1196 1197 ```python 1198 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1199 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1200 1201 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1202 input_tree_flattened = flatten(input_tree) 1203 1204 # Output is: 1205 # [('a', 1), ('b', 2), ('c', 3), ('d', 4)] 1206 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1207 ``` 1208 1209 Edge Cases for atoms: 1210 1211 ```python 1212 flatten_up_to(0, 0) # Output: [0] 1213 flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] 1214 flatten_up_to([0, 1, 2], 0) # Output: TypeError 1215 flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] 1216 ``` 1217 1218 Args: 1219 shallow_tree: a possibly pruned structure of input_tree. 1220 input_tree: an atom or a nested structure. 1221 Note, numpy arrays are considered atoms. 1222 check_types: bool. If True, check that each node in shallow_tree has the 1223 same type as the corresponding node in input_tree. 1224 expand_composites: If true, then composite tensors such as 1225 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1226 component tensors. 1227 1228 Returns: 1229 A Python list, the partially flattened version of `input_tree` according to 1230 the structure of `shallow_tree`. 1231 1232 Raises: 1233 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not. 1234 TypeError: If the structure types of `shallow_tree` are different from 1235 `input_tree`. 1236 ValueError: If the structure lengths of `shallow_tree` are different from 1237 `input_tree`. 1238 """ 1239 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1240 assert_shallow_structure(shallow_tree, 1241 input_tree, 1242 check_types=check_types, 1243 expand_composites=expand_composites) 1244 # Discard paths returned by _yield_flat_up_to. 1245 return [ 1246 v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_nested_fn) 1247 ] 1248 1249 1250def flatten_with_tuple_paths_up_to(shallow_tree, 1251 input_tree, 1252 check_types=True, 1253 expand_composites=False): 1254 """Flattens `input_tree` up to `shallow_tree`. 1255 1256 Any further depth in structure in `input_tree` is retained as structures in 1257 the partially flattened output. 1258 1259 Returns a list of (path, value) pairs, where value a leaf node in the 1260 flattened tree, and path is the tuple path of that leaf in input_tree. 1261 1262 If `shallow_tree` and `input_tree` are not sequences, this returns a 1263 single-item list: `[((), input_tree)]`. 1264 1265 Use Case: 1266 1267 Sometimes we may wish to partially flatten a nested sequence, retaining some 1268 of the nested structure. We achieve this by specifying a shallow structure, 1269 `shallow_tree`, we wish to flatten up to. 1270 1271 The input, `input_tree`, can be thought of as having the same structure layout 1272 as `shallow_tree`, but with leaf nodes that are themselves tree structures. 1273 1274 Examples: 1275 1276 ```python 1277 input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] 1278 shallow_tree = [[True, True], [False, True]] 1279 1280 flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1281 input_tree) 1282 flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree, 1283 shallow_tree) 1284 1285 # Output is: 1286 # [((0, 0), [2, 2]), 1287 # ((0, 1), [3, 3]), 1288 # ((1, 0), [4, 9]), 1289 # ((1, 1), [5, 5])] 1290 # 1291 # [((0, 0), True), 1292 # ((0, 1), True), 1293 # ((1, 0), False), 1294 # ((1, 1), True)] 1295 ``` 1296 1297 ```python 1298 input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] 1299 shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] 1300 1301 input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) 1302 input_tree_flattened = flatten(input_tree) 1303 1304 # Output is: 1305 # [((0, 0), ('a', 1)), 1306 # ((0, 1, 0), ('b', 2)), 1307 # ((0, 1, 1, 0), ('c', 3)), 1308 # ((0, 1, 1, 1), ('d', 4))] 1309 # ['a', 1, 'b', 2, 'c', 3, 'd', 4] 1310 ``` 1311 1312 Non-Sequence Edge Cases: 1313 1314 ```python 1315 flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0] 1316 1317 flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]] 1318 1319 flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError 1320 1321 flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2]) 1322 # Output: [((0,) 0), ((1,), 1), ((2,), 2)] 1323 ``` 1324 1325 Args: 1326 shallow_tree: a possibly pruned structure of input_tree. 1327 input_tree: an atom or a nested structure. 1328 Note, numpy arrays are considered atoms. 1329 check_types: bool. If True, check that each node in shallow_tree has the 1330 same type as the corresponding node in input_tree. 1331 expand_composites: If true, then composite tensors such as 1332 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1333 component tensors. 1334 1335 Returns: 1336 A Python list, the partially flattened version of `input_tree` according to 1337 the structure of `shallow_tree`. 1338 1339 Raises: 1340 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not. 1341 TypeError: If the structure types of `shallow_tree` are different from 1342 `input_tree`. 1343 ValueError: If the structure lengths of `shallow_tree` are different from 1344 `input_tree`. 1345 """ 1346 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1347 assert_shallow_structure(shallow_tree, 1348 input_tree, 1349 check_types=check_types, 1350 expand_composites=expand_composites) 1351 return list(_yield_flat_up_to(shallow_tree, input_tree, is_nested_fn)) 1352 1353 1354@tf_export("__internal__.nest.map_structure_up_to", v1=[]) 1355def map_structure_up_to(shallow_tree, func, *inputs, **kwargs): 1356 """Applies a function or op to a number of partially flattened inputs. 1357 1358 The `inputs` are flattened up to `shallow_tree` before being mapped. 1359 1360 Use Case: 1361 1362 Sometimes we wish to apply a function to a partially flattened 1363 structure (for example when the function itself takes structure inputs). We 1364 achieve this by specifying a shallow structure, `shallow_tree` we wish to 1365 flatten up to. 1366 1367 The `inputs`, can be thought of as having the same structure layout as 1368 `shallow_tree`, but with leaf nodes that are themselves tree structures. 1369 1370 This function therefore will return something with the same base structure as 1371 `shallow_tree`. 1372 1373 Examples: 1374 1375 ```python 1376 shallow_tree = [None, None] 1377 inp_val = [1, 2, 3] 1378 out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val) 1379 1380 # Output is: [2, 4] 1381 ``` 1382 1383 ```python 1384 ab_tuple = collections.namedtuple("ab_tuple", "a, b") 1385 op_tuple = collections.namedtuple("op_tuple", "add, mul") 1386 inp_val = ab_tuple(a=2, b=3) 1387 inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) 1388 out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, 1389 inp_val, inp_ops) 1390 1391 # Output is: ab_tuple(a=6, b=15) 1392 ``` 1393 1394 ```python 1395 data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] 1396 name_list = ['evens', ['odds', 'primes']] 1397 out = map_structure_up_to( 1398 name_list, 1399 lambda name, sec: "first_{}_{}".format(len(sec), name), 1400 name_list, data_list) 1401 1402 # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] 1403 ``` 1404 1405 Args: 1406 shallow_tree: a shallow structure, common to all the inputs. 1407 func: callable which will be applied to each input individually. 1408 *inputs: structures that are compatible with shallow_tree. The function 1409 `func` is applied to corresponding structures due to partial flattening 1410 of each input, so the function must support arity of `len(inputs)`. 1411 **kwargs: kwargs to feed to func(). Special kwarg 1412 `check_types` is not passed to func, but instead determines whether the 1413 types of iterables within the structures have to be same (e.g. 1414 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1415 this set this argument to `False`. 1416 1417 Raises: 1418 TypeError: If `shallow_tree` is a nested structure but `input_tree` is not. 1419 TypeError: If the structure types of `shallow_tree` are different from 1420 `input_tree`. 1421 ValueError: If the structure lengths of `shallow_tree` are different from 1422 `input_tree`. 1423 1424 Returns: 1425 result of repeatedly applying `func`, with the same structure layout as 1426 `shallow_tree`. 1427 """ 1428 return map_structure_with_tuple_paths_up_to( 1429 shallow_tree, 1430 lambda _, *values: func(*values), # Discards the path arg. 1431 *inputs, 1432 **kwargs) 1433 1434 1435def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs): 1436 """Applies a function or op to a number of partially flattened inputs. 1437 1438 Like map_structure_up_to(), except that the 'func' argument takes a path 1439 tuple as its first argument, followed by the corresponding values from 1440 *inputs. 1441 1442 Example: 1443 1444 ```python 1445 lowercase = {'a': 'a', 'b': ('b0', 'b1')} 1446 uppercase = {'a': 'A', 'b': ('B0', 'B1')} 1447 1448 def print_path_and_values(path, *values): 1449 print("path: {}, values: {}".format(path, values)) 1450 1451 shallow_tree = {'a': None} 1452 map_structure_with_tuple_paths_up_to(shallow_tree, 1453 print_path_and_values, 1454 lowercase, 1455 uppercase) 1456 path: ('a',), values: ('a', 'A') 1457 path: ('b', 0), values: ('b0', 'B0') 1458 path: ('b', 1), values: ('b1', 'B1') 1459 1460 shallow_tree = {'b': None} 1461 map_structure_with_tuple_paths_up_to(shallow_tree, 1462 print_path_and_values, 1463 lowercase, 1464 uppercase, 1465 check_types=False) 1466 path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1')) 1467 1468 shallow_tree = {'a': None, 'b': {1: None}} 1469 map_structure_with_tuple_paths_up_to(shallow_tree, 1470 print_path_and_values, 1471 lowercase, 1472 uppercase, 1473 check_types=False) 1474 path: ('a',), values: ('a', 'A') 1475 path: ('b', 1), values: ('b1', B1') 1476 ``` 1477 1478 Args: 1479 shallow_tree: a shallow structure, common to all the inputs. 1480 func: callable that takes args (path, inputs_0_value, ... , inputs_N_value), 1481 where path is a tuple path to an atom in shallow_tree, and 1482 inputs_i_value is the corresponding value from inputs[i]. 1483 *inputs: structures that are all structurally compatible with 1484 shallow_tree. 1485 **kwargs: kwargs to feed to func(). Special kwarg 1486 `check_types` is not passed to func, but instead determines whether the 1487 types of iterables within the structures have to be same (e.g. 1488 `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow 1489 this set this argument to `False`. 1490 1491 Raises: 1492 TypeError: If `shallow_tree` is a nested structure but one of `*inputs` is 1493 not. 1494 TypeError: If the structure types of `shallow_tree` are different from 1495 `input_tree`. 1496 ValueError: If the structure lengths of `shallow_tree` are different from 1497 `input_tree`. 1498 1499 Returns: 1500 Result of repeatedly applying `func`. Has the same structure layout as 1501 `shallow_tree`. 1502 """ 1503 if not inputs: 1504 raise ValueError("Cannot map over no sequences") 1505 1506 check_types = kwargs.pop("check_types", True) 1507 expand_composites = kwargs.pop("expand_composites", False) 1508 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1509 1510 for input_tree in inputs: 1511 assert_shallow_structure( 1512 shallow_tree, 1513 input_tree, 1514 check_types=check_types, 1515 expand_composites=expand_composites) 1516 1517 # Flatten each input separately, apply the function to corresponding items, 1518 # then repack based on the structure of the first input. 1519 flat_value_gen = ( 1520 flatten_up_to( # pylint: disable=g-complex-comprehension 1521 shallow_tree, 1522 input_tree, 1523 check_types, 1524 expand_composites=expand_composites) for input_tree in inputs) 1525 flat_path_gen = ( 1526 path 1527 for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_nested_fn)) 1528 results = [ 1529 func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen) 1530 ] 1531 return pack_sequence_as(structure=shallow_tree, flat_sequence=results, 1532 expand_composites=expand_composites) 1533 1534 1535@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[]) 1536def get_traverse_shallow_structure(traverse_fn, structure, 1537 expand_composites=False): 1538 """Generates a shallow structure from a `traverse_fn` and `structure`. 1539 1540 `traverse_fn` must accept any possible subtree of `structure` and return 1541 a depth=1 structure containing `True` or `False` values, describing which 1542 of the top-level subtrees may be traversed. It may also 1543 return scalar `True` or `False` "traversal is OK / not OK for all subtrees." 1544 1545 Examples are available in the unit tests (nest_test.py). 1546 1547 Args: 1548 traverse_fn: Function taking a substructure and returning either a scalar 1549 `bool` (whether to traverse that substructure or not) or a depth=1 1550 shallow structure of the same type, describing which parts of the 1551 substructure to traverse. 1552 structure: The structure to traverse. 1553 expand_composites: If true, then composite tensors such as 1554 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1555 component tensors. 1556 1557 Returns: 1558 A shallow structure containing python bools, which can be passed to 1559 `map_structure_up_to` and `flatten_up_to`. 1560 1561 Raises: 1562 TypeError: if `traverse_fn` returns a nested structure for an atom input. 1563 or a structure with depth higher than 1 for a nested structure input, 1564 or if any leaf values in the returned structure or scalar are not type 1565 `bool`. 1566 """ 1567 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1568 to_traverse = traverse_fn(structure) 1569 if not is_nested_fn(structure): 1570 if not isinstance(to_traverse, bool): 1571 raise TypeError("traverse_fn returned structure: %s for non-structure: %s" 1572 % (to_traverse, structure)) 1573 return to_traverse 1574 level_traverse = [] 1575 if isinstance(to_traverse, bool): 1576 if not to_traverse: 1577 # Do not traverse this substructure at all. Exit early. 1578 return False 1579 else: 1580 # Traverse the entire substructure. 1581 for branch in _yield_value(structure): 1582 level_traverse.append( 1583 get_traverse_shallow_structure(traverse_fn, branch, 1584 expand_composites=expand_composites)) 1585 elif not is_nested_fn(to_traverse): 1586 raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" 1587 % (to_traverse, structure)) 1588 else: 1589 # Traverse some subset of this substructure. 1590 assert_shallow_structure(to_traverse, structure, 1591 expand_composites=expand_composites) 1592 for t, branch in zip(_yield_value(to_traverse), 1593 _yield_value(structure)): 1594 if not isinstance(t, bool): 1595 raise TypeError( 1596 "traverse_fn didn't return a depth=1 structure of bools. saw: %s " 1597 " for structure: %s" % (to_traverse, structure)) 1598 if t: 1599 level_traverse.append( 1600 get_traverse_shallow_structure(traverse_fn, branch)) 1601 else: 1602 level_traverse.append(False) 1603 return _sequence_like(structure, level_traverse) 1604 1605 1606@tf_export("__internal__.nest.yield_flat_paths", v1=[]) 1607def yield_flat_paths(nest, expand_composites=False): 1608 """Yields paths for some nested structure. 1609 1610 Refer to [tf.nest](https://www.tensorflow.org/api_docs/python/tf/nest) 1611 for the definition of a structure. 1612 1613 Paths are lists of objects which can be str-converted, which may include 1614 integers or other types which are used as indices in a dict. 1615 1616 The flat list will be in the corresponding order as if you called 1617 `nest.flatten` on the structure. This is handy for naming Tensors such 1618 the TF scope structure matches the tuple structure. 1619 1620 E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` 1621 1622 ```shell 1623 nest.flatten(value) 1624 [3, 23, 42] 1625 list(nest.yield_flat_paths(value)) 1626 [('a',), ('b', 'c'), ('b', 'd')] 1627 ``` 1628 1629 ```shell 1630 list(nest.yield_flat_paths({'a': [3]})) 1631 [('a', 0)] 1632 list(nest.yield_flat_paths({'a': 3})) 1633 [('a',)] 1634 ``` 1635 1636 Args: 1637 nest: the value to produce a flattened paths list for. 1638 expand_composites: If true, then composite tensors such as 1639 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1640 component tensors. 1641 1642 Yields: 1643 Tuples containing index or key values which form the path to a specific 1644 leaf value in the nested structure. 1645 """ 1646 is_nested_fn = _is_nested_or_composite if expand_composites else _is_nested 1647 for k, _ in _yield_flat_up_to(nest, nest, is_nested_fn): 1648 yield k 1649 1650 1651def flatten_with_joined_string_paths(structure, separator="/", 1652 expand_composites=False): 1653 """Returns a list of (string path, atom) tuples. 1654 1655 The order of tuples produced matches that of `nest.flatten`. This allows you 1656 to flatten a nested structure while keeping information about where in the 1657 structure each atom was located. See `nest.yield_flat_paths` 1658 for more information. 1659 1660 Args: 1661 structure: the nested structure to flatten. 1662 separator: string to separate levels of hierarchy in the results, defaults 1663 to '/'. 1664 expand_composites: If true, then composite tensors such as 1665 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1666 component tensors. 1667 1668 Returns: 1669 A list of (string, atom) tuples. 1670 """ 1671 flat_paths = yield_flat_paths(structure, expand_composites=expand_composites) 1672 def stringify_and_join(path_elements): 1673 return separator.join(str(path_element) for path_element in path_elements) 1674 1675 flat_string_paths = (stringify_and_join(path) for path in flat_paths) 1676 return list(zip(flat_string_paths, 1677 flatten(structure, expand_composites=expand_composites))) 1678 1679 1680def flatten_with_tuple_paths(structure, expand_composites=False): 1681 """Returns a list of `(tuple_path, atom)` tuples. 1682 1683 The order of pairs produced matches that of `nest.flatten`. This allows you 1684 to flatten a nested structure while keeping information about where in the 1685 structure each atom was located. See `nest.yield_flat_paths` 1686 for more information about tuple paths. 1687 1688 Args: 1689 structure: the nested structure to flatten. 1690 expand_composites: If true, then composite tensors such as 1691 `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their 1692 component tensors. 1693 1694 Returns: 1695 A list of `(tuple_path, atom)` tuples. Each `tuple_path` is a tuple 1696 of indices and/or dictionary keys that uniquely specify the path to 1697 `atom` within `structure`. 1698 """ 1699 return list(zip(yield_flat_paths(structure, 1700 expand_composites=expand_composites), 1701 flatten(structure, expand_composites=expand_composites))) 1702 1703 1704@tf_export("__internal__.nest.list_to_tuple", v1=[]) 1705def list_to_tuple(structure): 1706 """Replace all lists with tuples. 1707 1708 The fork of nest that tf.data uses treats lists as atoms, while 1709 tf.nest treats them as structures to recurse into. Keras has chosen to adopt 1710 the latter convention, and must therefore deeply replace all lists with tuples 1711 before passing structures to Dataset.from_generator. 1712 1713 Args: 1714 structure: A nested structure to be remapped. 1715 1716 Returns: 1717 structure mapped to replace all lists with tuples. 1718 """ 1719 def sequence_fn(instance, args): 1720 if isinstance(instance, list): 1721 return tuple(args) 1722 return _sequence_like(instance, args) 1723 1724 return _pack_sequence_as(structure, flatten(structure), False, 1725 sequence_fn=sequence_fn) 1726 1727 1728_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping) 1729_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping) 1730_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence) 1731_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView) 1732_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy) 1733