xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/common_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for common values and methods of TensorFlow Debugger."""
16import json
17
18from tensorflow.python.debug.lib import common
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import test_util
21from tensorflow.python.platform import googletest
22
23
24class CommonTest(test_util.TensorFlowTestCase):
25
26  @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2")
27  def testOnFeedOneFetch(self):
28    a = constant_op.constant(10.0, name="a")
29    b = constant_op.constant(20.0, name="b")
30    run_key = common.get_run_key({"a": a}, [b])
31    loaded = json.loads(run_key)
32    self.assertItemsEqual(["a:0"], loaded[0])
33    self.assertItemsEqual(["b:0"], loaded[1])
34
35  @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2")
36  def testGetRunKeyFlat(self):
37    a = constant_op.constant(10.0, name="a")
38    b = constant_op.constant(20.0, name="b")
39    run_key = common.get_run_key({"a": a}, [a, b])
40    loaded = json.loads(run_key)
41    self.assertItemsEqual(["a:0"], loaded[0])
42    self.assertItemsEqual(["a:0", "b:0"], loaded[1])
43
44  @test_util.run_v1_only("Relies on tensor name, which is unavailable in TF2")
45  def testGetRunKeyNestedFetches(self):
46    a = constant_op.constant(10.0, name="a")
47    b = constant_op.constant(20.0, name="b")
48    c = constant_op.constant(30.0, name="c")
49    d = constant_op.constant(30.0, name="d")
50    run_key = common.get_run_key(
51        {}, {"set1": [a, b], "set2": {"c": c, "d": d}})
52    loaded = json.loads(run_key)
53    self.assertItemsEqual([], loaded[0])
54    self.assertItemsEqual(["a:0", "b:0", "c:0", "d:0"], loaded[1])
55
56
57if __name__ == "__main__":
58  googletest.main()
59