1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for data_spec.py.""" 15 16import collections 17 18from absl.testing import absltest 19 20import tensorflow as tf 21import tensorflow_federated as tff 22 23from fcp.artifact_building import data_spec 24from fcp.protos import plan_pb2 25 26_TEST_EXAMPLE_SELECTOR = plan_pb2.ExampleSelector( 27 collection_uri='app://fake_uri' 28) 29 30 31class DataSpecTest(absltest.TestCase): 32 33 def test_construction_with_valid_arguments(self): 34 preprocessing_fn = lambda ds: ds.batch(10) 35 ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) 36 self.assertIs(ds.example_selector_proto, _TEST_EXAMPLE_SELECTOR) 37 self.assertIs(ds.preprocessing_fn, preprocessing_fn) 38 39 def test_is_data_spec_or_structure(self): 40 preprocessing_fn = lambda ds: ds.batch(10) 41 ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) 42 self.assertTrue(data_spec.is_data_spec_or_structure(ds)) 43 self.assertTrue(data_spec.is_data_spec_or_structure([ds, ds])) 44 self.assertTrue(data_spec.is_data_spec_or_structure({'a': ds})) 45 self.assertFalse(data_spec.is_data_spec_or_structure(10)) 46 self.assertFalse(data_spec.is_data_spec_or_structure({'a': 10})) 47 48 def test_type_signature(self): 49 def parsing_fn(serialized_example): 50 parsing_dict = { 51 'key': tf.io.FixedLenFeature(shape=[1], dtype=tf.int64), 52 } 53 parsed_example = tf.io.parse_example(serialized_example, parsing_dict) 54 return collections.OrderedDict([('key', parsed_example['key'])]) 55 56 preprocessing_fn = lambda ds: ds.map(parsing_fn) 57 ds = data_spec.DataSpec(_TEST_EXAMPLE_SELECTOR, preprocessing_fn) 58 59 expected_type = tff.SequenceType( 60 tff.types.to_type( 61 collections.OrderedDict( 62 [('key', tf.TensorSpec(shape=(1,), dtype=tf.int64))] 63 ) 64 ) 65 ) 66 self.assertEqual(ds.type_signature, expected_type) 67 68 69if __name__ == '__main__': 70 tf.compat.v1.enable_v2_behavior() 71 absltest.main() 72