xref: /aosp_15_r20/external/federated-compute/fcp/protos/plan.proto (revision 14675a029014e728ec732f129a32e299b2da0601)
1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15syntax = "proto3";
16
17package google.internal.federated.plan;
18
19import "google/protobuf/any.proto";
20import "tensorflow/core/framework/tensor.proto";
21import "tensorflow/core/framework/types.proto";
22import "tensorflow/core/protobuf/saver.proto";
23import "tensorflow/core/protobuf/struct.proto";
24
25option java_package = "com.google.internal.federated.plan";
26option java_multiple_files = true;
27option java_outer_classname = "PlanProto";
28
29// Primitives
30// ===========
31
32// Represents an operation to save or restore from a checkpoint.  Some
33// instances of this message may only be used either for restore or for
34// save, others for both directions. This is documented together with
35// their usage.
36//
37// This op has four essential uses:
38//   1. read and apply a checkpoint.
39//   2. write a checkpoint.
40//   3. read and apply from an aggregated side channel.
41//   4. write to a side channel (grouped with write a checkpoint).
42// We should consider splitting this into four separate messages.
43message CheckpointOp {
44  // An optional standard saver def. If not provided, only the
45  // op(s) below will be executed. This must be a version 1 SaverDef.
46  tensorflow.SaverDef saver_def = 1;
47
48  // An optional operation to run before the saver_def is executed for
49  // restore.
50  string before_restore_op = 2;
51
52  // An optional operation to run after the saver_def has been
53  // executed for restore. If side_channel_tensors are provided, then
54  // they should be provided in a feed_dict to this op.
55  string after_restore_op = 3;
56
57  // An optional operation to run before the saver_def will be
58  // executed for save.
59  string before_save_op = 4;
60
61  // An optional operation to run after the saver_def has been
62  // executed for save. If there are side_channel_tensors, this op
63  // should be run after the side_channel_tensors have been fetched.
64  string after_save_op = 5;
65
66  // In addition to being saved and restored from a checkpoint, one can
67  // also save and restore via a side channel. The keys in this map are
68  // the names of the tensors transmitted by the side channel. These (key)
69  // tensors should be read off just before saving a SaveDef and used
70  // by the code that handles the side channel. Any variables provided this
71  // way should NOT be saved in the SaveDef.
72  //
73  // For restoring, the variables that are provided by the side channel
74  // are restored differently than those for a checkpoint. For those from
75  // the side channel, these should be restored by calling the before_restore_op
76  // with a feed dict whose keys are the restore_names in the SideChannel and
77  // whose values are the values to be restored.
78  map<string, SideChannel> side_channel_tensors = 6;
79
80  // An optional name of a tensor in to which a unique token for the current
81  // session should be written.
82  //
83  // This session identifier allows TensorFlow ops such as `ServeSlices` or
84  // `ExternalDataset` to refer to callbacks and other session-global objects
85  // registered before running the session.
86  string session_token_tensor_name = 7;
87}
88
89message SideChannel {
90  // A side channel whose variables are processed via SecureAggregation.
91  // This side channel implements aggregation via sum over a set of
92  // clients, so the restored tensor will be a sum of multiple clients
93  // inputs into the side channel. Hence this will restore during the
94  // read_aggregate_update restore, not the per-client read_update restore.
95  message SecureAggregand {
96    message Dimension {
97      int64 size = 1;
98    }
99
100    // Dimensions of the aggregand. This is used by the secure aggregation
101    // protocol in its early rounds, not as redundant info which could be
102    // obtained by reading the dimensions of the tensor itself.
103    repeated Dimension dimension = 3;
104
105    // The data type anticipated by the server-side graph.
106    tensorflow.DataType dtype = 4;
107
108    // SecureAggregation will compute sum modulo this modulus.
109    message FixedModulus {
110      uint64 modulus = 1;
111    }
112
113    // SecureAggregation will for each shard compute sum modulo m with m at
114    // least (1 + shard_size * (base_modulus - 1)), then aggregate
115    // shard results with non-modular addition. Here, shard_size is the number
116    // of clients in the shard.
117    //
118    // Note that the modulus for each shard will be greater than the largest
119    // possible (non-modular) sum of the inputs to that shard.  That is,
120    // assuming each client has input on range [0, base_modulus), the result
121    // will be identical to non-modular addition (i.e. federated_sum).
122    //
123    // While any m >= (1 + shard_size * (base_modulus - 1)), the current
124    // implementation takes
125    // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
126    // smallest possible value of m that is also a power of 2.  This choice is
127    // made because (a) it uses the same number of bits per vector entry as
128    // valid smaller m, using the current on-the-wire encoding scheme, and (b)
129    // it enables the underlying mask-generation PRNG to run in its most
130    // computationally efficient mode, which can be up to 2x faster.
131    message ModulusTimesShardSize {
132      uint64 base_modulus = 1;
133    }
134
135    oneof modulus_scheme {
136      // Bitwidth of the aggregand.
137      //
138      // This is the bitwidth of an input value (i.e. the bitwidth that
139      // quantization should target).  The Secure Aggregation bitwidth (i.e.,
140      // the bitwidth of the *sum* of the input values) will be a function of
141      // this bitwidth and the number of participating clients, as negotiated
142      // with the server when the protocol is initiated.
143      //
144      // Deprecated; prefer fixed_modulus instead.
145      int32 quantized_input_bitwidth = 2 [deprecated = true];
146
147      FixedModulus fixed_modulus = 5;
148      ModulusTimesShardSize modulus_times_shard_size = 6;
149    }
150
151    reserved 1;
152  }
153
154  // What type of side channel is used.
155  oneof type {
156    SecureAggregand secure_aggregand = 1;
157  }
158
159  // When restoring the name of the tensor to restore to.  This is the name
160  // (key) supplied in the feed_dict in the before_restore_op in order to
161  // restore the tensor provided by the side channel (which will be the
162  // value in the feed_dict).
163  string restore_name = 2;
164}
165
166// Container for a metric used by the internal toolkit.
167message Metric {
168  // Name of an Op to run to read the value.
169  string variable_name = 1;
170
171  // A human-readable name for the statistic. Metric names are usually
172  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
173  // Must be 7-bit ASCII and under 122 characters.
174  string stat_name = 2;
175
176  // The human-readable name of another metric by which this metric should be
177  // normalized, if any. If empty, this Metric should be aggregated with simple
178  // summation. If not empty, the Metric is aggregated according to
179  // weighted_metric_sum = sum_i (metric_i * weight_i)
180  // weight_sum = sum_i weight_i
181  // average_metric_value = weighted_metric_sum / weight_sum
182  string weight_name = 3;
183}
184
185// Controls the format of output metrics users receive. Represents instructions
186// for how metrics are to be output to users, controlling the end format of
187// the metric users receive.
188message OutputMetric {
189  // Metric name.
190  string name = 1;
191
192  oneof value_source {
193    // A metric representing one stat with aggregation type sum.
194    SumOptions sum = 2;
195
196    // A metric representing a ratio between metrics with aggregation
197    // type average.
198    AverageOptions average = 3;
199
200    // A metric that is not aggregated by the MetricReportAggregator or
201    // metrics_loader. This includes metrics like 'num_server_updates' that are
202    // aggregated in TensorFlow.
203    NoneOptions none = 4;
204
205    // A metric representing one stat with aggregation type only sample.
206    // Samples at most 101 clients' values.
207    OnlySampleOptions only_sample = 5;
208  }
209  // Iff True, the metric will be plotted in the default view of the
210  // task level Colab automatically.
211  oneof visualization_info {
212    bool auto_plot = 6 [deprecated = true];
213    VisualizationSpec plot_spec = 7;
214  }
215}
216
217message VisualizationSpec {
218  // Different allowable plot types.
219  enum VisualizationType {
220    NONE = 0;
221    DEFAULT_PLOT_FOR_TASK_TYPE = 1;
222    LINE_PLOT = 2;
223    LINE_PLOT_WITH_PERCENTILES = 3;
224    HISTOGRAM = 4;
225  }
226
227  // Defines the plot type to provide downstream.
228  VisualizationType plot_type = 1;
229
230  // The x-axis which to provide for the given metric. Must be the name of a
231  // metric or counter. Recommended x_axis options are source_round, round,
232  // or time.
233  string x_axis = 2;
234
235  // Iff True, metric will be displayed on a population level dashboard.
236  bool plot_on_population_dashboard = 3;
237}
238
239// A metric representing one stat with aggregation type sum.
240message SumOptions {
241  // Name for corresponding Metric stat_name field.
242  string stat_name = 1;
243
244  // Iff True, a cumulative sum over rounds will be provided in addition to a
245  // sum per round for the value metric.
246  bool include_cumulative_sum = 2;
247
248  // Iff True, sample of at most 101 clients' values.
249  // Used to calculate quantiles in downstream visualization pipeline.
250  bool include_client_samples = 3;
251}
252
253// A metric representing a ratio between metrics with aggregation type average.
254// Represents: numerator stat / denominator stat.
255message AverageOptions {
256  // Numerator stat name pointing to corresponding Metric stat_name.
257  string numerator_stat_name = 1;
258
259  // Denominator stat name pointing to corresponding Metric stat_name.
260  string denominator_stat_name = 2;
261
262  // Name for corresponding Metric stat_name that is the ratio of the
263  // numerator stat / denominator stat.
264  string average_stat_name = 3;
265
266  // Iff True, sample of at most 101 client's values.
267  // Used to calculate quantiles in downstream visualization pipeline.
268  bool include_client_samples = 4;
269}
270
271// A metric representing one stat with aggregation type none.
272message NoneOptions {
273  // Name for corresponding Metric stat_name field.
274  string stat_name = 1;
275}
276
277// A metric representing one stat with aggregation type only sample.
278message OnlySampleOptions {
279  // Name for corresponding Metric stat_name field.
280  string stat_name = 1;
281}
282
283// Represents a data set. This is used for testing.
284message Dataset {
285  // Represents the data set for one client.
286  message ClientDataset {
287    // A string identifying the client.
288    string client_id = 1;
289
290    // A list of serialized tf.Example protos.
291    repeated bytes example = 2;
292
293    // Represents a dataset whose examples are selected by an ExampleSelector.
294    message SelectedExample {
295      ExampleSelector selector = 1;
296      repeated bytes example = 2;
297    }
298
299    // A list of (selector, dataset) pairs. Used in testing some *TFF-based
300    // tasks* that require multiple datasets as client input, e.g., a TFF-based
301    // personalization eval task requires each client to provide at least two
302    // datasets: one for train, and the other for test.
303    repeated SelectedExample selected_example = 3;
304  }
305
306  // A list of client data.
307  repeated ClientDataset client_data = 1;
308}
309
310// Represents predicates over metrics - i.e., expectations. This is used in
311// training/eval tests to encode metric names and values expected to be reported
312// by a client execution.
313message MetricTestPredicates {
314  // The value must lie in [lower_bound; upper_bound]. Can also be used for
315  // approximate matching (lower == value - epsilon; upper = value + epsilon).
316  message Interval {
317    double lower_bound = 1;
318    double upper_bound = 2;
319  }
320
321  // The value must be a real value as long as the value of the weight_name
322  // metric is non-zero. If the weight metric is zero, then it is acceptable for
323  // the value to be non-real.
324  message RealIfNonzeroWeight {
325    string weight_name = 1;
326  }
327
328  message MetricCriterion {
329    // Name of the metric.
330    string name = 1;
331
332    // FL training round this metric is expected to appear in.
333    int32 training_round_index = 2;
334
335    // If none of the following is set, no matching is performed; but the
336    // metric is still expected to be present (with whatever value).
337    oneof Criterion {
338      // The reported metric must be < lt.
339      float lt = 3;
340      // The reported metric must be > gt.
341      float gt = 4;
342      // The reported metric must be <= le.
343      float le = 5;
344      // The reported metric must be >= ge.
345      float ge = 6;
346      // The reported metric must be == eq.
347      float eq = 7;
348      // The reported metric must lie in the interval.
349      Interval interval = 8;
350      // The reported metric is not NaN or +/- infinity.
351      bool real = 9;
352      // The reported metric is real (i.e., not NaN or +/- infinity) if the
353      // value of an associated weight is not 0.
354      RealIfNonzeroWeight real_if_nonzero_weight = 10;
355    }
356  }
357
358  repeated MetricCriterion metric_criterion = 1;
359
360  reserved 2;
361}
362
363// Client Phase
364// ============
365
366// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
367// In federated optimization, this will correspond to one `ServerPhase`.
368message ClientPhase {
369  // A short CamelCase name for the ClientPhase.
370  string name = 2;
371
372  // Minimum number of clients in aggregation.
373  // In secure aggregation mode this is used to configure the protocol instance
374  // in a way that server can't learn aggregated values with number of
375  // participants lower than this number.
376  // Without secure aggregation server still respects this parameter,
377  // ensuring that aggregated values never leave server RAM unless they include
378  // data from (at least) specified number of participants.
379  int32 minimum_number_of_participants = 3;
380
381  // If populated, `io_router` must be specified.
382  oneof spec {
383    // A functional interface for the TensorFlow logic the client should
384    // perform.
385    TensorflowSpec tensorflow_spec = 4 [lazy = true];
386    // Spec for client plans that issue example queries and send the query
387    // results directly to an aggregator with no or little additional
388    // processing.
389    ExampleQuerySpec example_query_spec = 9 [lazy = true];
390  }
391
392  // The specification of the inputs coming either from customer apps
393  // (Local Compute) or the federated protocol (Federated Compute).
394  oneof io_router {
395    FederatedComputeIORouter federated_compute = 5 [lazy = true];
396    LocalComputeIORouter local_compute = 6 [lazy = true];
397    FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
398        [lazy = true];
399    FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
400  }
401
402  reserved 1;
403}
404
405// TensorflowSpec message describes a single call into TensorFlow, including the
406// expected input tensors that must be fed when making that call, which
407// output tensors to be fetched, and any operations that have no output but must
408// be run. The TensorFlow session will then use the input tensors to do some
409// computation, generally reading from one or more datasets, and provide some
410// outputs.
411//
412// Conceptually, client or server code uses this proto along with an IORouter
413// to build maps of names to input tensors, vectors of output tensor names,
414// and vectors of target nodes:
415//
416//    CreateTensorflowArguments(
417//        TensorflowSpec& spec,
418//        IORouter& io_router,
419//        const vector<pair<string, Tensor>>* input_tensors,
420//        const vector<string>* output_tensor_names,
421//        const vector<string>* target_node_names);
422//
423// Where `input_tensor`, `output_tensor_names` and `target_node_names`
424// correspond to the arguments of TensorFlow C++ API for
425// `tensorflow::Session:Run()`, and the client executes only a single
426// invocation.
427//
428// Note: the execution engine never sees any concepts related to the federated
429// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
430// in, tensors out" interface. New aggregation methods can be added without
431// having to modify the execution engine / TensorflowSpec message, instead they
432// should modify the IORouter messages.
433//
434// Note: both `input_tensor_specs` and `output_tensor_specs` are full
435// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
436// only requires the names to feed the values into the session. The additional
437// dtypes/shape information must always be included in case the runtime
438// executing this TensorflowSpec wants to perform additional, optional static
439// assertions. The runtimes however are free to ignore the dtype/shapes and only
440// rely on the names if so desired.
441//
442// Assertions:
443//   - all names in `input_tensor_specs`, `output_tensor_specs`, and
444//     `target_node_names` must appear in the serialized GraphDef where
445//     the TF execution will be invoked.
446//   - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
447//     there is nothing to execute in the graph.
448message TensorflowSpec {
449  // The name of a tensor into which a unique token for the current session
450  // should be written. The corresponding tensor is a scalar string tensor and
451  // is separate from `input_tensors` as there is only one.
452  //
453  // A session token allows TensorFlow ops such as `ServeSlices` or
454  // `ExternalDataset` to refer to callbacks and other session-global objects
455  // registered before running the session. In the `ExternalDataset` case, a
456  // single dataset_token is valid for multiple `tf.data.Dataset` objects as
457  // the token can be thought of as a handle to a dataset factory.
458  string dataset_token_tensor_name = 1;
459
460  // TensorSpecs of inputs which will be passed to TF.
461  //
462  // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
463  // TensorFlow's Python API, excluding the dataset_token listed above.
464  //
465  // Assertions:
466  //   - All the tensor names designated as inputs in the corresponding IORouter
467  //     must be listed (otherwise the IORouter input work is unused).
468  //   - All placeholders in the TF graph must be listed here, with the
469  //     exception of the dataset_token which is explicitly set above (otherwise
470  //     TensorFlow will fail to execute).
471  repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
472
473  // TensorSpecs that should be fetched from TF after execution.
474  //
475  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
476  // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
477  // API.
478  //
479  // Assertions:
480  //   - The set of tensor names here must strictly match the tensor names
481  //     designated as outputs in the corresponding IORouter (if any exist).
482  repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
483
484  // Node names in the graph that should be executed, but the output not
485  // returned.
486  //
487  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
488  // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
489  // API.
490  //
491  // This is intended for use with operations that do not produce tensors, but
492  // nonetheless are required to run (e.g. serializing checkpoints).
493  repeated string target_node_names = 4;
494
495  // Map of Tensor names to constant inputs.
496  // Note: tensors specified via this message should not be included in
497  // input_tensor_specs.
498  map<string, tensorflow.TensorProto> constant_inputs = 5;
499
500  // The fields below are added by OnDevicePersonalization module.
501  // Specifies an example selection procedure.
502  ExampleSelector example_selector = 999;
503}
504
505// ExampleQuerySpec message describes client execution that issues example
506// queries and sends the query results directly to an aggregator with no or
507// little additional processing.
508// This message describes one or more example store queries that perform the
509// client side analytics computation in C++. The corresponding output vectors
510// will be converted into the expected federated protocol output format.
511// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
512message ExampleQuerySpec {
513  message OutputVectorSpec {
514    // The output vector name.
515    string vector_name = 1;
516
517    // Supported data types for the vector of information.
518    enum DataType {
519      UNSPECIFIED = 0;
520      INT32 = 1;
521      INT64 = 2;
522      BOOL = 3;
523      FLOAT = 4;
524      DOUBLE = 5;
525      BYTES = 6;
526      STRING = 7;
527    }
528
529    // The data type for each entry in the vector.
530    DataType data_type = 2;
531  }
532
533  message ExampleQuery {
534    // The `ExampleSelector` to issue the query with.
535    ExampleSelector example_selector = 1;
536
537    // Indicates that the query returns vector data and must return a single
538    // ExampleQueryResult result containing a VectorData entry matching each
539    // OutputVectorSpec.vector_name.
540    //
541    // If the query instead returns no result, then it will be treated as is if
542    // an error was returned. In that case, or if the query explicitly returns
543    // an error, then the client will abort its session.
544    //
545    // The keys in the map are the names the vectors should be aggregated under,
546    // and must match the keys in FederatedExampleQueryIORouter.aggregations.
547    map<string, OutputVectorSpec> output_vector_specs = 2;
548  }
549
550  // The queries to run.
551  repeated ExampleQuery example_queries = 1;
552}
553
554// The input and output router for Federated Compute plans.
555//
556// This proto is the glue between the federated protocol and the TensorFlow
557// execution engine. This message describes how to prepare data coming from the
558// incoming `CheckinResponse` (defined in
559// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
560// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
561// the server).
562//
563// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
564// an `input_tensors` field, which would then be a tensor that contains the
565// input TensorProtos directly and skipping disk I/O, rather than referring to a
566// checkpoint file path.
567message FederatedComputeIORouter {
568  // ===========================================================================
569  //   Inputs
570  // ===========================================================================
571  // The name of the scalar string tensor that is fed the file path to the
572  // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
573  //
574  // The federated protocol code would copy the `CheckinResponse`'s initial
575  // checkpoint to a temporary file and then pass that file path through this
576  // tensor.
577  //
578  // Ops may be added to the client graph that take this tensor as input and
579  // reads the path.
580  //
581  // This field is optional. It may be omitted if the client graph does not use
582  // an initial checkpoint.
583  string input_filepath_tensor_name = 1;
584
585  // The name of the scalar string tensor that is fed the file path to which
586  // client work should serialize the bytes to send back to the server.
587  //
588  // The federated protocol code generates a temporary file and passes the file
589  // path through this tensor.
590  //
591  // Ops may be be added to the client graph that use this tensor as an argument
592  // to write files (e.g. writing checkpoints to disk).
593  //
594  // This field is optional. It must be omitted if the client graph does not
595  // generate any output files (e.g. when all output tensors of `TensorflowSpec`
596  // use Secure Aggregation). If this field is not set, then the `ReportRequest`
597  // message in the federated protocol will not have the
598  // `Report.update_checkpoint` field set. This absence of a value here can be
599  // used to validate that the plan only uses Secure Aggregation.
600  //
601  // Conversely, if this field is set and executing the associated
602  // TensorflowSpec does not write to the path is indication of an internal
603  // framework error. The runtime should notify the caller that the computation
604  // was setup incorrectly.
605  string output_filepath_tensor_name = 2;
606
607  // ===========================================================================
608  //   Outputs
609  // ===========================================================================
610  // Describes which output tensors should be aggregated using an aggregation
611  // protocol, and the configuration for those protocols.
612  //
613  // Assertions:
614  //   - All keys must exist in the associated `TensorflowSpec` as
615  //     `output_tensor_specs.name` values.
616  map<string, AggregationConfig> aggregations = 3;
617}
618
619// The input and output router for client plans that do not use TensorFlow.
620//
621// This proto is the glue between the federated protocol and the example query
622// execution engine, describing how the query results should ultimately be
623// aggregated.
624message FederatedExampleQueryIORouter {
625  // Describes how each output vector should be aggregated using an aggregation
626  // protocol, and the configuration for those protocols.
627  // Keys must match the keys in ExampleQuerySpec.output_vector_specs.
628  // Note that currently only the TFV1CheckpointAggregation config is supported.
629  map<string, AggregationConfig> aggregations = 1;
630}
631
632// The specification for how to aggregate the associated tensor across clients
633// on the server.
634message AggregationConfig {
635  oneof protocol_config {
636    // Indicates that the given output tensor should be processed using Secure
637    // Aggregation, using the specified config options.
638    SecureAggregationConfig secure_aggregation = 2;
639
640    // Note: in the future we could add a `SimpleAggregationConfig` to add
641    // support for simple aggregation without writing to an intermediate
642    // checkpoint file first.
643
644    // Indicates that the given output tensor or vector (e.g. as produced by an
645    // ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
646    //
647    // Currently only ExampleQuerySpec output vectors are supported by this
648    // aggregation type (i.e. it cannot be used with TensorflowSpec output
649    // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
650    // its corresponding data type.
651    TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
652  }
653}
654
655// Parameters for the SecAgg protocol (go/secagg).
656//
657// Currently only the server uses the SecAgg parameters, so we only use this
658// message to signify usage of SecAgg.
659message SecureAggregationConfig {}
660
661// Parameters for the TFV1 Checkpoint Aggregation protocol.
662//
663// Currently only ExampleQuerySpec output vectors are supported by this
664// aggregation type (i.e. it cannot be used with TensorflowSpec output
665// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
666// its corresponding data type.
667message TFV1CheckpointAggregation {}
668
669// The input and output router for eligibility-computing plans. These plans
670// compute which other plans a client is eligible to run, and are returned by
671// clients via a `EligibilityEvalCheckinResponse` (defined in
672// fcp/protos/federated_api.proto).
673message FederatedComputeEligibilityIORouter {
674  // The name of the scalar string tensor that is fed the file path to the
675  // initial checkpoint (e.g. as provided via
676  // `EligibilityEvalPayload.init_checkpoint`).
677  //
678  // For more detail see the
679  // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
680  // semantics.
681  //
682  // This field is optional. It may be omitted if the client graph does not use
683  // an initial checkpoint.
684  //
685  // This tensor name must exist in the associated
686  // `TensorflowSpec.input_tensor_specs` list.
687  string input_filepath_tensor_name = 1;
688
689  // Name of the output tensor (a string scalar) containing the serialized
690  // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
691  // client code will parse this proto and place it in the
692  // `task_eligibility_info` field of the subsequent `CheckinRequest`.
693  //
694  // This tensor name must exist in the associated
695  // `TensorflowSpec.output_tensor_specs` list.
696  string task_eligibility_info_tensor_name = 2;
697}
698
699// The input and output router for Local Compute plans.
700//
701// This proto is the glue between the customers app and the TensorFlow
702// execution engine. This message describes how to prepare data coming from the
703// customer app (e.g. the input directory the app setup), and the temporary,
704// scratch output directory that will be notified to the customer app upon
705// completion of `TensorflowSpec`.
706message LocalComputeIORouter {
707  // ===========================================================================
708  //   Inputs
709  // ===========================================================================
710  // The name of the placeholder tensor representing the input resource path(s).
711  // It can be a single input directory or file path (in this case the
712  // `input_dir_tensor_name` is populated) or multiple input resources
713  // represented as a map from names to input directories or file paths (in this
714  // case the `multiple_input_resources` is populated).
715  //
716  // In the multiple input resources case, the placeholder tensors are
717  // represented as a map: the keys are the input resource names defined by the
718  // users when constructing the `LocalComputation` Python object, and the
719  // values are the corresponding placeholder tensor names created by the local
720  // computation plan builder.
721  //
722  // Apps will have the ability to create contracts between their Android code
723  // and `LocalComputation` toolkit code to place files inside the input
724  // resource paths with known names (Android code) and create graphs with ops
725  // to read from these paths (file names can be specified in toolkit code).
726  oneof input_resource {
727    string input_dir_tensor_name = 1;
728    // Directly using the `map` field is not allowed in `oneof`, so we have to
729    // wrap it in a new message.
730    MultipleInputResources multiple_input_resources = 3;
731  }
732
733  // Scalar string tensor name that will contain the output directory path.
734  //
735  // The provided directory should be considered temporary scratch that will be
736  // deleted, not persisted. It is the responsibility of the calling app to
737  // move the desired files to a permanent location once the client returns this
738  // directory back to the calling app.
739  string output_dir_tensor_name = 2;
740
741  // ===========================================================================
742  //   Outputs
743  // ===========================================================================
744  // NOTE: LocalCompute has no outputs other than what the client graph writes
745  // to `output_dir` specified above.
746}
747
748// Describes the multiple input resources in `LocalComputeIORouter`.
749message MultipleInputResources {
750  // The keys are the input resource names (defined by the users when
751  // constructing the `LocalComputation` Python object), and the values are the
752  // corresponding placeholder tensor names created by the local computation
753  // plan builder.
754  map<string, string> input_resource_tensor_name_map = 1;
755}
756
757// Describes a queue to which input is fed.
758message AsyncInputFeed {
759  // The op for enqueuing an example input.
760  string enqueue_op = 1;
761
762  // The input placeholders for the enqueue op.
763  repeated string enqueue_params = 2;
764
765  // The op for closing the input queue.
766  string close_op = 3;
767
768  // Whether the work that should be fed asynchronously is the data itself
769  // or a description of where that data lives.
770  bool feed_values_are_data = 4;
771}
772
773message DatasetInput {
774  // Initializer of iterator corresponding to tf.data.Dataset object which
775  // handles the input data. Stores name of an op in the graph.
776  string initializer = 1;
777
778  // Placeholders necessary to initialize the dataset.
779  DatasetInputPlaceholders placeholders = 2;
780
781  // Batch size to be used in tf.data.Dataset.
782  int32 batch_size = 3;
783}
784
785message DatasetInputPlaceholders {
786  // Name of placeholder corresponding to filename(s) of SSTable(s) to read data
787  // from.
788  string filename = 1;
789
790  // Name of placeholder corresponding to key_prefix initializing the
791  // SSTableDataset. Note the value fed should be unique user id, not a prefix.
792  string key_prefix = 2;
793
794  // Name of placeholder corresponding to number of rounds the local training
795  // should be run for.
796  string num_epochs = 3;
797
798  // Name of placeholder corresponding to batch size.
799  string batch_size = 4;
800}
801
802// Specifies an example selection procedure.
803message ExampleSelector {
804  // Selection criteria following a contract agreed upon between client and
805  // model designers.
806  google.protobuf.Any criteria = 1;
807
808  // A URI identifying the example collection to read from. Format should adhere
809  // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
810  // should adhere to the following rules:
811  // - The scheme ${COLLECTION} should be one of:
812  //   - "app" for app-hosted example
813  //   - "simulation" for collections not connected to an app (e.g., if used
814  //     purely for simulation)
815  // - The authority ${APP_NAME} identifies the owner of the example
816  //   collection and should be either the app's package name, or be left empty
817  //   (which means "the current app package name").
818  // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
819  //   a forward slash ("/").
820  // - The query and fragment are currently not used, but they may become used
821  //   for something in the future. To keep open that possibility they must
822  //   currently be left empty.
823  //
824  // Example: "app://com.google.some.app/someCollection/name"
825  // identifies the collection "/someCollection/name" owned and hosted by the
826  // app with package name "com.google.some.app".
827  //
828  // Example: "app:/someCollection/name" or "app:///someCollection/name"
829  // both identify the collection "/someCollection/name" owned and hosted by the
830  // app associated with the training job in which this URI appears.
831  //
832  // The path will not be interpreted by the runtime, and will be passed to the
833  // example collection implementation for interpretation. Thus, in the case of
834  // app-hosted example stores, the path segment's interpretation is a contract
835  // between the app's example store developers, and the app's model designers.
836  //
837  // If an `app://` URI is set, then the `TrainerOptions` collection name must
838  // not be set.
839  string collection_uri = 2;
840
841  // Resumption token following a contract agreed upon between client and
842  // model designers.
843  google.protobuf.Any resumption_token = 3;
844}
845
846// Selector for slices to fetch as part of a `federated_select` operation.
847message SlicesSelector {
848  // The string ID under which the slices are served.
849  //
850  // This value must have been returned by a previous call to the `serve_slices`
851  // op run during the `write_client_init` operation.
852  string served_at_id = 1;
853
854  // The indices of slices to fetch.
855  repeated int32 keys = 2;
856}
857
858// Represents slice data to be served as part of a `federated_select` operation.
859// This is used for testing.
860message SlicesTestDataset {
861  // The test data to use. The keys map to the `SlicesSelector.served_at_id`
862  // field.  E.g. test slice data for a slice with `served_at_id`="foo" and
863  // `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
864  map<string, SlicesTestData> dataset = 1;
865}
866message SlicesTestData {
867  // The test slice data to serve. Each entry's index corresponds to the slice
868  // key it is the test data for.
869  repeated bytes slice_data = 2;
870}
871
872// Server Phase V2
873// ===============
874
875// Represents a server phase with three distinct components: pre-broadcast,
876// aggregation, and post-aggregation.
877//
878// The pre-broadcast and post-aggregation components are described with
879// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
880// messages, respectively. These messages in combination with the server
881// IORouter messages specify how to set up a single TF sess.run call for each
882// component.
883//
884// The pre-broadcast logic is obtained by transforming the server_prepare TFF
885// computation in the DistributeAggregateForm. It takes the server state as
886// input, and it generates the checkpoint to broadcast to the clients and
887// potentially an intermediate server state. The intermediate server state may
888// be used by the aggregation and post-aggregation logic.
889//
890// The aggregation logic represents the aggregation of client results at the
891// server and is described using a list of ServerAggregationConfig messages.
892// Each ServerAggregationConfig message describes a single aggregation operation
893// on a set of input/output tensors. The input tensors may represent parts of
894// either the client results or the intermediate server state. These messages
895// are obtained by transforming the client_to_server_aggregation TFF computation
896// in the DistributeAggregateForm.
897//
898// The post-aggregation logic is obtained by transforming the server_result TFF
899// computation in the DistributeAggregateForm. It takes the intermediate server
900// state and the aggregated client results as input, and it generates the new
901// server state and potentially other server-side output.
902//
903// Note that while a ServerPhaseV2 message can be generated for all types of
904// intrinsics, it is currently only compatible with the ClientPhase message if
905// the aggregations being used are exclusively federated_sum (not SecAgg). If
906// this compatibility requirement is satisfied, it is also valid to run the
907// aggregation portion of this ServerPhaseV2 message alongside the pre- and
908// post-aggregation logic from the original ServerPhase message. Ultimately,
909// we expect the full ServerPhaseV2 message to be run and the ServerPhase
910// message to be deprecated.
911message ServerPhaseV2 {
912  // A short CamelCase name for the ServerPhaseV2.
913  string name = 1;
914
915  // A functional interface for the TensorFlow logic the server should perform
916  // prior to the server-to-client broadcast. This should be used with the
917  // TensorFlow graph defined in server_graph_prepare_bytes.
918  TensorflowSpec tensorflow_spec_prepare = 3;
919
920  // The specification of inputs needed by the server_prepare TF logic.
921  oneof server_prepare_io_router {
922    ServerPrepareIORouter prepare_router = 4;
923  }
924
925  // A list of client-to-server aggregations to perform.
926  repeated ServerAggregationConfig aggregations = 2;
927
928  // A functional interface for the TensorFlow logic the server should perform
929  // post-aggregation. This should be used with the TensorFlow graph defined
930  // in server_graph_result_bytes.
931  TensorflowSpec tensorflow_spec_result = 5;
932
933  // The specification of inputs and outputs needed by the server_result TF
934  // logic.
935  oneof server_result_io_router {
936    ServerResultIORouter result_router = 6;
937  }
938}
939
940// Routing for server_prepare graph
941message ServerPrepareIORouter {
942  // The name of the scalar string tensor in the server_prepare TF graph that
943  // is fed the filepath to the initial server state checkpoint. The
944  // server_prepare logic reads from this filepath.
945  string prepare_server_state_input_filepath_tensor_name = 1;
946
947  // The name of the scalar string tensor in the server_prepare TF graph that
948  // is fed the filepath where the client checkpoint should be stored. The
949  // server_prepare logic writes to this filepath.
950  string prepare_output_filepath_tensor_name = 2;
951
952  // The name of the scalar string tensor in the server_prepare TF graph that
953  // is fed the filepath where the intermediate state checkpoint should be
954  // stored. The server_prepare logic writes to this filepath. The intermediate
955  // state checkpoint will be consumed by both the logic used to set parameters
956  // for aggregation and the post-aggregation logic.
957  string prepare_intermediate_state_output_filepath_tensor_name = 3;
958}
959
960// Routing for server_result graph
961message ServerResultIORouter {
962  // The name of the scalar string tensor in the server_result TF graph that is
963  // fed the filepath to the intermediate state checkpoint. The server_result
964  // logic reads from this filepath.
965  string result_intermediate_state_input_filepath_tensor_name = 1;
966
967  // The name of the scalar string tensor in the server_result TF graph that is
968  // fed the filepath to the aggregated client result checkpoint. The
969  // server_result logic reads from this filepath.
970  string result_aggregate_result_input_filepath_tensor_name = 2;
971
972  // The name of the scalar string tensor in the server_result TF graph that is
973  // fed the filepath where the updated server state should be stored. The
974  // server_result logic writes to this filepath.
975  string result_server_state_output_filepath_tensor_name = 3;
976}
977
978// Represents a single aggregation operation, combining one or more input
979// tensors from a collection of clients into one or more output tensors on the
980// server.
981message ServerAggregationConfig {
982  // The uri of the aggregation intrinsic (e.g. 'federated_sum').
983  string intrinsic_uri = 1;
984
985  // Describes an argument to the aggregation operation.
986  message IntrinsicArg {
987    oneof arg {
988      // Refers to a tensor within the checkpoint provided by each client.
989      tensorflow.TensorSpecProto input_tensor = 2;
990
991      // Refers to a tensor within the intermediate server state checkpoint.
992      tensorflow.TensorSpecProto state_tensor = 3;
993    }
994  }
995
996  // List of arguments for the aggregation operation. The arguments can be
997  // dependent on client data (in which case they must be retrieved from
998  // clients) or they can be independent of client data (in which case they
999  // can be configured server-side). For now we assume all client-independent
1000  // arguments are constants. The arguments must be in the order expected by
1001  // the server.
1002  repeated IntrinsicArg intrinsic_args = 4;
1003
1004  // List of server-side outputs produced by the aggregation operation.
1005  repeated tensorflow.TensorSpecProto output_tensors = 5;
1006
1007  // List of inner aggregation intrinsics. This can be used to delegate parts
1008  // of the aggregation logic (e.g. a groupby intrinsic may want to delegate
1009  // a sum operation to a sum intrinsic).
1010  repeated ServerAggregationConfig inner_aggregations = 6;
1011}
1012
1013// Server Phase
1014// ============
1015
1016// Represents a server phase which implements TF-based aggregation of multiple
1017// client updates.
1018//
1019// There are two different modes of aggregation that are described
1020// by the values in this message. The first is aggregation that is
1021// coming from coordinated sets of clients. This includes aggregation
1022// done via checkpoints from clients or aggregation done over a set
1023// of clients by a process like secure aggregation. The results of
1024// this first aggregation are saved to intermediate aggregation
1025// checkpoints. The second aggregation then comes from taking
1026// these intermediate checkpoints and aggregating over them.
1027//
1028// These two different modes of aggregation are done on different
1029// servers, the first in the 'L1' servers and the second in the
1030// 'L2' servers, so we use this nomenclature to describe these
1031// phases below.
1032//
1033// The ServerPhase message is currently in the process of being replaced by the
1034// ServerPhaseV2 message as we switch the plan building pipeline to use
1035// DistributeAggregateForm instead of MapReduceForm. During the migration
1036// process, we may generate both messages and use components from either
1037// message during execution.
1038//
1039message ServerPhase {
1040  // A short CamelCase name for the ServerPhase.
1041  string name = 8;
1042
1043  // ===========================================================================
1044  // L1 "Intermediate" Aggregation.
1045  //
1046  // This is the initial aggregation that creates partial aggregates from client
1047  // results. L1 Aggregation may be run on many different instances.
1048  //
1049  // Pre-condition:
1050  //   The execution environment has loaded the graph from `server_graph_bytes`.
1051
1052  // 1. Initialize the phase.
1053  //
1054  // Operation to run before the first aggregation happens.
1055  // For instance, clears the accumulators so that a new aggregation can begin.
1056  string phase_init_op = 1;
1057
1058  // 2. For each client in set of clients:
1059  //   a. Restore variables from the client checkpoint.
1060  //
1061  // Loads a checkpoint from a single client written via
1062  // `FederatedComputeIORouter.output_filepath_tensor_name`.  This is done once
1063  // for every client checkpoint in a round.
1064  CheckpointOp read_update = 3;
1065  //   b. Aggregate the data coming from the client checkpoint.
1066  //
1067  // An operation that aggregates the data from read_update.
1068  // Generally this will add to accumulators and it may leverage internal data
1069  // inside the graph to adjust the weights of the Tensors.
1070  //
1071  // Executed once for each `read_update`, to (for example) update accumulator
1072  // variables using the values loaded during `read_update`.
1073  string aggregate_into_accumulators_op = 4;
1074
1075  // 3. After all clients have been aggregated, possibly restore
1076  //    variables that have been aggregated via a separate process.
1077  //
1078  // Optionally restores variables where aggregation is done across
1079  // an entire round of client data updates. In contrast to `read_update`,
1080  // which restores once per client, this occurs after all clients
1081  // in a round have been processed. This allows, for example, side
1082  // channels where aggregation is done by a separate process (such
1083  // as in secure aggregation), in which the side channel aggregated
1084  // tensor is passed to the `before_restore_op` which ensure the
1085  // variables are restored properly. The `after_restore_op` will then
1086  // be responsible for performing the accumulation.
1087  //
1088  // Note that in current use this should not have a SaverDef, but
1089  // should only be used for side channels.
1090  CheckpointOp read_aggregated_update = 10;
1091
1092  // 4. Write the aggregated variables to an intermediate checkpoint.
1093  //
1094  // We require that `aggregate_into_accumulators_op` is associative and
1095  // commutative, so that the aggregates can be computed across
1096  // multiple TensorFlow sessions.
1097  // As an example, say we are computing the sum of 5 client updates:
1098  //    A = X1 + X2 + X3 + X4 + X5
1099  // We can always do this in one session by calling `read_update`j and
1100  // `aggregate_into_accumulators_op` once for each client checkpoint.
1101  //
1102  // Alternatively, we could compute:
1103  //    A1 = X1 + X2 in one TensorFlow session, and
1104  //    A2 = X3 + X4 + X5 in a different session.
1105  // Each of these sessions can then write their accumulator state
1106  // with the `write_intermediate_update` CheckpointOp, and a yet another third
1107  // session can then call `read_intermediate_update` and
1108  // `aggregate_into_accumulators_op` on each of these checkpoints to compute:
1109  //    A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
1110  CheckpointOp write_intermediate_update = 7;
1111  // End L1 "Intermediate" Aggregation.
1112  // ===========================================================================
1113
1114  // ===========================================================================
1115  // L2 Aggregation and Coordinator.
1116  //
1117  // This aggregates intermediate checkpoints from L1 Aggregation and performs
1118  // the finalizing of the update. Unlike L1 there will only be one instance
1119  // that does this aggregation.
1120
1121  // Pre-condition:
1122  //   The execution environment has loaded the graph from `server_graph_bytes`
1123  //   and restored the global model using `server_savepoint` from the parent
1124  //   `Plan` message.
1125
1126  // 1. Initialize the phase.
1127  //
1128  // This currently re-uses the `phase_init_op` from L1 aggregation above.
1129
1130  // 2. Write a checkpoint that can be sent to the client.
1131  //
1132  // Generates a checkpoint to be sent to the client, to be read by
1133  // `FederatedComputeIORouter.input_filepath_tensor_name`.
1134
1135  CheckpointOp write_client_init = 2;
1136
1137  // 3. For each intermediate checkpoint:
1138  //   a. Restore variables from the intermediate checkpoint.
1139  //
1140  // The corresponding read checkpoint op to the write_intermediate_update.
1141  // This is used instead of read_update for intermediate checkpoints because
1142  // the format of these updates may be different than those used in updates
1143  // from clients (which may, for example, be compressed).
1144  CheckpointOp read_intermediate_update = 9;
1145  //   b. Aggregate the data coming from the intermediate checkpoint.
1146  //
1147  // An operation that aggregates the data from `read_intermediate_update`.
1148  // Generally this will add to accumulators and it may leverage internal data
1149  // inside the graph to adjust the weights of the Tensors.
1150  string intermediate_aggregate_into_accumulators_op = 11;
1151
1152  // 4. Write the aggregated intermediate variables to a checkpoint.
1153  //
1154  // This is used for downstream, cross-round aggregation of metrics.
1155  // These variables will be read back into a session with
1156  // read_intermediate_update.
1157  //
1158  // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
1159  // to disable writing accumulator checkpoints.
1160  CheckpointOp write_accumulators = 12;
1161
1162  // 5. Finalize the round.
1163  //
1164  // This can include:
1165  // - Applying the update aggregated from the intermediate checkpoints to the
1166  //   global model and other updates to cross-round state variables.
1167  // - Computing final round metric values (e.g. the `report` of a
1168  //   `tff.federated_aggregate`).
1169  string apply_aggregrated_updates_op = 5;
1170
1171  // 5. Fetch the server aggregated metrics.
1172  //
1173  // A list of names of metric variables to fetch from the TensorFlow session.
1174  repeated Metric metrics = 6;
1175
1176  // 6. Serialize the updated server state (e.g. the coefficients of the global
1177  //    model in FL) using `server_savepoint` in the parent `Plan` message.
1178
1179  // End L2 Aggregation.
1180  // ===========================================================================
1181}
1182
1183// Represents the server phase in an eligibility computation.
1184//
1185// This phase produces a checkpoint to be sent to clients. This checkpoint is
1186// then used as an input to the clients' task eligibility computations.
1187// This phase *does not include any aggregation.*
1188message ServerEligibilityComputationPhase {
1189  // A short CamelCase name for the ServerEligibilityComputationPhase.
1190  string name = 1;
1191
1192  // The names of the TensorFlow nodes to run in order to produce output.
1193  repeated string target_node_names = 2;
1194
1195  // The specification of inputs and outputs to the TensorFlow graph.
1196  oneof server_eligibility_io_router {
1197    TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
1198  }
1199}
1200
1201// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
1202// which takes a single `TaskEligibilityContext` as input.
1203message TEContextServerEligibilityIORouter {
1204  // The name of the scalar string tensor that must be fed a serialized
1205  // `TaskEligibilityContext`.
1206  string context_proto_input_tensor_name = 1;
1207
1208  // The name of the scalar string tensor that must be fed the path to which
1209  // the server graph should write the checkpoint file to be sent to the client.
1210  string output_filepath_tensor_name = 2;
1211}
1212
1213// Plan
1214// =====
1215
1216// Represents the overall plan for performing federated optimization or
1217// personalization, as handed over to the production system. This will
1218// typically be split down into individual pieces for different production
1219// parts, e.g. server and client side.
1220// NEXT_TAG: 15
1221message Plan {
1222  reserved 1, 3, 5;
1223
1224  // The actual type of the server_*_graph_bytes fields below is expected to be
1225  // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
1226  // for two reasons.
1227  // 1) We may use execution engines other than TensorFlow.
1228  // 2) We wish to avoid the cost of deserialized and re-serializing large
1229  // graphs, in the Federated Learning service.
1230
1231  // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
1232  // server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
1233  // If we're using a MapReduceForm-based server implementation, only
1234  // server_graph_bytes will be used. If we're using a DistributeAggregateForm-
1235  // based server implementation, only server_graph_prepare_bytes and
1236  // server_graph_result_bytes will be used.
1237
1238  // Optional. The TensorFlow graph used for all server processing described by
1239  // ServerPhase. For personalization, this will not be set.
1240  google.protobuf.Any server_graph_bytes = 7;
1241
1242  // Optional. The TensorFlow graph used for all server processing described by
1243  // ServerPhaseV2.tensorflow_spec_prepare.
1244  google.protobuf.Any server_graph_prepare_bytes = 13;
1245
1246  // Optional. The TensorFlow graph used for all server processing described by
1247  // ServerPhaseV2.tensorflow_spec_result.
1248  google.protobuf.Any server_graph_result_bytes = 14;
1249
1250  // A savepoint to sync the server checkpoint with a persistent
1251  // storage system.  The storage initially holds a seeded checkpoint
1252  // which can subsequently read and updated by this savepoint.
1253  // Optional-- not present in eligibility computation plans (those with a
1254  // ServerEligibilityComputationPhase). This is used in conjunction with
1255  // ServerPhase only.
1256  CheckpointOp server_savepoint = 2;
1257
1258  // Required. The TensorFlow graph that describes the TensorFlow logic a client
1259  // should perform. It should be consistent with the `TensorflowSpec` field in
1260  // the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
1261  // The TensorFlow graph is stored in serialized form for two reasons.
1262  // 1) We may use execution engines other than TensorFlow.
1263  // 2) We wish to avoid the cost of deserialized and re-serializing large
1264  // graphs, in the Federated Learning service.
1265  google.protobuf.Any client_graph_bytes = 8;
1266
1267  // Optional. The FlatBuffer used for TFLite training.
1268  // It contains the same model information as the client_graph_bytes, but with
1269  // a different format.
1270  bytes client_tflite_graph_bytes = 12;
1271
1272  // A pair of client phase and server phase which are processed in
1273  // sync. The server execution defines how the results of a client
1274  // phase are aggregated, and how the checkpoints for clients are
1275  // generated.
1276  message Phase {
1277    // Required. The client phase.
1278    ClientPhase client_phase = 1;
1279
1280    // Optional. Server phase for TF-based aggregation; not provided for
1281    // personalization or eligibility tasks.
1282    ServerPhase server_phase = 2;
1283
1284    // Optional. Server phase for native aggregation; only provided for tasks
1285    // that have enabled the corresponding flag.
1286    ServerPhaseV2 server_phase_v2 = 4;
1287
1288    // Optional. Only provided for eligibility tasks.
1289    ServerEligibilityComputationPhase server_eligibility_phase = 3;
1290  }
1291
1292  // A pair of client and server computations to run.
1293  repeated Phase phase = 4;
1294
1295  // Metrics that are persistent across different phases. This
1296  // includes, for example, counters that track how much work of
1297  // different kinds has been done.
1298  repeated Metric metrics = 6;
1299
1300  // Describes how metrics in both the client and server phases should be
1301  // aggregated.
1302  repeated OutputMetric output_metrics = 10;
1303
1304  // Version of the plan:
1305  // version == 0 - Old plan without version field, containing b/65131070
1306  // version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
1307  int32 version = 9;
1308
1309  // A TensorFlow ConfigProto packed in an Any.
1310  //
1311  // If this field is unset, if the Any proto is set but empty, or if the Any
1312  // proto is populated with an empty ConfigProto (i.e. its `type_url` field is
1313  // set, but the `value` field is empty) then the client implementation may
1314  // choose a set of configuration parameters to provide to TensorFlow by
1315  // default.
1316  //
1317  // In all other cases this field must contain a valid packed ConfigProto
1318  // (invalid values will result in an error at execution time), and in this
1319  // case the client will not provide any other configuration parameters by
1320  // default.
1321  google.protobuf.Any tensorflow_config_proto = 11;
1322}
1323
1324// Represents a client part of the plan of federated optimization.
1325// This also used to describe a client-only plan for standalone on-device
1326// training, known as personalization.
1327// NEXT_TAG: 6
1328message ClientOnlyPlan {
1329  reserved 3;
1330
1331  // The graph to use for training, in binary form.
1332  bytes graph = 1;
1333
1334  // Optional. The flatbuffer used for TFLite training.
1335  // Whether "graph" or "tflite_graph" is used for training is up to the client
1336  // code to allow for a flag-controlled a/b rollout.
1337  bytes tflite_graph = 5;
1338
1339  // The client phase to execute.
1340  ClientPhase phase = 2;
1341
1342  // A TensorFlow ConfigProto.
1343  google.protobuf.Any tensorflow_config_proto = 4;
1344}
1345
1346// Represents the cross round aggregation portion for user defined measurements.
1347// This is used by tools that process / analyze accumulator checkpoints
1348// after a round of computation, to achieve aggregation beyond a round.
1349message CrossRoundAggregationExecution {
1350  // Operation to run before reading accumulator checkpoint.
1351  string init_op = 1;
1352
1353  // Reads accumulator checkpoint.
1354  CheckpointOp read_aggregated_update = 2;
1355
1356  // Operation to merge loaded checkpoint into accumulator.
1357  string merge_op = 3;
1358
1359  // Reads and writes the final aggregated accumulator vars.
1360  CheckpointOp read_write_final_accumulators = 6;
1361
1362  // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
1363  // to the user defined name of the signal.
1364  repeated Measurement measurements = 4;
1365
1366  // The `tf.Graph` used for aggregating accumulator checkpoints when
1367  // loading metrics.
1368  google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
1369}
1370
1371message Measurement {
1372  // Name of a TensorFlow op to run to read/fetch the value of this measurement.
1373  string read_op_name = 1;
1374
1375  // A human-readable name for the measurement. Names are usually
1376  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
1377  string name = 2;
1378
1379  reserved 3;
1380
1381  // A serialized `tff.Type` for the measurement.
1382  bytes tff_type = 4;
1383}
1384