xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/experimental/ops/testing.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental API for testing of tf.data."""
16from google.protobuf import text_format
17from tensorflow.core.framework import attr_value_pb2
18from tensorflow.python.data.ops import dataset_ops
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import gen_experimental_dataset_ops
22
23
24def assert_next(transformations):
25  """A transformation that asserts which transformations happen next.
26
27  Transformations should be referred to by their base name, not including
28  version suffix. For example, use "Batch" instead of "BatchV2". "Batch" will
29  match any of "Batch", "BatchV1", "BatchV2", etc.
30
31  Args:
32    transformations: A `tf.string` vector `tf.Tensor` identifying the
33      transformations that are expected to happen next.
34
35  Returns:
36    A `Dataset` transformation function, which can be passed to
37    `tf.data.Dataset.apply`.
38  """
39
40  def _apply_fn(dataset):
41    """Function from `Dataset` to `Dataset` that applies the transformation."""
42    return _AssertNextDataset(dataset, transformations)
43
44  return _apply_fn
45
46
47def assert_prev(transformations):
48  r"""Asserts which transformations, with which attributes, happened previously.
49
50    Each transformation is repesented as a tuple in the input.
51
52    The first element is the base op name of the transformation, not including
53    version suffix.  For example, use "BatchDataset" instead of
54    "BatchDatasetV2".  "BatchDataset" will match any of "BatchDataset",
55    "BatchDatasetV1", "BatchDatasetV2", etc.
56
57    The second element is a dict of attribute name-value pairs.  Attributes
58    values must be of type bool, int, or string.
59
60    Example usage:
61
62    >>> dataset_ops.Dataset.from_tensors(0) \
63    ... .map(lambda x: x) \
64    ... .batch(1, deterministic=True, num_parallel_calls=8) \
65    ... .assert_prev([("ParallelBatchDataset", {"deterministic": True}), \
66    ...               ("MapDataset", {})])
67
68  Args:
69    transformations: A list of tuples identifying the (required) transformation
70      name, with (optional) attribute name-value pairs, that are expected to
71      have happened previously.
72
73  Returns:
74    A `Dataset` transformation function, which can be passed to
75    `tf.data.Dataset.apply`.
76  """
77
78  def _apply_fn(dataset):
79    """Function from `Dataset` to `Dataset` that applies the transformation."""
80    return _AssertPrevDataset(dataset, transformations)
81
82  return _apply_fn
83
84
85def non_serializable():
86  """A non-serializable identity transformation.
87
88  Returns:
89    A `Dataset` transformation function, which can be passed to
90    `tf.data.Dataset.apply`.
91  """
92
93  def _apply_fn(dataset):
94    """Function from `Dataset` to `Dataset` that applies the transformation."""
95    return _NonSerializableDataset(dataset)
96
97  return _apply_fn
98
99
100def sleep(sleep_microseconds):
101  """Sleeps for `sleep_microseconds` before producing each input element.
102
103  Args:
104    sleep_microseconds: The number of microseconds to sleep before producing an
105      input element.
106
107  Returns:
108    A `Dataset` transformation function, which can be passed to
109    `tf.data.Dataset.apply`.
110  """
111
112  def _apply_fn(dataset):
113    return _SleepDataset(dataset, sleep_microseconds)
114
115  return _apply_fn
116
117
118class _AssertNextDataset(dataset_ops.UnaryUnchangedStructureDataset):
119  """A `Dataset` that asserts which transformations happen next."""
120
121  def __init__(self, input_dataset, transformations):
122    """See `assert_next()` for details."""
123    self._input_dataset = input_dataset
124    if transformations is None:
125      raise ValueError(
126          "Invalid `transformations`. `transformations` should not be empty.")
127
128    self._transformations = ops.convert_to_tensor(
129        transformations, dtype=dtypes.string, name="transformations")
130    variant_tensor = (
131        gen_experimental_dataset_ops.experimental_assert_next_dataset(
132            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
133            self._transformations,
134            **self._flat_structure))
135    super(_AssertNextDataset, self).__init__(input_dataset, variant_tensor)
136
137
138class _AssertPrevDataset(dataset_ops.UnaryUnchangedStructureDataset):
139  """A `Dataset` that asserts which transformations happened previously."""
140
141  def __init__(self, input_dataset, transformations):
142    """See `assert_prev()` for details."""
143    self._input_dataset = input_dataset
144    if transformations is None:
145      raise ValueError("`transformations` cannot be empty")
146
147    def serialize_transformation(op_name, attributes):
148      proto = attr_value_pb2.NameAttrList(name=op_name)
149      if attributes is None or isinstance(attributes, set):
150        attributes = dict()
151      for (name, value) in attributes.items():
152        if isinstance(value, bool):
153          proto.attr[name].b = value
154        elif isinstance(value, int):
155          proto.attr[name].i = value
156        elif isinstance(value, str):
157          proto.attr[name].s = value.encode()
158        else:
159          raise ValueError(
160              f"attribute value type ({type(value)}) must be bool, int, or str")
161      return text_format.MessageToString(proto)
162
163    self._transformations = ops.convert_to_tensor(
164        [serialize_transformation(*x) for x in transformations],
165        dtype=dtypes.string,
166        name="transformations")
167    variant_tensor = (
168        gen_experimental_dataset_ops.assert_prev_dataset(
169            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
170            self._transformations,
171            **self._flat_structure))
172    super(_AssertPrevDataset, self).__init__(input_dataset, variant_tensor)
173
174
175class _NonSerializableDataset(dataset_ops.UnaryUnchangedStructureDataset):
176  """A `Dataset` that performs non-serializable identity transformation."""
177
178  def __init__(self, input_dataset):
179    """See `non_serializable()` for details."""
180    self._input_dataset = input_dataset
181    variant_tensor = (
182        gen_experimental_dataset_ops.experimental_non_serializable_dataset(
183            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
184            **self._flat_structure))
185    super(_NonSerializableDataset, self).__init__(input_dataset, variant_tensor)
186
187
188class _SleepDataset(dataset_ops.UnaryUnchangedStructureDataset):
189  """A `Dataset` that sleeps before producing each upstream element."""
190
191  def __init__(self, input_dataset, sleep_microseconds):
192    self._input_dataset = input_dataset
193    self._sleep_microseconds = sleep_microseconds
194    variant_tensor = gen_experimental_dataset_ops.sleep_dataset(
195        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
196        self._sleep_microseconds,
197        **self._flat_structure)
198    super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
199