# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for federated_data_source.""" from absl.testing import absltest import tensorflow as tf import tensorflow_federated as tff from fcp.demo import federated_data_source as fds from fcp.protos import plan_pb2 from fcp.protos.federatedcompute import eligibility_eval_tasks_pb2 _TaskAssignmentMode = ( eligibility_eval_tasks_pb2.PopulationEligibilitySpec.TaskInfo.TaskAssignmentMode ) POPULATION_NAME = 'test/name' EXAMPLE_SELECTOR = plan_pb2.ExampleSelector(collection_uri='app://test') class FederatedDataSourceTest(absltest.TestCase): def test_invalid_population_name(self): with self.assertRaisesRegex(ValueError, r'population_name must match ".+"'): fds.FederatedDataSource('^invalid^', EXAMPLE_SELECTOR) def test_population_name(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertEqual(ds.population_name, POPULATION_NAME) def test_example_selector(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertEqual(ds.example_selector, EXAMPLE_SELECTOR) def test_default_task_assignment_mode(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertEqual( ds.task_assignment_mode, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_SINGLE ) def test_task_assignment_mode(self): ds = fds.FederatedDataSource( POPULATION_NAME, EXAMPLE_SELECTOR, task_assignment_mode=_TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, ) self.assertEqual( ds.task_assignment_mode, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, ) def test_federated_type(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertEqual( ds.federated_type, tff.FederatedType(tff.SequenceType(tf.string), tff.CLIENTS)) def test_federated_type_nested(self): nested_example_selector = { 'a': EXAMPLE_SELECTOR, 'b': EXAMPLE_SELECTOR, 'c': { '1': EXAMPLE_SELECTOR, '2': EXAMPLE_SELECTOR }, } ds = fds.FederatedDataSource(POPULATION_NAME, nested_example_selector) self.assertEqual( ds.federated_type, tff.FederatedType( tff.StructType([ ('a', tff.SequenceType(tf.string)), ('b', tff.SequenceType(tf.string)), ('c', tff.StructType([ ('1', tff.SequenceType(tf.string)), ('2', tff.SequenceType(tf.string)), ])), ]), tff.CLIENTS)) def test_capabilities(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertListEqual(ds.capabilities, [tff.program.Capability.SUPPORTS_REUSE]) def test_iterator_federated_type(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) self.assertEqual(ds.iterator().federated_type, ds.federated_type) def test_iterator_select(self): ds = fds.FederatedDataSource( POPULATION_NAME, EXAMPLE_SELECTOR, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, ) self.assertEqual( ds.iterator().select(10), fds.DataSelectionConfig( POPULATION_NAME, EXAMPLE_SELECTOR, _TaskAssignmentMode.TASK_ASSIGNMENT_MODE_MULTIPLE, 10, ), ) def test_iterator_select_with_invalid_num_clients(self): ds = fds.FederatedDataSource(POPULATION_NAME, EXAMPLE_SELECTOR) with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): ds.iterator().select(num_clients=None) with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): ds.iterator().select(num_clients=-5) with self.assertRaisesRegex(ValueError, 'num_clients must be positive'): ds.iterator().select(num_clients=0) if __name__ == '__main__': absltest.main()