xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/parallel_for/test_util.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test utility."""
16
17import numpy as np
18
19from tensorflow.python.ops import variables
20from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
21from tensorflow.python.platform import test
22from tensorflow.python.util import nest
23
24
25class PForTestCase(test.TestCase):
26  """Base class for test cases."""
27
28  def _run_targets(self, targets1, targets2=None, run_init=True):
29    targets1 = nest.flatten(targets1)
30    targets2 = ([] if targets2 is None else nest.flatten(targets2))
31    assert len(targets1) == len(targets2) or not targets2
32    if run_init:
33      init = variables.global_variables_initializer()
34      self.evaluate(init)
35    return self.evaluate(targets1 + targets2)
36
37  # TODO(agarwal): Allow tests to pass down tolerances.
38  def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5):
39    outputs = self._run_targets(targets1, targets2)
40    outputs = nest.flatten(outputs)  # flatten SparseTensorValues
41    n = len(outputs) // 2
42    for i in range(n):
43      if outputs[i + n].dtype != np.object_:
44        self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol)
45      else:
46        self.assertAllEqual(outputs[i + n], outputs[i])
47
48  def _test_loop_fn(self,
49                    loop_fn,
50                    iters,
51                    parallel_iterations=None,
52                    fallback_to_while_loop=False,
53                    rtol=1e-4,
54                    atol=1e-5):
55    t1 = pfor_control_flow_ops.pfor(
56        loop_fn,
57        iters=iters,
58        fallback_to_while_loop=fallback_to_while_loop,
59        parallel_iterations=parallel_iterations)
60    loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1)
61    t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters,
62                                        parallel_iterations=parallel_iterations)
63    self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol)
64