xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/proto_helpers_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for proto_helpers.py."""
15
16import collections
17
18import tensorflow as tf
19import tensorflow_federated as tff
20
21from fcp.artifact_building import proto_helpers
22from fcp.artifact_building import variable_helpers
23
24
25class MakeMetricTest(tf.test.TestCase):
26
27  def test_make_metric(self):
28    with tf.Graph().as_default():
29      v = variable_helpers.create_vars_for_tff_type(
30          tff.to_type(collections.OrderedDict([("bar", tf.int32)])), name="foo"
31      )
32      self.assertProtoEquals(
33          "variable_name: 'Identity:0' stat_name: 'client/bar'",
34          proto_helpers.make_metric(v[0], "client"),
35      )
36
37
38class MakeTensorSpecTest(tf.test.TestCase):
39
40  def test_fully_defined_shape(self):
41    with tf.Graph().as_default():
42      test_tensor = tf.constant([[1], [2]])  # Shape [1, 2]
43      with self.subTest("no_hint"):
44        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
45        self.assertProtoEquals(
46            (
47                "name: 'Const:0' "
48                "shape { "
49                "  dim { size: 2 } "
50                "  dim { size: 1 } "
51                "} "
52                "dtype: DT_INT32"
53            ),
54            tensor_spec.experimental_as_proto(),
55        )
56      with self.subTest("ignored_hint"):
57        # Supplied shape hint is incompatible, but ignored because tensor is
58        # fully defined.
59        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
60            test_tensor, shape_hint=tf.TensorShape([1, 4])
61        )
62        self.assertProtoEquals(
63            (
64                "name: 'Const:0' "
65                "shape { "
66                "  dim { size: 2 } "
67                "  dim { size: 1 } "
68                "} "
69                "dtype: DT_INT32"
70            ),
71            tensor_spec.experimental_as_proto(),
72        )
73
74  def test_undefined_shape(self):
75    with tf.Graph().as_default():
76      # Create a undefined shape tensor via a placeholder and an op that doesn't
77      # alter shape.
78      test_tensor = tf.clip_by_value(
79          tf.compat.v1.placeholder(dtype=tf.int32), 0, 1
80      )
81      with self.subTest("no_hint"):
82        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
83        self.assertProtoEquals(
84            (
85                "name: 'clip_by_value:0' "
86                "shape { "
87                " unknown_rank: true "
88                "} "
89                "dtype: DT_INT32"
90            ),
91            tensor_spec.experimental_as_proto(),
92        )
93      with self.subTest("hint"):
94        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
95            test_tensor, shape_hint=tf.TensorShape([1, 4])
96        )
97        self.assertProtoEquals(
98            (
99                "name: 'clip_by_value:0' "
100                "shape { "
101                "  dim { size: 1 } "
102                "  dim { size: 4 } "
103                "} "
104                "dtype: DT_INT32"
105            ),
106            tensor_spec.experimental_as_proto(),
107        )
108
109  def test_partially_defined_shape(self):
110    with tf.Graph().as_default():
111      # Create a partially defined shape tensor via a placeholder and a reshape
112      # to specify some dimensions.
113      test_tensor = tf.reshape(
114          tf.compat.v1.placeholder(dtype=tf.int32), [2, -1]
115      )
116      with self.subTest("no_hint"):
117        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(test_tensor)
118        self.assertProtoEquals(
119            (
120                "name: 'Reshape:0' "
121                "shape { "
122                "  dim { size: 2 } "
123                "  dim { size: -1 } "
124                "} "
125                "dtype: DT_INT32"
126            ),
127            tensor_spec.experimental_as_proto(),
128        )
129      with self.subTest("hint"):
130        tensor_spec = proto_helpers.make_tensor_spec_from_tensor(
131            test_tensor, shape_hint=tf.TensorShape([2, 4])
132        )
133        self.assertProtoEquals(
134            (
135                "name: 'Reshape:0' "
136                "shape { "
137                "  dim { size: 2 } "
138                "  dim { size: 4} "
139                "} "
140                "dtype: DT_INT32"
141            ),
142            tensor_spec.experimental_as_proto(),
143        )
144      with self.subTest("invalid_hint"):
145        with self.assertRaises(TypeError):
146          _ = proto_helpers.make_tensor_spec_from_tensor(
147              test_tensor, shape_hint=tf.TensorShape([1, 4])
148          )
149
150
151class MakeMeasurementTest(tf.test.TestCase):
152
153  def test_succeeds(self):
154    with tf.Graph().as_default():
155      tensor = tf.constant(1)
156      tff_type = tff.types.TensorType(tensor.dtype, tensor.shape)
157      m = proto_helpers.make_measurement(
158          t=tensor, name="test", tff_type=tff_type
159      )
160
161      self.assertEqual(m.name, "test")
162      self.assertProtoEquals(
163          m.tff_type, tff.types.serialize_type(tff_type).SerializeToString()
164      )
165
166  def test_fails_for_non_matching_dtype(self):
167    with tf.Graph().as_default():
168      tensor = tf.constant(1.0)
169      tff_type = tff.types.TensorType(tf.int32, tensor.shape)
170
171      with self.assertRaisesRegex(ValueError, ".* does not match.*"):
172        proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type)
173
174  def test_fails_for_non_matching_shape(self):
175    with tf.Graph().as_default():
176      tensor = tf.constant(1.0)
177      tff_type = tff.types.TensorType(tensor.dtype, shape=[5])
178
179      with self.assertRaisesRegex(ValueError, ".* does not match.*"):
180        proto_helpers.make_measurement(t=tensor, name="test", tff_type=tff_type)
181
182
183if __name__ == "__main__":
184  tf.test.main()
185