xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/dumping_callback_test_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shared library for testing tfdbg v2 dumping callback."""
16
17import os
18import shutil
19import tempfile
20import uuid
21
22from tensorflow.python.debug.lib import check_numerics_callback
23from tensorflow.python.debug.lib import debug_events_reader
24from tensorflow.python.debug.lib import dumping_callback
25from tensorflow.python.framework import test_util
26from tensorflow.python.framework import versions
27
28
29class DumpingCallbackTestBase(test_util.TensorFlowTestCase):
30  """Base test-case class for tfdbg v2 callbacks."""
31
32  def setUp(self):
33    super(DumpingCallbackTestBase, self).setUp()
34    self.dump_root = tempfile.mkdtemp()
35    self.tfdbg_run_id = str(uuid.uuid4())
36
37  def tearDown(self):
38    if os.path.isdir(self.dump_root):
39      shutil.rmtree(self.dump_root, ignore_errors=True)
40    check_numerics_callback.disable_check_numerics()
41    dumping_callback.disable_dump_debug_info()
42    super(DumpingCallbackTestBase, self).tearDown()
43
44  def _readAndCheckMetadataFile(self):
45    """Read and check the .metadata debug-events file."""
46    with debug_events_reader.DebugEventsReader(self.dump_root) as reader:
47      self.assertTrue(reader.tfdbg_run_id())
48      self.assertEqual(reader.tensorflow_version(), versions.__version__)
49      self.assertTrue(reader.tfdbg_file_version().startswith("debug.Event"))
50