1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to traverse the Dataset dependency structure.""" 16import queue 17 18from tensorflow.python.framework import dtypes 19 20 21OP_TYPES_ALLOWLIST = ["DummyIterationCounter"] 22# We allowlist all ops that produce variant tensors as output. This is a bit 23# of overkill but the other dataset _inputs() traversal strategies can't 24# cover the case of function inputs that capture dataset variants. 25TENSOR_TYPES_ALLOWLIST = [dtypes.variant] 26 27 28def _traverse(dataset, op_filter_fn): 29 """Traverse a dataset graph, returning nodes matching `op_filter_fn`.""" 30 result = [] 31 bfs_q = queue.Queue() 32 bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access 33 visited = [] 34 while not bfs_q.empty(): 35 op = bfs_q.get() 36 visited.append(op) 37 if op_filter_fn(op): 38 result.append(op) 39 for i in op.inputs: 40 input_op = i.op 41 if input_op not in visited: 42 bfs_q.put(input_op) 43 return result 44 45 46def obtain_capture_by_value_ops(dataset): 47 """Given an input dataset, finds all allowlisted ops used for construction. 48 49 Allowlisted ops are stateful ops which are known to be safe to capture by 50 value. 51 52 Args: 53 dataset: Dataset to find allowlisted stateful ops for. 54 55 Returns: 56 A list of variant_tensor producing dataset ops used to construct this 57 dataset. 58 """ 59 60 def capture_by_value(op): 61 return (op.outputs[0].dtype in TENSOR_TYPES_ALLOWLIST or 62 op.type in OP_TYPES_ALLOWLIST) 63 64 return _traverse(dataset, capture_by_value) 65 66 67def obtain_all_variant_tensor_ops(dataset): 68 """Given an input dataset, finds all dataset ops used for construction. 69 70 A series of transformations would have created this dataset with each 71 transformation including zero or more Dataset ops, each producing a dataset 72 variant tensor. This method outputs all of them. 73 74 Args: 75 dataset: Dataset to find variant tensors for. 76 77 Returns: 78 A list of variant_tensor producing dataset ops used to construct this 79 dataset. 80 """ 81 return _traverse(dataset, lambda op: op.outputs[0].dtype == dtypes.variant) 82