xref: /aosp_15_r20/external/federated-compute/fcp/demo/federated_program_test.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""End-to-end test running a simple Federated Program."""
15
16import asyncio
17import os
18import tempfile
19import unittest
20
21from absl import flags
22from absl.testing import absltest
23import tensorflow as tf
24import tensorflow_federated as tff
25
26from fcp import demo
27from fcp.client import client_runner_example_data_pb2
28from fcp.protos import plan_pb2
29
30POPULATION_NAME = 'test/population'
31COLLECTION_URI = 'app:/example'
32
33
34@tff.federated_computation()
35def initialize() -> tff.Value:
36  """Returns the initial state."""
37  return tff.federated_value(0, tff.SERVER)
38
39
40@tff.federated_computation(
41    tff.type_at_server(tf.int32),
42    tff.type_at_clients(tff.SequenceType(tf.string)))
43def sum_counts(state, client_data):
44  """Sums the value of all 'count' features across all clients."""
45
46  @tf.function
47  def reduce_counts(s: tf.int32, example: tf.string) -> tf.int32:
48    features = {'count': tf.io.FixedLenFeature((), tf.int64)}
49    count = tf.io.parse_example(example, features=features)['count']
50    return s + tf.cast(count, tf.int32)
51
52  @tff.tf_computation
53  def client_work(client_data):
54    return client_data.reduce(0, reduce_counts)
55
56  client_counts = tff.federated_map(client_work, client_data)
57  aggregated_count = tff.federated_sum(client_counts)
58
59  num_clients = tff.federated_sum(tff.federated_value(1, tff.CLIENTS))
60  metrics = tff.federated_zip((num_clients,))
61  return state + aggregated_count, metrics
62
63
64async def program_logic(init: tff.Computation, comp: tff.Computation,
65                        data_source: tff.program.FederatedDataSource,
66                        total_rounds: int, number_of_clients: int,
67                        release_manager: tff.program.ReleaseManager) -> None:
68  """Initializes and runs a computation, releasing metrics and final state."""
69  tff.program.check_in_federated_context()
70  data_iterator = data_source.iterator()
71  state = init()
72  for i in range(total_rounds):
73    cohort_config = data_iterator.select(number_of_clients)
74    state, metrics = comp(state, cohort_config)
75    await release_manager.release(
76        metrics, comp.type_signature.result[1], key=f'metrics/{i}')
77  await release_manager.release(
78      state, comp.type_signature.result[0], key='result')
79
80
81async def run_client(population_name: str, server_url: str, num_rounds: int,
82                     collection_uri: str,
83                     examples: list[tf.train.Example]) -> int:
84  """Runs a client and returns its return code."""
85  client_runner = os.path.join(
86      flags.FLAGS.test_srcdir,
87      'com_google_fcp',
88      'fcp',
89      'client',
90      'client_runner_main')
91
92  example_data = client_runner_example_data_pb2.ClientRunnerExampleData(
93      examples_by_collection_uri={
94          collection_uri:
95              client_runner_example_data_pb2.ClientRunnerExampleData
96              .ExampleList(examples=[e.SerializeToString() for e in examples])
97      })
98
99  # Unfortunately, since there's no convenient way to tell when the server has
100  # actually started serving the computation, we cannot delay starting the
101  # client until the server's ready to assign it a task. This isn't an issue in
102  # a production setting, where there's a steady stream of clients connecting,
103  # but it is a problem in this unit test, where each client only connects to
104  # the server a fixed number of times. To work around this, we give the server
105  # a little extra time to become ready; this delay doesn't significantly slow
106  # down the test since there are many other time-consuming steps.
107  await asyncio.sleep(1)
108
109  with tempfile.NamedTemporaryFile() as tmpfile:
110    tmpfile.write(example_data.SerializeToString())
111    tmpfile.flush()
112    subprocess = await asyncio.create_subprocess_exec(
113        client_runner, f'--population={population_name}',
114        f'--server={server_url}', f'--example_data_path={tmpfile.name}',
115        f'--num_rounds={num_rounds}', '--sleep_after_round_secs=1',
116        '--use_http_federated_compute_protocol', '--use_tflite_training')
117    return await subprocess.wait()
118
119
120def create_examples(counts: list[int]) -> list[tf.train.Example]:
121  """Creates a list of tf.train.Example with the provided 'count' features."""
122  examples = []
123  for count in counts:
124    example = tf.train.Example()
125    example.features.feature['count'].int64_list.value.append(count)
126    examples.append(example)
127  return examples
128
129
130class FederatedProgramTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase):
131
132  async def test_multiple_rounds(self):
133    data_source = demo.FederatedDataSource(
134        POPULATION_NAME,
135        plan_pb2.ExampleSelector(collection_uri=COLLECTION_URI))
136    comp = demo.FederatedComputation(sum_counts, name='sum_counts')
137    release_manager = tff.program.MemoryReleaseManager()
138    num_rounds = 2
139    client_counts = [
140        [0, 3, 5, 1],
141        [2, 4],
142    ]
143
144    base_context = tff.backends.native.create_sync_local_cpp_execution_context()
145
146    with demo.FederatedContext(
147        POPULATION_NAME,
148        base_context=base_context) as ctx:
149      clients = [
150          run_client(POPULATION_NAME, f'http://localhost:{ctx.server_port}',
151                     num_rounds, COLLECTION_URI, create_examples(counts))
152          for counts in client_counts
153      ]
154      with tff.framework.get_context_stack().install(ctx):
155        program = program_logic(initialize, comp, data_source, num_rounds,
156                                len(client_counts), release_manager)
157        return_codes = (await asyncio.gather(program, *clients))[1:]
158        # All clients should complete successfully.
159        self.assertListEqual(return_codes, [0] * len(client_counts))
160
161    self.assertSequenceEqual(release_manager.values()['result'],
162                             (num_rounds * sum([sum(l) for l in client_counts]),
163                              tff.type_at_server(tf.int32)))
164    for i in range(num_rounds):
165      self.assertSequenceEqual(
166          release_manager.values()[f'metrics/{i}'],
167          ((len(client_counts),),
168           tff.type_at_server(tff.StructWithPythonType([tf.int32], tuple))))
169
170
171if __name__ == '__main__':
172  absltest.main()
173