1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""End-to-end test running a simple Federated Program.""" 15 16import asyncio 17import os 18import tempfile 19import unittest 20 21from absl import flags 22from absl.testing import absltest 23import tensorflow as tf 24import tensorflow_federated as tff 25 26from fcp import demo 27from fcp.client import client_runner_example_data_pb2 28from fcp.protos import plan_pb2 29 30POPULATION_NAME = 'test/population' 31COLLECTION_URI = 'app:/example' 32 33 34@tff.federated_computation() 35def initialize() -> tff.Value: 36 """Returns the initial state.""" 37 return tff.federated_value(0, tff.SERVER) 38 39 40@tff.federated_computation( 41 tff.type_at_server(tf.int32), 42 tff.type_at_clients(tff.SequenceType(tf.string))) 43def sum_counts(state, client_data): 44 """Sums the value of all 'count' features across all clients.""" 45 46 @tf.function 47 def reduce_counts(s: tf.int32, example: tf.string) -> tf.int32: 48 features = {'count': tf.io.FixedLenFeature((), tf.int64)} 49 count = tf.io.parse_example(example, features=features)['count'] 50 return s + tf.cast(count, tf.int32) 51 52 @tff.tf_computation 53 def client_work(client_data): 54 return client_data.reduce(0, reduce_counts) 55 56 client_counts = tff.federated_map(client_work, client_data) 57 aggregated_count = tff.federated_sum(client_counts) 58 59 num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS)) 60 metrics = tff.federated_zip((num_clients,)) 61 return state + aggregated_count, metrics 62 63 64async def program_logic(init: tff.Computation, comp: tff.Computation, 65 data_source: tff.program.FederatedDataSource, 66 total_rounds: int, number_of_clients: int, 67 release_manager: tff.program.ReleaseManager) -> None: 68 """Initializes and runs a computation, releasing metrics and final state.""" 69 tff.program.check_in_federated_context() 70 data_iterator = data_source.iterator() 71 state = init() 72 for i in range(total_rounds): 73 cohort_config = data_iterator.select(number_of_clients) 74 state, metrics = comp(state, cohort_config) 75 await release_manager.release( 76 metrics, comp.type_signature.result[1], key=f'metrics/{i}') 77 await release_manager.release( 78 state, comp.type_signature.result[0], key='result') 79 80 81async def run_client(population_name: str, server_url: str, num_rounds: int, 82 collection_uri: str, 83 examples: list[tf.train.Example]) -> int: 84 """Runs a client and returns its return code.""" 85 client_runner = os.path.join( 86 flags.FLAGS.test_srcdir, 87 'com_google_fcp', 88 'fcp', 89 'client', 90 'client_runner_main') 91 92 example_data = client_runner_example_data_pb2.ClientRunnerExampleData( 93 examples_by_collection_uri={ 94 collection_uri: 95 client_runner_example_data_pb2.ClientRunnerExampleData 96 .ExampleList(examples=[e.SerializeToString() for e in examples]) 97 }) 98 99 # Unfortunately, since there's no convenient way to tell when the server has 100 # actually started serving the computation, we cannot delay starting the 101 # client until the server's ready to assign it a task. This isn't an issue in 102 # a production setting, where there's a steady stream of clients connecting, 103 # but it is a problem in this unit test, where each client only connects to 104 # the server a fixed number of times. To work around this, we give the server 105 # a little extra time to become ready; this delay doesn't significantly slow 106 # down the test since there are many other time-consuming steps. 107 await asyncio.sleep(1) 108 109 with tempfile.NamedTemporaryFile() as tmpfile: 110 tmpfile.write(example_data.SerializeToString()) 111 tmpfile.flush() 112 subprocess = await asyncio.create_subprocess_exec( 113 client_runner, f'--population={population_name}', 114 f'--server={server_url}', f'--example_data_path={tmpfile.name}', 115 f'--num_rounds={num_rounds}', '--sleep_after_round_secs=1', 116 '--use_http_federated_compute_protocol', '--use_tflite_training') 117 return await subprocess.wait() 118 119 120def create_examples(counts: list[int]) -> list[tf.train.Example]: 121 """Creates a list of tf.train.Example with the provided 'count' features.""" 122 examples = [] 123 for count in counts: 124 example = tf.train.Example() 125 example.features.feature['count'].int64_list.value.append(count) 126 examples.append(example) 127 return examples 128 129 130class FederatedProgramTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): 131 132 async def test_multiple_rounds(self): 133 data_source = demo.FederatedDataSource( 134 POPULATION_NAME, 135 plan_pb2.ExampleSelector(collection_uri=COLLECTION_URI)) 136 comp = demo.FederatedComputation(sum_counts, name='sum_counts') 137 release_manager = tff.program.MemoryReleaseManager() 138 num_rounds = 2 139 client_counts = [ 140 [0, 3, 5, 1], 141 [2, 4], 142 ] 143 144 base_context = tff.backends.native.create_sync_local_cpp_execution_context() 145 146 with demo.FederatedContext( 147 POPULATION_NAME, 148 base_context=base_context) as ctx: 149 clients = [ 150 run_client(POPULATION_NAME, f'http://localhost:{ctx.server_port}', 151 num_rounds, COLLECTION_URI, create_examples(counts)) 152 for counts in client_counts 153 ] 154 with tff.framework.get_context_stack().install(ctx): 155 program = program_logic(initialize, comp, data_source, num_rounds, 156 len(client_counts), release_manager) 157 return_codes = (await asyncio.gather(program, *clients))[1:] 158 # All clients should complete successfully. 159 self.assertListEqual(return_codes, [0] * len(client_counts)) 160 161 self.assertSequenceEqual(release_manager.values()['result'], 162 (num_rounds * sum([sum(l) for l in client_counts]), 163 tff.type_at_server(tf.int32))) 164 for i in range(num_rounds): 165 self.assertSequenceEqual( 166 release_manager.values()[f'metrics/{i}'], 167 ((len(client_counts),), 168 tff.type_at_server(tff.StructWithPythonType([tf.int32], tuple)))) 169 170 171if __name__ == '__main__': 172 absltest.main() 173