1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental API for testing of tf.data.""" 16from google.protobuf import text_format 17from tensorflow.core.framework import attr_value_pb2 18from tensorflow.python.data.ops import dataset_ops 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import gen_experimental_dataset_ops 22 23 24def assert_next(transformations): 25 """A transformation that asserts which transformations happen next. 26 27 Transformations should be referred to by their base name, not including 28 version suffix. For example, use "Batch" instead of "BatchV2". "Batch" will 29 match any of "Batch", "BatchV1", "BatchV2", etc. 30 31 Args: 32 transformations: A `tf.string` vector `tf.Tensor` identifying the 33 transformations that are expected to happen next. 34 35 Returns: 36 A `Dataset` transformation function, which can be passed to 37 `tf.data.Dataset.apply`. 38 """ 39 40 def _apply_fn(dataset): 41 """Function from `Dataset` to `Dataset` that applies the transformation.""" 42 return _AssertNextDataset(dataset, transformations) 43 44 return _apply_fn 45 46 47def assert_prev(transformations): 48 r"""Asserts which transformations, with which attributes, happened previously. 49 50 Each transformation is repesented as a tuple in the input. 51 52 The first element is the base op name of the transformation, not including 53 version suffix. For example, use "BatchDataset" instead of 54 "BatchDatasetV2". "BatchDataset" will match any of "BatchDataset", 55 "BatchDatasetV1", "BatchDatasetV2", etc. 56 57 The second element is a dict of attribute name-value pairs. Attributes 58 values must be of type bool, int, or string. 59 60 Example usage: 61 62 >>> dataset_ops.Dataset.from_tensors(0) \ 63 ... .map(lambda x: x) \ 64 ... .batch(1, deterministic=True, num_parallel_calls=8) \ 65 ... .assert_prev([("ParallelBatchDataset", {"deterministic": True}), \ 66 ... ("MapDataset", {})]) 67 68 Args: 69 transformations: A list of tuples identifying the (required) transformation 70 name, with (optional) attribute name-value pairs, that are expected to 71 have happened previously. 72 73 Returns: 74 A `Dataset` transformation function, which can be passed to 75 `tf.data.Dataset.apply`. 76 """ 77 78 def _apply_fn(dataset): 79 """Function from `Dataset` to `Dataset` that applies the transformation.""" 80 return _AssertPrevDataset(dataset, transformations) 81 82 return _apply_fn 83 84 85def non_serializable(): 86 """A non-serializable identity transformation. 87 88 Returns: 89 A `Dataset` transformation function, which can be passed to 90 `tf.data.Dataset.apply`. 91 """ 92 93 def _apply_fn(dataset): 94 """Function from `Dataset` to `Dataset` that applies the transformation.""" 95 return _NonSerializableDataset(dataset) 96 97 return _apply_fn 98 99 100def sleep(sleep_microseconds): 101 """Sleeps for `sleep_microseconds` before producing each input element. 102 103 Args: 104 sleep_microseconds: The number of microseconds to sleep before producing an 105 input element. 106 107 Returns: 108 A `Dataset` transformation function, which can be passed to 109 `tf.data.Dataset.apply`. 110 """ 111 112 def _apply_fn(dataset): 113 return _SleepDataset(dataset, sleep_microseconds) 114 115 return _apply_fn 116 117 118class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset): 119 """A `Dataset` that asserts which transformations happen next.""" 120 121 def __init__(self, input_dataset, transformations): 122 """See `assert_next()` for details.""" 123 self._input_dataset = input_dataset 124 if transformations is None: 125 raise ValueError( 126 "Invalid `transformations`. `transformations` should not be empty.") 127 128 self._transformations = ops.convert_to_tensor( 129 transformations, dtype=dtypes.string, name="transformations") 130 variant_tensor = ( 131 gen_experimental_dataset_ops.experimental_assert_next_dataset( 132 self._input_dataset._variant_tensor, # pylint: disable=protected-access 133 self._transformations, 134 **self._flat_structure)) 135 super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor) 136 137 138class _AssertPrevDataset(dataset_ops.UnaryUnchangedStructureDataset): 139 """A `Dataset` that asserts which transformations happened previously.""" 140 141 def __init__(self, input_dataset, transformations): 142 """See `assert_prev()` for details.""" 143 self._input_dataset = input_dataset 144 if transformations is None: 145 raise ValueError("`transformations` cannot be empty") 146 147 def serialize_transformation(op_name, attributes): 148 proto = attr_value_pb2.NameAttrList(name=op_name) 149 if attributes is None or isinstance(attributes, set): 150 attributes = dict() 151 for (name, value) in attributes.items(): 152 if isinstance(value, bool): 153 proto.attr[name].b = value 154 elif isinstance(value, int): 155 proto.attr[name].i = value 156 elif isinstance(value, str): 157 proto.attr[name].s = value.encode() 158 else: 159 raise ValueError( 160 f"attribute value type ({type(value)}) must be bool, int, or str") 161 return text_format.MessageToString(proto) 162 163 self._transformations = ops.convert_to_tensor( 164 [serialize_transformation(*x) for x in transformations], 165 dtype=dtypes.string, 166 name="transformations") 167 variant_tensor = ( 168 gen_experimental_dataset_ops.assert_prev_dataset( 169 self._input_dataset._variant_tensor, # pylint: disable=protected-access 170 self._transformations, 171 **self._flat_structure)) 172 super(_AssertPrevDataset, self).__init__(input_dataset, variant_tensor) 173 174 175class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset): 176 """A `Dataset` that performs non-serializable identity transformation.""" 177 178 def __init__(self, input_dataset): 179 """See `non_serializable()` for details.""" 180 self._input_dataset = input_dataset 181 variant_tensor = ( 182 gen_experimental_dataset_ops.experimental_non_serializable_dataset( 183 self._input_dataset._variant_tensor, # pylint: disable=protected-access 184 **self._flat_structure)) 185 super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor) 186 187 188class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset): 189 """A `Dataset` that sleeps before producing each upstream element.""" 190 191 def __init__(self, input_dataset, sleep_microseconds): 192 self._input_dataset = input_dataset 193 self._sleep_microseconds = sleep_microseconds 194 variant_tensor = gen_experimental_dataset_ops.sleep_dataset( 195 self._input_dataset._variant_tensor, # pylint: disable=protected-access 196 self._sleep_microseconds, 197 **self._flat_structure) 198 super(_SleepDataset, self).__init__(input_dataset, variant_tensor) 199