# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for proto_helpers.py.""" import collections import tensorflow as tf import tensorflow_federated as tff from fcp.artifact_building import proto_helpers from fcp.artifact_building import variable_helpers class MakeMetricTest(tf.test.TestCase): def test_make_metric(self): with tf.Graph().as_default(): v = variable_helpers.create_vars_for_tff_type( tff.to_type(collections.OrderedDict([("bar", tf.int32)])), name="foo" ) self.assertProtoEquals( "variable_name: 'Identity:0' stat_name: 'client/bar'", proto_helpers.make_metric(v[0], "client"), ) class MakeTensorSpecTest(tf.test.TestCase): def test_fully_defined_shape(self): with tf.Graph().as_default(): test_tensor = tf.constant([[1], [2]]) # Shape [1, 2] with self.subTest("no_hint"): tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) self.assertProtoEquals( ( "name: 'Const:0' " "shape { " " dim { size: 2 } " " dim { size: 1 } " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) with self.subTest("ignored_hint"): # Supplied shape hint is incompatible, but ignored because tensor is # fully defined. tensor_spec = proto_helpers.make_tensor_spec_from_tensor( test_tensor, shape_hint=tf.TensorShape([1, 4]) ) self.assertProtoEquals( ( "name: 'Const:0' " "shape { " " dim { size: 2 } " " dim { size: 1 } " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) def test_undefined_shape(self): with tf.Graph().as_default(): # Create a undefined shape tensor via a placeholder and an op that doesn't # alter shape. test_tensor = tf.clip_by_value( tf.compat.v1.placeholder(dtype=tf.int32), 0, 1 ) with self.subTest("no_hint"): tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) self.assertProtoEquals( ( "name: 'clip_by_value:0' " "shape { " " unknown_rank: true " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) with self.subTest("hint"): tensor_spec = proto_helpers.make_tensor_spec_from_tensor( test_tensor, shape_hint=tf.TensorShape([1, 4]) ) self.assertProtoEquals( ( "name: 'clip_by_value:0' " "shape { " " dim { size: 1 } " " dim { size: 4 } " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) def test_partially_defined_shape(self): with tf.Graph().as_default(): # Create a partially defined shape tensor via a placeholder and a reshape # to specify some dimensions. test_tensor = tf.reshape( tf.compat.v1.placeholder(dtype=tf.int32), [2, -1] ) with self.subTest("no_hint"): tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) self.assertProtoEquals( ( "name: 'Reshape:0' " "shape { " " dim { size: 2 } " " dim { size: -1 } " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) with self.subTest("hint"): tensor_spec = proto_helpers.make_tensor_spec_from_tensor( test_tensor, shape_hint=tf.TensorShape([2, 4]) ) self.assertProtoEquals( ( "name: 'Reshape:0' " "shape { " " dim { size: 2 } " " dim { size: 4} " "} " "dtype: DT_INT32" ), tensor_spec.experimental_as_proto(), ) with self.subTest("invalid_hint"): with self.assertRaises(TypeError): _ = proto_helpers.make_tensor_spec_from_tensor( test_tensor, shape_hint=tf.TensorShape([1, 4]) ) class MakeMeasurementTest(tf.test.TestCase): def test_succeeds(self): with tf.Graph().as_default(): tensor = tf.constant(1) tff_type = tff.types.TensorType(tensor.dtype, tensor.shape) m = proto_helpers.make_measurement( t=tensor, name="test", tff_type=tff_type ) self.assertEqual(m.name, "test") self.assertProtoEquals( m.tff_type, tff.types.serialize_type(tff_type).SerializeToString() ) def test_fails_for_non_matching_dtype(self): with tf.Graph().as_default(): tensor = tf.constant(1.0) tff_type = tff.types.TensorType(tf.int32, tensor.shape) with self.assertRaisesRegex(ValueError, ".* does not match.*"): proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type) def test_fails_for_non_matching_shape(self): with tf.Graph().as_default(): tensor = tf.constant(1.0) tff_type = tff.types.TensorType(tensor.dtype, shape=[5]) with self.assertRaisesRegex(ValueError, ".* does not match.*"): proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type) if __name__ == "__main__": tf.test.main()