1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for proto_helpers.py.""" 15 16import collections 17 18import tensorflow as tf 19import tensorflow_federated as tff 20 21from fcp.artifact_building import proto_helpers 22from fcp.artifact_building import variable_helpers 23 24 25class MakeMetricTest(tf.test.TestCase): 26 27 def test_make_metric(self): 28 with tf.Graph().as_default(): 29 v = variable_helpers.create_vars_for_tff_type( 30 tff.to_type(collections.OrderedDict([("bar", tf.int32)])), name="foo" 31 ) 32 self.assertProtoEquals( 33 "variable_name: 'Identity:0' stat_name: 'client/bar'", 34 proto_helpers.make_metric(v[0], "client"), 35 ) 36 37 38class MakeTensorSpecTest(tf.test.TestCase): 39 40 def test_fully_defined_shape(self): 41 with tf.Graph().as_default(): 42 test_tensor = tf.constant([[1], [2]]) # Shape [1, 2] 43 with self.subTest("no_hint"): 44 tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) 45 self.assertProtoEquals( 46 ( 47 "name: 'Const:0' " 48 "shape { " 49 " dim { size: 2 } " 50 " dim { size: 1 } " 51 "} " 52 "dtype: DT_INT32" 53 ), 54 tensor_spec.experimental_as_proto(), 55 ) 56 with self.subTest("ignored_hint"): 57 # Supplied shape hint is incompatible, but ignored because tensor is 58 # fully defined. 59 tensor_spec = proto_helpers.make_tensor_spec_from_tensor( 60 test_tensor, shape_hint=tf.TensorShape([1, 4]) 61 ) 62 self.assertProtoEquals( 63 ( 64 "name: 'Const:0' " 65 "shape { " 66 " dim { size: 2 } " 67 " dim { size: 1 } " 68 "} " 69 "dtype: DT_INT32" 70 ), 71 tensor_spec.experimental_as_proto(), 72 ) 73 74 def test_undefined_shape(self): 75 with tf.Graph().as_default(): 76 # Create a undefined shape tensor via a placeholder and an op that doesn't 77 # alter shape. 78 test_tensor = tf.clip_by_value( 79 tf.compat.v1.placeholder(dtype=tf.int32), 0, 1 80 ) 81 with self.subTest("no_hint"): 82 tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) 83 self.assertProtoEquals( 84 ( 85 "name: 'clip_by_value:0' " 86 "shape { " 87 " unknown_rank: true " 88 "} " 89 "dtype: DT_INT32" 90 ), 91 tensor_spec.experimental_as_proto(), 92 ) 93 with self.subTest("hint"): 94 tensor_spec = proto_helpers.make_tensor_spec_from_tensor( 95 test_tensor, shape_hint=tf.TensorShape([1, 4]) 96 ) 97 self.assertProtoEquals( 98 ( 99 "name: 'clip_by_value:0' " 100 "shape { " 101 " dim { size: 1 } " 102 " dim { size: 4 } " 103 "} " 104 "dtype: DT_INT32" 105 ), 106 tensor_spec.experimental_as_proto(), 107 ) 108 109 def test_partially_defined_shape(self): 110 with tf.Graph().as_default(): 111 # Create a partially defined shape tensor via a placeholder and a reshape 112 # to specify some dimensions. 113 test_tensor = tf.reshape( 114 tf.compat.v1.placeholder(dtype=tf.int32), [2, -1] 115 ) 116 with self.subTest("no_hint"): 117 tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor) 118 self.assertProtoEquals( 119 ( 120 "name: 'Reshape:0' " 121 "shape { " 122 " dim { size: 2 } " 123 " dim { size: -1 } " 124 "} " 125 "dtype: DT_INT32" 126 ), 127 tensor_spec.experimental_as_proto(), 128 ) 129 with self.subTest("hint"): 130 tensor_spec = proto_helpers.make_tensor_spec_from_tensor( 131 test_tensor, shape_hint=tf.TensorShape([2, 4]) 132 ) 133 self.assertProtoEquals( 134 ( 135 "name: 'Reshape:0' " 136 "shape { " 137 " dim { size: 2 } " 138 " dim { size: 4} " 139 "} " 140 "dtype: DT_INT32" 141 ), 142 tensor_spec.experimental_as_proto(), 143 ) 144 with self.subTest("invalid_hint"): 145 with self.assertRaises(TypeError): 146 _ = proto_helpers.make_tensor_spec_from_tensor( 147 test_tensor, shape_hint=tf.TensorShape([1, 4]) 148 ) 149 150 151class MakeMeasurementTest(tf.test.TestCase): 152 153 def test_succeeds(self): 154 with tf.Graph().as_default(): 155 tensor = tf.constant(1) 156 tff_type = tff.types.TensorType(tensor.dtype, tensor.shape) 157 m = proto_helpers.make_measurement( 158 t=tensor, name="test", tff_type=tff_type 159 ) 160 161 self.assertEqual(m.name, "test") 162 self.assertProtoEquals( 163 m.tff_type, tff.types.serialize_type(tff_type).SerializeToString() 164 ) 165 166 def test_fails_for_non_matching_dtype(self): 167 with tf.Graph().as_default(): 168 tensor = tf.constant(1.0) 169 tff_type = tff.types.TensorType(tf.int32, tensor.shape) 170 171 with self.assertRaisesRegex(ValueError, ".* does not match.*"): 172 proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type) 173 174 def test_fails_for_non_matching_shape(self): 175 with tf.Graph().as_default(): 176 tensor = tf.constant(1.0) 177 tff_type = tff.types.TensorType(tensor.dtype, shape=[5]) 178 179 with self.assertRaisesRegex(ValueError, ".* does not match.*"): 180 proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type) 181 182 183if __name__ == "__main__": 184 tf.test.main() 185