xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/util/traverse.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to traverse the Dataset dependency structure."""
16import queue
17
18from tensorflow.python.framework import dtypes
19
20
21OP_TYPES_ALLOWLIST = ["DummyIterationCounter"]
22# We allowlist all ops that produce variant tensors as output. This is a bit
23# of overkill but the other dataset _inputs() traversal strategies can't
24# cover the case of function inputs that capture dataset variants.
25TENSOR_TYPES_ALLOWLIST = [dtypes.variant]
26
27
28def _traverse(dataset, op_filter_fn):
29  """Traverse a dataset graph, returning nodes matching `op_filter_fn`."""
30  result = []
31  bfs_q = queue.Queue()
32  bfs_q.put(dataset._variant_tensor.op)  # pylint: disable=protected-access
33  visited = []
34  while not bfs_q.empty():
35    op = bfs_q.get()
36    visited.append(op)
37    if op_filter_fn(op):
38      result.append(op)
39    for i in op.inputs:
40      input_op = i.op
41      if input_op not in visited:
42        bfs_q.put(input_op)
43  return result
44
45
46def obtain_capture_by_value_ops(dataset):
47  """Given an input dataset, finds all allowlisted ops used for construction.
48
49  Allowlisted ops are stateful ops which are known to be safe to capture by
50  value.
51
52  Args:
53    dataset: Dataset to find allowlisted stateful ops for.
54
55  Returns:
56    A list of variant_tensor producing dataset ops used to construct this
57    dataset.
58  """
59
60  def capture_by_value(op):
61    return (op.outputs[0].dtype in TENSOR_TYPES_ALLOWLIST or
62            op.type in OP_TYPES_ALLOWLIST)
63
64  return _traverse(dataset, capture_by_value)
65
66
67def obtain_all_variant_tensor_ops(dataset):
68  """Given an input dataset, finds all dataset ops used for construction.
69
70  A series of transformations would have created this dataset with each
71  transformation including zero or more Dataset ops, each producing a dataset
72  variant tensor. This method outputs all of them.
73
74  Args:
75    dataset: Dataset to find variant tensors for.
76
77  Returns:
78    A list of variant_tensor producing dataset ops used to construct this
79    dataset.
80  """
81  return _traverse(dataset, lambda op: op.outputs[0].dtype == dtypes.variant)
82