xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/tensor_utils_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1"""Tests for tensor_utils."""
2
3from absl.testing import absltest
4from absl.testing import parameterized
5
6import tensorflow as tf
7
8from google.protobuf import any_pb2
9from fcp.artifact_building import tensor_utils
10
11
12class TensorUtilsTest(parameterized.TestCase, tf.test.TestCase):
13
14  def test_bare_name(self):
15    self.assertEqual(tensor_utils.bare_name('foo'), 'foo')
16    self.assertEqual(tensor_utils.bare_name('foo:0'), 'foo')
17    self.assertEqual(tensor_utils.bare_name('foo:1'), 'foo')
18    self.assertEqual(tensor_utils.bare_name('^foo:1'), 'foo')
19    self.assertEqual(tensor_utils.bare_name('^foo:output:2'), 'foo')
20    with tf.Graph().as_default() as g:
21      v = tf.Variable(0.0, name='foo')
22      self.assertEqual(tensor_utils.bare_name(v), 'foo')
23
24      @tf.function
25      def foo(x):
26        return tf.add(x, v.read_value(), 'add_op')
27
28      foo(tf.constant(1.0))
29
30    # Exchange the input tensor names (the outputs of other nodes) in the graph
31    # to ensure we can recover the original user-specified bare names.
32    graph_def = g.as_graph_def()
33    # Test that the graph def contains
34    graph_def_str = str(graph_def)
35    self.assertIn('add_op:z:0', graph_def_str)
36    self.assertIn('Read/ReadVariableOp:value:0', graph_def_str)
37    # Ensure that we can locate
38    required_names = ['add_op', 'Read/ReadVariableOp']
39    for node in graph_def.library.function[0].node_def:
40      for i in node.input:
41        if tensor_utils.bare_name(i) in required_names:
42          required_names.remove(tensor_utils.bare_name(i))
43    self.assertEmpty(required_names)
44
45  def test_bare_name_with_scope(self):
46    self.assertEqual(tensor_utils.bare_name('bar/foo:1'), 'bar/foo')
47
48    with tf.Graph().as_default():
49      with tf.compat.v1.variable_scope('bar'):
50        v = tf.Variable(0.0, name='foo')
51      self.assertEqual(tensor_utils.bare_name(v), 'bar/foo')
52
53  def test_name_or_str_with_named_variable(self):
54    with tf.Graph().as_default():
55      v = tf.Variable(0.0, name='foo')
56      self.assertEqual('foo:0', tensor_utils.name_or_str(v))
57
58  def test_name_or_str_with_unnamed_variable(self):
59    with tf.Graph().as_default():
60      v = tf.Variable(0.0)
61      self.assertEqual('Variable:0', tensor_utils.name_or_str(v))
62
63  def test_import_graph_def_from_any(self):
64    with tf.Graph().as_default() as g:
65      tf.constant(0.0)
66      graph_def = g.as_graph_def()
67    graph_def_any = any_pb2.Any()
68    graph_def_any.Pack(graph_def)
69    # Graph object doesn't have equality, so we check that the graph defs match.
70    self.assertEqual(
71        tensor_utils.import_graph_def_from_any(graph_def_any), g.as_graph_def()
72    )
73
74  def test_save_and_restore_in_eager_mode(self):
75    filename = tf.constant(self.create_tempfile().full_path)
76    tensor_name = 'a'
77    tensor = tf.constant(1.0)
78    tensor_utils.save(filename, [tensor_name], [tensor])
79    restored_tensor = tensor_utils.restore(filename, tensor_name, tensor.dtype)
80    self.assertAllEqual(tensor, restored_tensor)
81
82  @parameterized.named_parameters(
83      ('scalar_tensor', tf.constant(1.0)),
84      ('non_scalar_tensor', tf.constant([1.0, 2.0])),
85  )
86  def test_save_and_restore_with_shape_info_in_eager_mode(self, tensor):
87    filename = tf.constant(self.create_tempfile().full_path)
88    tensor_name = 'a'
89    tensor_utils.save(filename, [tensor_name], [tensor])
90    restored_tensor = tensor_utils.restore(
91        filename, tensor_name, tensor.dtype, tensor.shape
92    )
93    self.assertAllEqual(tensor, restored_tensor)
94
95  def _assert_op_in_graph(self, expected_op, graph):
96    graph_def = graph.as_graph_def()
97    node_ops = [node.op for node in graph_def.node]
98    self.assertIn(expected_op, node_ops)
99
100  def _get_shape_and_slices_value(self, graph):
101    graph_def = graph.as_graph_def()
102    node_name_to_value_dict = {node.name: node for node in graph_def.node}
103    self.assertIn('restore/shape_and_slices', node_name_to_value_dict)
104    return (
105        node_name_to_value_dict['restore/shape_and_slices']
106        .attr['value']
107        .tensor.string_val[0]
108    )
109
110  def test_save_and_restore_in_graph_mode(self):
111    temp_file = self.create_tempfile().full_path
112    graph = tf.Graph()
113    with graph.as_default():
114      filename = tf.constant(temp_file)
115      tensor_name = 'a'
116      tensor = tf.constant(1.0)
117      save_op = tensor_utils.save(filename, [tensor_name], [tensor])
118      restored = tensor_utils.restore(filename, tensor_name, tensor.dtype)
119    with tf.compat.v1.Session(graph=graph) as sess:
120      sess.run(save_op)
121      expected_tensor, restored_tensor = sess.run([tensor, restored])
122      self.assertAllEqual(expected_tensor, restored_tensor)
123      self._assert_op_in_graph(expected_op='SaveSlices', graph=graph)
124      self._assert_op_in_graph(expected_op='RestoreV2', graph=graph)
125      self.assertEqual(b'', self._get_shape_and_slices_value(graph))
126
127  @parameterized.named_parameters(
128      ('scalar_tensor', lambda: tf.constant(1.0), b''),
129      ('non_scalar_tensor', lambda: tf.constant([1.0, 2.0]), b'2 :-'),
130  )
131  def test_save_and_restore_with_shape_info_in_graph_mode(
132      self, tensor_builder, expected_shape_and_slices_value
133  ):
134    temp_file = self.create_tempfile().full_path
135    graph = tf.Graph()
136    with graph.as_default():
137      filename = tf.constant(temp_file)
138      tensor_name = 'a'
139      tensor = tensor_builder()
140      save_op = tensor_utils.save(filename, [tensor_name], [tensor])
141      restored = tensor_utils.restore(
142          filename, tensor_name, tensor.dtype, tensor.shape
143      )
144    with tf.compat.v1.Session(graph=graph) as sess:
145      sess.run(save_op)
146      expected_tensor, restored_tensor = sess.run([tensor, restored])
147      self.assertAllEqual(expected_tensor, restored_tensor)
148      self._assert_op_in_graph(expected_op='SaveSlices', graph=graph)
149      self._assert_op_in_graph(expected_op='RestoreV2', graph=graph)
150      self.assertEqual(
151          expected_shape_and_slices_value,
152          self._get_shape_and_slices_value(graph),
153      )
154
155
156if __name__ == '__main__':
157  absltest.main()
158