1"""Tests for tensor_utils.""" 2 3from absl.testing import absltest 4from absl.testing import parameterized 5 6import tensorflow as tf 7 8from google.protobuf import any_pb2 9from fcp.artifact_building import tensor_utils 10 11 12class TensorUtilsTest(parameterized.TestCase, tf.test.TestCase): 13 14 def test_bare_name(self): 15 self.assertEqual(tensor_utils.bare_name('foo'), 'foo') 16 self.assertEqual(tensor_utils.bare_name('foo:0'), 'foo') 17 self.assertEqual(tensor_utils.bare_name('foo:1'), 'foo') 18 self.assertEqual(tensor_utils.bare_name('^foo:1'), 'foo') 19 self.assertEqual(tensor_utils.bare_name('^foo:output:2'), 'foo') 20 with tf.Graph().as_default() as g: 21 v = tf.Variable(0.0, name='foo') 22 self.assertEqual(tensor_utils.bare_name(v), 'foo') 23 24 @tf.function 25 def foo(x): 26 return tf.add(x, v.read_value(), 'add_op') 27 28 foo(tf.constant(1.0)) 29 30 # Exchange the input tensor names (the outputs of other nodes) in the graph 31 # to ensure we can recover the original user-specified bare names. 32 graph_def = g.as_graph_def() 33 # Test that the graph def contains 34 graph_def_str = str(graph_def) 35 self.assertIn('add_op:z:0', graph_def_str) 36 self.assertIn('Read/ReadVariableOp:value:0', graph_def_str) 37 # Ensure that we can locate 38 required_names = ['add_op', 'Read/ReadVariableOp'] 39 for node in graph_def.library.function[0].node_def: 40 for i in node.input: 41 if tensor_utils.bare_name(i) in required_names: 42 required_names.remove(tensor_utils.bare_name(i)) 43 self.assertEmpty(required_names) 44 45 def test_bare_name_with_scope(self): 46 self.assertEqual(tensor_utils.bare_name('bar/foo:1'), 'bar/foo') 47 48 with tf.Graph().as_default(): 49 with tf.compat.v1.variable_scope('bar'): 50 v = tf.Variable(0.0, name='foo') 51 self.assertEqual(tensor_utils.bare_name(v), 'bar/foo') 52 53 def test_name_or_str_with_named_variable(self): 54 with tf.Graph().as_default(): 55 v = tf.Variable(0.0, name='foo') 56 self.assertEqual('foo:0', tensor_utils.name_or_str(v)) 57 58 def test_name_or_str_with_unnamed_variable(self): 59 with tf.Graph().as_default(): 60 v = tf.Variable(0.0) 61 self.assertEqual('Variable:0', tensor_utils.name_or_str(v)) 62 63 def test_import_graph_def_from_any(self): 64 with tf.Graph().as_default() as g: 65 tf.constant(0.0) 66 graph_def = g.as_graph_def() 67 graph_def_any = any_pb2.Any() 68 graph_def_any.Pack(graph_def) 69 # Graph object doesn't have equality, so we check that the graph defs match. 70 self.assertEqual( 71 tensor_utils.import_graph_def_from_any(graph_def_any), g.as_graph_def() 72 ) 73 74 def test_save_and_restore_in_eager_mode(self): 75 filename = tf.constant(self.create_tempfile().full_path) 76 tensor_name = 'a' 77 tensor = tf.constant(1.0) 78 tensor_utils.save(filename, [tensor_name], [tensor]) 79 restored_tensor = tensor_utils.restore(filename, tensor_name, tensor.dtype) 80 self.assertAllEqual(tensor, restored_tensor) 81 82 @parameterized.named_parameters( 83 ('scalar_tensor', tf.constant(1.0)), 84 ('non_scalar_tensor', tf.constant([1.0, 2.0])), 85 ) 86 def test_save_and_restore_with_shape_info_in_eager_mode(self, tensor): 87 filename = tf.constant(self.create_tempfile().full_path) 88 tensor_name = 'a' 89 tensor_utils.save(filename, [tensor_name], [tensor]) 90 restored_tensor = tensor_utils.restore( 91 filename, tensor_name, tensor.dtype, tensor.shape 92 ) 93 self.assertAllEqual(tensor, restored_tensor) 94 95 def _assert_op_in_graph(self, expected_op, graph): 96 graph_def = graph.as_graph_def() 97 node_ops = [node.op for node in graph_def.node] 98 self.assertIn(expected_op, node_ops) 99 100 def _get_shape_and_slices_value(self, graph): 101 graph_def = graph.as_graph_def() 102 node_name_to_value_dict = {node.name: node for node in graph_def.node} 103 self.assertIn('restore/shape_and_slices', node_name_to_value_dict) 104 return ( 105 node_name_to_value_dict['restore/shape_and_slices'] 106 .attr['value'] 107 .tensor.string_val[0] 108 ) 109 110 def test_save_and_restore_in_graph_mode(self): 111 temp_file = self.create_tempfile().full_path 112 graph = tf.Graph() 113 with graph.as_default(): 114 filename = tf.constant(temp_file) 115 tensor_name = 'a' 116 tensor = tf.constant(1.0) 117 save_op = tensor_utils.save(filename, [tensor_name], [tensor]) 118 restored = tensor_utils.restore(filename, tensor_name, tensor.dtype) 119 with tf.compat.v1.Session(graph=graph) as sess: 120 sess.run(save_op) 121 expected_tensor, restored_tensor = sess.run([tensor, restored]) 122 self.assertAllEqual(expected_tensor, restored_tensor) 123 self._assert_op_in_graph(expected_op='SaveSlices', graph=graph) 124 self._assert_op_in_graph(expected_op='RestoreV2', graph=graph) 125 self.assertEqual(b'', self._get_shape_and_slices_value(graph)) 126 127 @parameterized.named_parameters( 128 ('scalar_tensor', lambda: tf.constant(1.0), b''), 129 ('non_scalar_tensor', lambda: tf.constant([1.0, 2.0]), b'2 :-'), 130 ) 131 def test_save_and_restore_with_shape_info_in_graph_mode( 132 self, tensor_builder, expected_shape_and_slices_value 133 ): 134 temp_file = self.create_tempfile().full_path 135 graph = tf.Graph() 136 with graph.as_default(): 137 filename = tf.constant(temp_file) 138 tensor_name = 'a' 139 tensor = tensor_builder() 140 save_op = tensor_utils.save(filename, [tensor_name], [tensor]) 141 restored = tensor_utils.restore( 142 filename, tensor_name, tensor.dtype, tensor.shape 143 ) 144 with tf.compat.v1.Session(graph=graph) as sess: 145 sess.run(save_op) 146 expected_tensor, restored_tensor = sess.run([tensor, restored]) 147 self.assertAllEqual(expected_tensor, restored_tensor) 148 self._assert_op_in_graph(expected_op='SaveSlices', graph=graph) 149 self._assert_op_in_graph(expected_op='RestoreV2', graph=graph) 150 self.assertEqual( 151 expected_shape_and_slices_value, 152 self._get_shape_and_slices_value(graph), 153 ) 154 155 156if __name__ == '__main__': 157 absltest.main() 158