xref: /aosp_15_r20/external/federated-compute/fcp/artifact_building/data_spec_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for data_spec.py."""
15
16import collections
17
18from absl.testing import absltest
19
20import tensorflow as tf
21import tensorflow_federated as tff
22
23from fcp.artifact_building import data_spec
24from fcp.protos import plan_pb2
25
26_TEST_EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(
27    collection_uri='app://fake_uri'
28)
29
30
31class DataSpecTest(absltest.TestCase):
32
33  def test_construction_with_valid_arguments(self):
34    preprocessing_fn = lambda ds: ds.batch(10)
35    ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
36    self.assertIs(ds.example_selector_proto, _TEST_EXAMPLE_SELECTOR)
37    self.assertIs(ds.preprocessing_fn, preprocessing_fn)
38
39  def test_is_data_spec_or_structure(self):
40    preprocessing_fn = lambda ds: ds.batch(10)
41    ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
42    self.assertTrue(data_spec.is_data_spec_or_structure(ds))
43    self.assertTrue(data_spec.is_data_spec_or_structure([ds, ds]))
44    self.assertTrue(data_spec.is_data_spec_or_structure({'a': ds}))
45    self.assertFalse(data_spec.is_data_spec_or_structure(10))
46    self.assertFalse(data_spec.is_data_spec_or_structure({'a': 10}))
47
48  def test_type_signature(self):
49    def parsing_fn(serialized_example):
50      parsing_dict = {
51          'key': tf.io.FixedLenFeature(shape=[1], dtype=tf.int64),
52      }
53      parsed_example = tf.io.parse_example(serialized_example, parsing_dict)
54      return collections.OrderedDict([('key', parsed_example['key'])])
55
56    preprocessing_fn = lambda ds: ds.map(parsing_fn)
57    ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn)
58
59    expected_type = tff.SequenceType(
60        tff.types.to_type(
61            collections.OrderedDict(
62                [('key', tf.TensorSpec(shape=(1,), dtype=tf.int64))]
63            )
64        )
65    )
66    self.assertEqual(ds.type_signature, expected_type)
67
68
69if __name__ == '__main__':
70  tf.compat.v1.enable_v2_behavior()
71  absltest.main()
72