xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/benchmarks/benchmark_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utilities for tf.data benchmarking functionality."""
16import time
17
18import numpy as np
19
20from tensorflow.python.client import session
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.data.ops import options as options_lib
23from tensorflow.python.data.util import nest
24from tensorflow.python.eager import context
25from tensorflow.python.platform import test
26
27
28class DatasetBenchmarkBase(test.Benchmark):
29  """Base class for dataset benchmarks."""
30
31  def _run_eager_benchmark(self, iterable, iters, warmup):
32    """Benchmark the iterable in eager mode.
33
34    Runs the iterable `iters` times. In each iteration, the benchmark measures
35    the time it takes to go execute the iterable.
36
37    Args:
38      iterable: The tf op or tf.data Dataset to benchmark.
39      iters: Number of times to repeat the timing.
40      warmup: If true, warms up the session caches by running an untimed run.
41
42    Returns:
43      A float, representing the median time (with respect to `iters`)
44      it takes for the iterable to be executed `iters` num of times.
45
46    Raises:
47      RuntimeError: When executed in graph mode.
48    """
49
50    deltas = []
51    if not context.executing_eagerly():
52      raise RuntimeError(
53          "Eager mode benchmarking is not supported in graph mode.")
54
55    for _ in range(iters):
56      if warmup:
57        iterator = iter(iterable)
58        next(iterator)
59
60      iterator = iter(iterable)
61      start = time.time()
62      next(iterator)
63      end = time.time()
64      deltas.append(end - start)
65    return np.median(deltas)
66
67  def _run_graph_benchmark(self,
68                           iterable,
69                           iters,
70                           warmup,
71                           session_config,
72                           initializer=None):
73    """Benchmarks the iterable in graph mode.
74
75    Runs the iterable `iters` times. In each iteration, the benchmark measures
76    the time it takes to go execute the iterable.
77
78    Args:
79      iterable: The tf op or tf.data Dataset to benchmark.
80      iters: Number of times to repeat the timing.
81      warmup: If true, warms up the session caches by running an untimed run.
82      session_config: A ConfigProto protocol buffer with configuration options
83        for the session. Applicable only for benchmarking in graph mode.
84      initializer: The initializer op required to initialize the iterable.
85
86    Returns:
87      A float, representing the median time (with respect to `iters`)
88      it takes for the iterable to be executed `iters` num of times.
89
90    Raises:
91      RuntimeError: When executed in eager mode.
92    """
93
94    deltas = []
95    if context.executing_eagerly():
96      raise RuntimeError(
97          "Graph mode benchmarking is not supported in eager mode.")
98
99    for _ in range(iters):
100      with session.Session(config=session_config) as sess:
101        if warmup:
102          # Run once to warm up the session caches.
103          if initializer:
104            sess.run(initializer)
105          sess.run(iterable)
106
107        if initializer:
108          sess.run(initializer)
109        start = time.time()
110        sess.run(iterable)
111        end = time.time()
112      deltas.append(end - start)
113    return np.median(deltas)
114
115  def run_op_benchmark(self, op, iters=1, warmup=True, session_config=None):
116    """Benchmarks the op.
117
118    Runs the op `iters` times. In each iteration, the benchmark measures
119    the time it takes to go execute the op.
120
121    Args:
122      op: The tf op to benchmark.
123      iters: Number of times to repeat the timing.
124      warmup: If true, warms up the session caches by running an untimed run.
125      session_config: A ConfigProto protocol buffer with configuration options
126        for the session. Applicable only for benchmarking in graph mode.
127
128    Returns:
129      A float, representing the per-execution wall time of the op in seconds.
130      This is the median time (with respect to `iters`) it takes for the op
131      to be executed `iters` num of times.
132    """
133
134    if context.executing_eagerly():
135      return self._run_eager_benchmark(iterable=op, iters=iters, warmup=warmup)
136
137    return self._run_graph_benchmark(
138        iterable=op, iters=iters, warmup=warmup, session_config=session_config)
139
140  def run_benchmark(self,
141                    dataset,
142                    num_elements,
143                    iters=1,
144                    warmup=True,
145                    apply_default_optimizations=False,
146                    session_config=None):
147    """Benchmarks the dataset.
148
149    Runs the dataset `iters` times. In each iteration, the benchmark measures
150    the time it takes to go through `num_elements` elements of the dataset.
151
152    Args:
153      dataset: Dataset to benchmark.
154      num_elements: Number of dataset elements to iterate through each benchmark
155        iteration.
156      iters: Number of times to repeat the timing.
157      warmup: If true, warms up the session caches by running an untimed run.
158      apply_default_optimizations: Determines whether default optimizations
159        should be applied.
160      session_config: A ConfigProto protocol buffer with configuration options
161        for the session. Applicable only for benchmarking in graph mode.
162
163    Returns:
164      A float, representing the per-element wall time of the dataset in seconds.
165      This is the median time (with respect to `iters`) it takes for the dataset
166      to go through `num_elements` elements, divided by `num_elements.`
167    """
168
169    # The options that have been applied to the dataset are preserved so that
170    # they are not overwritten while benchmarking.
171    options = options_lib.Options()
172    options.experimental_optimization.apply_default_optimizations = (
173        apply_default_optimizations)
174    dataset = dataset.with_options(options)
175
176    # NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
177    # the overhead of having to execute a TensorFlow op for each step of the
178    # input pipeline. Note that this relies on the underlying implementation of
179    # `skip` to execute upstream computation. If it is optimized in the future,
180    # we will have to change this code.
181    dataset = dataset.skip(num_elements - 1)
182
183    if context.executing_eagerly():
184      median_duration = self._run_eager_benchmark(
185          iterable=dataset, iters=iters, warmup=warmup)
186      return median_duration / float(num_elements)
187
188    iterator = dataset_ops.make_initializable_iterator(dataset)
189    next_element = iterator.get_next()
190    op = nest.flatten(next_element)[0].op
191    median_duration = self._run_graph_benchmark(
192        iterable=op,
193        iters=iters,
194        warmup=warmup,
195        session_config=session_config,
196        initializer=iterator.initializer)
197    return median_duration / float(num_elements)
198
199  def run_and_report_benchmark(self,
200                               dataset,
201                               num_elements,
202                               name,
203                               iters=5,
204                               extras=None,
205                               warmup=True,
206                               apply_default_optimizations=False,
207                               session_config=None):
208    """Benchmarks the dataset and reports the stats.
209
210    Runs the dataset `iters` times. In each iteration, the benchmark measures
211    the time it takes to go through `num_elements` elements of the dataset.
212    This is followed by logging/printing the benchmark stats.
213
214    Args:
215      dataset: Dataset to benchmark.
216      num_elements: Number of dataset elements to iterate through each benchmark
217        iteration.
218      name: Name of the benchmark.
219      iters: Number of times to repeat the timing.
220      extras: A dict which maps string keys to additional benchmark info.
221      warmup: If true, warms up the session caches by running an untimed run.
222      apply_default_optimizations: Determines whether default optimizations
223        should be applied.
224      session_config: A ConfigProto protocol buffer with configuration options
225        for the session. Applicable only for benchmarking in graph mode.
226
227    Returns:
228      A float, representing the per-element wall time of the dataset in seconds.
229      This is the median time (with respect to `iters`) it takes for the dataset
230      to go through `num_elements` elements, divided by `num_elements.`
231    """
232    wall_time = self.run_benchmark(
233        dataset=dataset,
234        num_elements=num_elements,
235        iters=iters,
236        warmup=warmup,
237        apply_default_optimizations=apply_default_optimizations,
238        session_config=session_config)
239    if extras is None:
240      extras = {}
241    if context.executing_eagerly():
242      name = "{}.eager".format(name)
243      extras["implementation"] = "eager"
244    else:
245      name = "{}.graph".format(name)
246      extras["implementation"] = "graph"
247    extras["num_elements"] = num_elements
248    self.report_benchmark(
249        wall_time=wall_time, iters=iters, name=name, extras=extras)
250    return wall_time
251