1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utility.""" 16 17import numpy as np 18 19from tensorflow.python.ops import variables 20from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops 21from tensorflow.python.platform import test 22from tensorflow.python.util import nest 23 24 25class PForTestCase(test.TestCase): 26 """Base class for test cases.""" 27 28 def _run_targets(self, targets1, targets2=None, run_init=True): 29 targets1 = nest.flatten(targets1) 30 targets2 = ([] if targets2 is None else nest.flatten(targets2)) 31 assert len(targets1) == len(targets2) or not targets2 32 if run_init: 33 init = variables.global_variables_initializer() 34 self.evaluate(init) 35 return self.evaluate(targets1 + targets2) 36 37 # TODO(agarwal): Allow tests to pass down tolerances. 38 def run_and_assert_equal(self, targets1, targets2, rtol=1e-4, atol=1e-5): 39 outputs = self._run_targets(targets1, targets2) 40 outputs = nest.flatten(outputs) # flatten SparseTensorValues 41 n = len(outputs) // 2 42 for i in range(n): 43 if outputs[i + n].dtype != np.object_: 44 self.assertAllClose(outputs[i + n], outputs[i], rtol=rtol, atol=atol) 45 else: 46 self.assertAllEqual(outputs[i + n], outputs[i]) 47 48 def _test_loop_fn(self, 49 loop_fn, 50 iters, 51 parallel_iterations=None, 52 fallback_to_while_loop=False, 53 rtol=1e-4, 54 atol=1e-5): 55 t1 = pfor_control_flow_ops.pfor( 56 loop_fn, 57 iters=iters, 58 fallback_to_while_loop=fallback_to_while_loop, 59 parallel_iterations=parallel_iterations) 60 loop_fn_dtypes = nest.map_structure(lambda x: x.dtype, t1) 61 t2 = pfor_control_flow_ops.for_loop(loop_fn, loop_fn_dtypes, iters=iters, 62 parallel_iterations=parallel_iterations) 63 self.run_and_assert_equal(t1, t2, rtol=rtol, atol=atol) 64