xref: /aosp_15_r20/external/federated-compute/fcp/protos/plan.proto (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker// Copyright 2021 Google LLC
2*14675a02SAndroid Build Coastguard Worker//
3*14675a02SAndroid Build Coastguard Worker// Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker// you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker// You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker//
7*14675a02SAndroid Build Coastguard Worker//      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker//
9*14675a02SAndroid Build Coastguard Worker// Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker// distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker// See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker// limitations under the License.
14*14675a02SAndroid Build Coastguard Worker
15*14675a02SAndroid Build Coastguard Workersyntax = "proto3";
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Workerpackage google.internal.federated.plan;
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Workerimport "google/protobuf/any.proto";
20*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/framework/tensor.proto";
21*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/framework/types.proto";
22*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/protobuf/saver.proto";
23*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/protobuf/struct.proto";
24*14675a02SAndroid Build Coastguard Worker
25*14675a02SAndroid Build Coastguard Workeroption java_package = "com.google.internal.federated.plan";
26*14675a02SAndroid Build Coastguard Workeroption java_multiple_files = true;
27*14675a02SAndroid Build Coastguard Workeroption java_outer_classname = "PlanProto";
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker// Primitives
30*14675a02SAndroid Build Coastguard Worker// ===========
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Worker// Represents an operation to save or restore from a checkpoint.  Some
33*14675a02SAndroid Build Coastguard Worker// instances of this message may only be used either for restore or for
34*14675a02SAndroid Build Coastguard Worker// save, others for both directions. This is documented together with
35*14675a02SAndroid Build Coastguard Worker// their usage.
36*14675a02SAndroid Build Coastguard Worker//
37*14675a02SAndroid Build Coastguard Worker// This op has four essential uses:
38*14675a02SAndroid Build Coastguard Worker//   1. read and apply a checkpoint.
39*14675a02SAndroid Build Coastguard Worker//   2. write a checkpoint.
40*14675a02SAndroid Build Coastguard Worker//   3. read and apply from an aggregated side channel.
41*14675a02SAndroid Build Coastguard Worker//   4. write to a side channel (grouped with write a checkpoint).
42*14675a02SAndroid Build Coastguard Worker// We should consider splitting this into four separate messages.
43*14675a02SAndroid Build Coastguard Workermessage CheckpointOp {
44*14675a02SAndroid Build Coastguard Worker  // An optional standard saver def. If not provided, only the
45*14675a02SAndroid Build Coastguard Worker  // op(s) below will be executed. This must be a version 1 SaverDef.
46*14675a02SAndroid Build Coastguard Worker  tensorflow.SaverDef saver_def = 1;
47*14675a02SAndroid Build Coastguard Worker
48*14675a02SAndroid Build Coastguard Worker  // An optional operation to run before the saver_def is executed for
49*14675a02SAndroid Build Coastguard Worker  // restore.
50*14675a02SAndroid Build Coastguard Worker  string before_restore_op = 2;
51*14675a02SAndroid Build Coastguard Worker
52*14675a02SAndroid Build Coastguard Worker  // An optional operation to run after the saver_def has been
53*14675a02SAndroid Build Coastguard Worker  // executed for restore. If side_channel_tensors are provided, then
54*14675a02SAndroid Build Coastguard Worker  // they should be provided in a feed_dict to this op.
55*14675a02SAndroid Build Coastguard Worker  string after_restore_op = 3;
56*14675a02SAndroid Build Coastguard Worker
57*14675a02SAndroid Build Coastguard Worker  // An optional operation to run before the saver_def will be
58*14675a02SAndroid Build Coastguard Worker  // executed for save.
59*14675a02SAndroid Build Coastguard Worker  string before_save_op = 4;
60*14675a02SAndroid Build Coastguard Worker
61*14675a02SAndroid Build Coastguard Worker  // An optional operation to run after the saver_def has been
62*14675a02SAndroid Build Coastguard Worker  // executed for save. If there are side_channel_tensors, this op
63*14675a02SAndroid Build Coastguard Worker  // should be run after the side_channel_tensors have been fetched.
64*14675a02SAndroid Build Coastguard Worker  string after_save_op = 5;
65*14675a02SAndroid Build Coastguard Worker
66*14675a02SAndroid Build Coastguard Worker  // In addition to being saved and restored from a checkpoint, one can
67*14675a02SAndroid Build Coastguard Worker  // also save and restore via a side channel. The keys in this map are
68*14675a02SAndroid Build Coastguard Worker  // the names of the tensors transmitted by the side channel. These (key)
69*14675a02SAndroid Build Coastguard Worker  // tensors should be read off just before saving a SaveDef and used
70*14675a02SAndroid Build Coastguard Worker  // by the code that handles the side channel. Any variables provided this
71*14675a02SAndroid Build Coastguard Worker  // way should NOT be saved in the SaveDef.
72*14675a02SAndroid Build Coastguard Worker  //
73*14675a02SAndroid Build Coastguard Worker  // For restoring, the variables that are provided by the side channel
74*14675a02SAndroid Build Coastguard Worker  // are restored differently than those for a checkpoint. For those from
75*14675a02SAndroid Build Coastguard Worker  // the side channel, these should be restored by calling the before_restore_op
76*14675a02SAndroid Build Coastguard Worker  // with a feed dict whose keys are the restore_names in the SideChannel and
77*14675a02SAndroid Build Coastguard Worker  // whose values are the values to be restored.
78*14675a02SAndroid Build Coastguard Worker  map<string, SideChannel> side_channel_tensors = 6;
79*14675a02SAndroid Build Coastguard Worker
80*14675a02SAndroid Build Coastguard Worker  // An optional name of a tensor in to which a unique token for the current
81*14675a02SAndroid Build Coastguard Worker  // session should be written.
82*14675a02SAndroid Build Coastguard Worker  //
83*14675a02SAndroid Build Coastguard Worker  // This session identifier allows TensorFlow ops such as `ServeSlices` or
84*14675a02SAndroid Build Coastguard Worker  // `ExternalDataset` to refer to callbacks and other session-global objects
85*14675a02SAndroid Build Coastguard Worker  // registered before running the session.
86*14675a02SAndroid Build Coastguard Worker  string session_token_tensor_name = 7;
87*14675a02SAndroid Build Coastguard Worker}
88*14675a02SAndroid Build Coastguard Worker
89*14675a02SAndroid Build Coastguard Workermessage SideChannel {
90*14675a02SAndroid Build Coastguard Worker  // A side channel whose variables are processed via SecureAggregation.
91*14675a02SAndroid Build Coastguard Worker  // This side channel implements aggregation via sum over a set of
92*14675a02SAndroid Build Coastguard Worker  // clients, so the restored tensor will be a sum of multiple clients
93*14675a02SAndroid Build Coastguard Worker  // inputs into the side channel. Hence this will restore during the
94*14675a02SAndroid Build Coastguard Worker  // read_aggregate_update restore, not the per-client read_update restore.
95*14675a02SAndroid Build Coastguard Worker  message SecureAggregand {
96*14675a02SAndroid Build Coastguard Worker    message Dimension {
97*14675a02SAndroid Build Coastguard Worker      int64 size = 1;
98*14675a02SAndroid Build Coastguard Worker    }
99*14675a02SAndroid Build Coastguard Worker
100*14675a02SAndroid Build Coastguard Worker    // Dimensions of the aggregand. This is used by the secure aggregation
101*14675a02SAndroid Build Coastguard Worker    // protocol in its early rounds, not as redundant info which could be
102*14675a02SAndroid Build Coastguard Worker    // obtained by reading the dimensions of the tensor itself.
103*14675a02SAndroid Build Coastguard Worker    repeated Dimension dimension = 3;
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Worker    // The data type anticipated by the server-side graph.
106*14675a02SAndroid Build Coastguard Worker    tensorflow.DataType dtype = 4;
107*14675a02SAndroid Build Coastguard Worker
108*14675a02SAndroid Build Coastguard Worker    // SecureAggregation will compute sum modulo this modulus.
109*14675a02SAndroid Build Coastguard Worker    message FixedModulus {
110*14675a02SAndroid Build Coastguard Worker      uint64 modulus = 1;
111*14675a02SAndroid Build Coastguard Worker    }
112*14675a02SAndroid Build Coastguard Worker
113*14675a02SAndroid Build Coastguard Worker    // SecureAggregation will for each shard compute sum modulo m with m at
114*14675a02SAndroid Build Coastguard Worker    // least (1 + shard_size * (base_modulus - 1)), then aggregate
115*14675a02SAndroid Build Coastguard Worker    // shard results with non-modular addition. Here, shard_size is the number
116*14675a02SAndroid Build Coastguard Worker    // of clients in the shard.
117*14675a02SAndroid Build Coastguard Worker    //
118*14675a02SAndroid Build Coastguard Worker    // Note that the modulus for each shard will be greater than the largest
119*14675a02SAndroid Build Coastguard Worker    // possible (non-modular) sum of the inputs to that shard.  That is,
120*14675a02SAndroid Build Coastguard Worker    // assuming each client has input on range [0, base_modulus), the result
121*14675a02SAndroid Build Coastguard Worker    // will be identical to non-modular addition (i.e. federated_sum).
122*14675a02SAndroid Build Coastguard Worker    //
123*14675a02SAndroid Build Coastguard Worker    // While any m >= (1 + shard_size * (base_modulus - 1)), the current
124*14675a02SAndroid Build Coastguard Worker    // implementation takes
125*14675a02SAndroid Build Coastguard Worker    // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the
126*14675a02SAndroid Build Coastguard Worker    // smallest possible value of m that is also a power of 2.  This choice is
127*14675a02SAndroid Build Coastguard Worker    // made because (a) it uses the same number of bits per vector entry as
128*14675a02SAndroid Build Coastguard Worker    // valid smaller m, using the current on-the-wire encoding scheme, and (b)
129*14675a02SAndroid Build Coastguard Worker    // it enables the underlying mask-generation PRNG to run in its most
130*14675a02SAndroid Build Coastguard Worker    // computationally efficient mode, which can be up to 2x faster.
131*14675a02SAndroid Build Coastguard Worker    message ModulusTimesShardSize {
132*14675a02SAndroid Build Coastguard Worker      uint64 base_modulus = 1;
133*14675a02SAndroid Build Coastguard Worker    }
134*14675a02SAndroid Build Coastguard Worker
135*14675a02SAndroid Build Coastguard Worker    oneof modulus_scheme {
136*14675a02SAndroid Build Coastguard Worker      // Bitwidth of the aggregand.
137*14675a02SAndroid Build Coastguard Worker      //
138*14675a02SAndroid Build Coastguard Worker      // This is the bitwidth of an input value (i.e. the bitwidth that
139*14675a02SAndroid Build Coastguard Worker      // quantization should target).  The Secure Aggregation bitwidth (i.e.,
140*14675a02SAndroid Build Coastguard Worker      // the bitwidth of the *sum* of the input values) will be a function of
141*14675a02SAndroid Build Coastguard Worker      // this bitwidth and the number of participating clients, as negotiated
142*14675a02SAndroid Build Coastguard Worker      // with the server when the protocol is initiated.
143*14675a02SAndroid Build Coastguard Worker      //
144*14675a02SAndroid Build Coastguard Worker      // Deprecated; prefer fixed_modulus instead.
145*14675a02SAndroid Build Coastguard Worker      int32 quantized_input_bitwidth = 2 [deprecated = true];
146*14675a02SAndroid Build Coastguard Worker
147*14675a02SAndroid Build Coastguard Worker      FixedModulus fixed_modulus = 5;
148*14675a02SAndroid Build Coastguard Worker      ModulusTimesShardSize modulus_times_shard_size = 6;
149*14675a02SAndroid Build Coastguard Worker    }
150*14675a02SAndroid Build Coastguard Worker
151*14675a02SAndroid Build Coastguard Worker    reserved 1;
152*14675a02SAndroid Build Coastguard Worker  }
153*14675a02SAndroid Build Coastguard Worker
154*14675a02SAndroid Build Coastguard Worker  // What type of side channel is used.
155*14675a02SAndroid Build Coastguard Worker  oneof type {
156*14675a02SAndroid Build Coastguard Worker    SecureAggregand secure_aggregand = 1;
157*14675a02SAndroid Build Coastguard Worker  }
158*14675a02SAndroid Build Coastguard Worker
159*14675a02SAndroid Build Coastguard Worker  // When restoring the name of the tensor to restore to.  This is the name
160*14675a02SAndroid Build Coastguard Worker  // (key) supplied in the feed_dict in the before_restore_op in order to
161*14675a02SAndroid Build Coastguard Worker  // restore the tensor provided by the side channel (which will be the
162*14675a02SAndroid Build Coastguard Worker  // value in the feed_dict).
163*14675a02SAndroid Build Coastguard Worker  string restore_name = 2;
164*14675a02SAndroid Build Coastguard Worker}
165*14675a02SAndroid Build Coastguard Worker
166*14675a02SAndroid Build Coastguard Worker// Container for a metric used by the internal toolkit.
167*14675a02SAndroid Build Coastguard Workermessage Metric {
168*14675a02SAndroid Build Coastguard Worker  // Name of an Op to run to read the value.
169*14675a02SAndroid Build Coastguard Worker  string variable_name = 1;
170*14675a02SAndroid Build Coastguard Worker
171*14675a02SAndroid Build Coastguard Worker  // A human-readable name for the statistic. Metric names are usually
172*14675a02SAndroid Build Coastguard Worker  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
173*14675a02SAndroid Build Coastguard Worker  // Must be 7-bit ASCII and under 122 characters.
174*14675a02SAndroid Build Coastguard Worker  string stat_name = 2;
175*14675a02SAndroid Build Coastguard Worker
176*14675a02SAndroid Build Coastguard Worker  // The human-readable name of another metric by which this metric should be
177*14675a02SAndroid Build Coastguard Worker  // normalized, if any. If empty, this Metric should be aggregated with simple
178*14675a02SAndroid Build Coastguard Worker  // summation. If not empty, the Metric is aggregated according to
179*14675a02SAndroid Build Coastguard Worker  // weighted_metric_sum = sum_i (metric_i * weight_i)
180*14675a02SAndroid Build Coastguard Worker  // weight_sum = sum_i weight_i
181*14675a02SAndroid Build Coastguard Worker  // average_metric_value = weighted_metric_sum / weight_sum
182*14675a02SAndroid Build Coastguard Worker  string weight_name = 3;
183*14675a02SAndroid Build Coastguard Worker}
184*14675a02SAndroid Build Coastguard Worker
185*14675a02SAndroid Build Coastguard Worker// Controls the format of output metrics users receive. Represents instructions
186*14675a02SAndroid Build Coastguard Worker// for how metrics are to be output to users, controlling the end format of
187*14675a02SAndroid Build Coastguard Worker// the metric users receive.
188*14675a02SAndroid Build Coastguard Workermessage OutputMetric {
189*14675a02SAndroid Build Coastguard Worker  // Metric name.
190*14675a02SAndroid Build Coastguard Worker  string name = 1;
191*14675a02SAndroid Build Coastguard Worker
192*14675a02SAndroid Build Coastguard Worker  oneof value_source {
193*14675a02SAndroid Build Coastguard Worker    // A metric representing one stat with aggregation type sum.
194*14675a02SAndroid Build Coastguard Worker    SumOptions sum = 2;
195*14675a02SAndroid Build Coastguard Worker
196*14675a02SAndroid Build Coastguard Worker    // A metric representing a ratio between metrics with aggregation
197*14675a02SAndroid Build Coastguard Worker    // type average.
198*14675a02SAndroid Build Coastguard Worker    AverageOptions average = 3;
199*14675a02SAndroid Build Coastguard Worker
200*14675a02SAndroid Build Coastguard Worker    // A metric that is not aggregated by the MetricReportAggregator or
201*14675a02SAndroid Build Coastguard Worker    // metrics_loader. This includes metrics like 'num_server_updates' that are
202*14675a02SAndroid Build Coastguard Worker    // aggregated in TensorFlow.
203*14675a02SAndroid Build Coastguard Worker    NoneOptions none = 4;
204*14675a02SAndroid Build Coastguard Worker
205*14675a02SAndroid Build Coastguard Worker    // A metric representing one stat with aggregation type only sample.
206*14675a02SAndroid Build Coastguard Worker    // Samples at most 101 clients' values.
207*14675a02SAndroid Build Coastguard Worker    OnlySampleOptions only_sample = 5;
208*14675a02SAndroid Build Coastguard Worker  }
209*14675a02SAndroid Build Coastguard Worker  // Iff True, the metric will be plotted in the default view of the
210*14675a02SAndroid Build Coastguard Worker  // task level Colab automatically.
211*14675a02SAndroid Build Coastguard Worker  oneof visualization_info {
212*14675a02SAndroid Build Coastguard Worker    bool auto_plot = 6 [deprecated = true];
213*14675a02SAndroid Build Coastguard Worker    VisualizationSpec plot_spec = 7;
214*14675a02SAndroid Build Coastguard Worker  }
215*14675a02SAndroid Build Coastguard Worker}
216*14675a02SAndroid Build Coastguard Worker
217*14675a02SAndroid Build Coastguard Workermessage VisualizationSpec {
218*14675a02SAndroid Build Coastguard Worker  // Different allowable plot types.
219*14675a02SAndroid Build Coastguard Worker  enum VisualizationType {
220*14675a02SAndroid Build Coastguard Worker    NONE = 0;
221*14675a02SAndroid Build Coastguard Worker    DEFAULT_PLOT_FOR_TASK_TYPE = 1;
222*14675a02SAndroid Build Coastguard Worker    LINE_PLOT = 2;
223*14675a02SAndroid Build Coastguard Worker    LINE_PLOT_WITH_PERCENTILES = 3;
224*14675a02SAndroid Build Coastguard Worker    HISTOGRAM = 4;
225*14675a02SAndroid Build Coastguard Worker  }
226*14675a02SAndroid Build Coastguard Worker
227*14675a02SAndroid Build Coastguard Worker  // Defines the plot type to provide downstream.
228*14675a02SAndroid Build Coastguard Worker  VisualizationType plot_type = 1;
229*14675a02SAndroid Build Coastguard Worker
230*14675a02SAndroid Build Coastguard Worker  // The x-axis which to provide for the given metric. Must be the name of a
231*14675a02SAndroid Build Coastguard Worker  // metric or counter. Recommended x_axis options are source_round, round,
232*14675a02SAndroid Build Coastguard Worker  // or time.
233*14675a02SAndroid Build Coastguard Worker  string x_axis = 2;
234*14675a02SAndroid Build Coastguard Worker
235*14675a02SAndroid Build Coastguard Worker  // Iff True, metric will be displayed on a population level dashboard.
236*14675a02SAndroid Build Coastguard Worker  bool plot_on_population_dashboard = 3;
237*14675a02SAndroid Build Coastguard Worker}
238*14675a02SAndroid Build Coastguard Worker
239*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type sum.
240*14675a02SAndroid Build Coastguard Workermessage SumOptions {
241*14675a02SAndroid Build Coastguard Worker  // Name for corresponding Metric stat_name field.
242*14675a02SAndroid Build Coastguard Worker  string stat_name = 1;
243*14675a02SAndroid Build Coastguard Worker
244*14675a02SAndroid Build Coastguard Worker  // Iff True, a cumulative sum over rounds will be provided in addition to a
245*14675a02SAndroid Build Coastguard Worker  // sum per round for the value metric.
246*14675a02SAndroid Build Coastguard Worker  bool include_cumulative_sum = 2;
247*14675a02SAndroid Build Coastguard Worker
248*14675a02SAndroid Build Coastguard Worker  // Iff True, sample of at most 101 clients' values.
249*14675a02SAndroid Build Coastguard Worker  // Used to calculate quantiles in downstream visualization pipeline.
250*14675a02SAndroid Build Coastguard Worker  bool include_client_samples = 3;
251*14675a02SAndroid Build Coastguard Worker}
252*14675a02SAndroid Build Coastguard Worker
253*14675a02SAndroid Build Coastguard Worker// A metric representing a ratio between metrics with aggregation type average.
254*14675a02SAndroid Build Coastguard Worker// Represents: numerator stat / denominator stat.
255*14675a02SAndroid Build Coastguard Workermessage AverageOptions {
256*14675a02SAndroid Build Coastguard Worker  // Numerator stat name pointing to corresponding Metric stat_name.
257*14675a02SAndroid Build Coastguard Worker  string numerator_stat_name = 1;
258*14675a02SAndroid Build Coastguard Worker
259*14675a02SAndroid Build Coastguard Worker  // Denominator stat name pointing to corresponding Metric stat_name.
260*14675a02SAndroid Build Coastguard Worker  string denominator_stat_name = 2;
261*14675a02SAndroid Build Coastguard Worker
262*14675a02SAndroid Build Coastguard Worker  // Name for corresponding Metric stat_name that is the ratio of the
263*14675a02SAndroid Build Coastguard Worker  // numerator stat / denominator stat.
264*14675a02SAndroid Build Coastguard Worker  string average_stat_name = 3;
265*14675a02SAndroid Build Coastguard Worker
266*14675a02SAndroid Build Coastguard Worker  // Iff True, sample of at most 101 client's values.
267*14675a02SAndroid Build Coastguard Worker  // Used to calculate quantiles in downstream visualization pipeline.
268*14675a02SAndroid Build Coastguard Worker  bool include_client_samples = 4;
269*14675a02SAndroid Build Coastguard Worker}
270*14675a02SAndroid Build Coastguard Worker
271*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type none.
272*14675a02SAndroid Build Coastguard Workermessage NoneOptions {
273*14675a02SAndroid Build Coastguard Worker  // Name for corresponding Metric stat_name field.
274*14675a02SAndroid Build Coastguard Worker  string stat_name = 1;
275*14675a02SAndroid Build Coastguard Worker}
276*14675a02SAndroid Build Coastguard Worker
277*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type only sample.
278*14675a02SAndroid Build Coastguard Workermessage OnlySampleOptions {
279*14675a02SAndroid Build Coastguard Worker  // Name for corresponding Metric stat_name field.
280*14675a02SAndroid Build Coastguard Worker  string stat_name = 1;
281*14675a02SAndroid Build Coastguard Worker}
282*14675a02SAndroid Build Coastguard Worker
283*14675a02SAndroid Build Coastguard Worker// Represents a data set. This is used for testing.
284*14675a02SAndroid Build Coastguard Workermessage Dataset {
285*14675a02SAndroid Build Coastguard Worker  // Represents the data set for one client.
286*14675a02SAndroid Build Coastguard Worker  message ClientDataset {
287*14675a02SAndroid Build Coastguard Worker    // A string identifying the client.
288*14675a02SAndroid Build Coastguard Worker    string client_id = 1;
289*14675a02SAndroid Build Coastguard Worker
290*14675a02SAndroid Build Coastguard Worker    // A list of serialized tf.Example protos.
291*14675a02SAndroid Build Coastguard Worker    repeated bytes example = 2;
292*14675a02SAndroid Build Coastguard Worker
293*14675a02SAndroid Build Coastguard Worker    // Represents a dataset whose examples are selected by an ExampleSelector.
294*14675a02SAndroid Build Coastguard Worker    message SelectedExample {
295*14675a02SAndroid Build Coastguard Worker      ExampleSelector selector = 1;
296*14675a02SAndroid Build Coastguard Worker      repeated bytes example = 2;
297*14675a02SAndroid Build Coastguard Worker    }
298*14675a02SAndroid Build Coastguard Worker
299*14675a02SAndroid Build Coastguard Worker    // A list of (selector, dataset) pairs. Used in testing some *TFF-based
300*14675a02SAndroid Build Coastguard Worker    // tasks* that require multiple datasets as client input, e.g., a TFF-based
301*14675a02SAndroid Build Coastguard Worker    // personalization eval task requires each client to provide at least two
302*14675a02SAndroid Build Coastguard Worker    // datasets: one for train, and the other for test.
303*14675a02SAndroid Build Coastguard Worker    repeated SelectedExample selected_example = 3;
304*14675a02SAndroid Build Coastguard Worker  }
305*14675a02SAndroid Build Coastguard Worker
306*14675a02SAndroid Build Coastguard Worker  // A list of client data.
307*14675a02SAndroid Build Coastguard Worker  repeated ClientDataset client_data = 1;
308*14675a02SAndroid Build Coastguard Worker}
309*14675a02SAndroid Build Coastguard Worker
310*14675a02SAndroid Build Coastguard Worker// Represents predicates over metrics - i.e., expectations. This is used in
311*14675a02SAndroid Build Coastguard Worker// training/eval tests to encode metric names and values expected to be reported
312*14675a02SAndroid Build Coastguard Worker// by a client execution.
313*14675a02SAndroid Build Coastguard Workermessage MetricTestPredicates {
314*14675a02SAndroid Build Coastguard Worker  // The value must lie in [lower_bound; upper_bound]. Can also be used for
315*14675a02SAndroid Build Coastguard Worker  // approximate matching (lower == value - epsilon; upper = value + epsilon).
316*14675a02SAndroid Build Coastguard Worker  message Interval {
317*14675a02SAndroid Build Coastguard Worker    double lower_bound = 1;
318*14675a02SAndroid Build Coastguard Worker    double upper_bound = 2;
319*14675a02SAndroid Build Coastguard Worker  }
320*14675a02SAndroid Build Coastguard Worker
321*14675a02SAndroid Build Coastguard Worker  // The value must be a real value as long as the value of the weight_name
322*14675a02SAndroid Build Coastguard Worker  // metric is non-zero. If the weight metric is zero, then it is acceptable for
323*14675a02SAndroid Build Coastguard Worker  // the value to be non-real.
324*14675a02SAndroid Build Coastguard Worker  message RealIfNonzeroWeight {
325*14675a02SAndroid Build Coastguard Worker    string weight_name = 1;
326*14675a02SAndroid Build Coastguard Worker  }
327*14675a02SAndroid Build Coastguard Worker
328*14675a02SAndroid Build Coastguard Worker  message MetricCriterion {
329*14675a02SAndroid Build Coastguard Worker    // Name of the metric.
330*14675a02SAndroid Build Coastguard Worker    string name = 1;
331*14675a02SAndroid Build Coastguard Worker
332*14675a02SAndroid Build Coastguard Worker    // FL training round this metric is expected to appear in.
333*14675a02SAndroid Build Coastguard Worker    int32 training_round_index = 2;
334*14675a02SAndroid Build Coastguard Worker
335*14675a02SAndroid Build Coastguard Worker    // If none of the following is set, no matching is performed; but the
336*14675a02SAndroid Build Coastguard Worker    // metric is still expected to be present (with whatever value).
337*14675a02SAndroid Build Coastguard Worker    oneof Criterion {
338*14675a02SAndroid Build Coastguard Worker      // The reported metric must be < lt.
339*14675a02SAndroid Build Coastguard Worker      float lt = 3;
340*14675a02SAndroid Build Coastguard Worker      // The reported metric must be > gt.
341*14675a02SAndroid Build Coastguard Worker      float gt = 4;
342*14675a02SAndroid Build Coastguard Worker      // The reported metric must be <= le.
343*14675a02SAndroid Build Coastguard Worker      float le = 5;
344*14675a02SAndroid Build Coastguard Worker      // The reported metric must be >= ge.
345*14675a02SAndroid Build Coastguard Worker      float ge = 6;
346*14675a02SAndroid Build Coastguard Worker      // The reported metric must be == eq.
347*14675a02SAndroid Build Coastguard Worker      float eq = 7;
348*14675a02SAndroid Build Coastguard Worker      // The reported metric must lie in the interval.
349*14675a02SAndroid Build Coastguard Worker      Interval interval = 8;
350*14675a02SAndroid Build Coastguard Worker      // The reported metric is not NaN or +/- infinity.
351*14675a02SAndroid Build Coastguard Worker      bool real = 9;
352*14675a02SAndroid Build Coastguard Worker      // The reported metric is real (i.e., not NaN or +/- infinity) if the
353*14675a02SAndroid Build Coastguard Worker      // value of an associated weight is not 0.
354*14675a02SAndroid Build Coastguard Worker      RealIfNonzeroWeight real_if_nonzero_weight = 10;
355*14675a02SAndroid Build Coastguard Worker    }
356*14675a02SAndroid Build Coastguard Worker  }
357*14675a02SAndroid Build Coastguard Worker
358*14675a02SAndroid Build Coastguard Worker  repeated MetricCriterion metric_criterion = 1;
359*14675a02SAndroid Build Coastguard Worker
360*14675a02SAndroid Build Coastguard Worker  reserved 2;
361*14675a02SAndroid Build Coastguard Worker}
362*14675a02SAndroid Build Coastguard Worker
363*14675a02SAndroid Build Coastguard Worker// Client Phase
364*14675a02SAndroid Build Coastguard Worker// ============
365*14675a02SAndroid Build Coastguard Worker
366*14675a02SAndroid Build Coastguard Worker// A `TensorflowSpec` that is executed on the client in a single `tf.Session`.
367*14675a02SAndroid Build Coastguard Worker// In federated optimization, this will correspond to one `ServerPhase`.
368*14675a02SAndroid Build Coastguard Workermessage ClientPhase {
369*14675a02SAndroid Build Coastguard Worker  // A short CamelCase name for the ClientPhase.
370*14675a02SAndroid Build Coastguard Worker  string name = 2;
371*14675a02SAndroid Build Coastguard Worker
372*14675a02SAndroid Build Coastguard Worker  // Minimum number of clients in aggregation.
373*14675a02SAndroid Build Coastguard Worker  // In secure aggregation mode this is used to configure the protocol instance
374*14675a02SAndroid Build Coastguard Worker  // in a way that server can't learn aggregated values with number of
375*14675a02SAndroid Build Coastguard Worker  // participants lower than this number.
376*14675a02SAndroid Build Coastguard Worker  // Without secure aggregation server still respects this parameter,
377*14675a02SAndroid Build Coastguard Worker  // ensuring that aggregated values never leave server RAM unless they include
378*14675a02SAndroid Build Coastguard Worker  // data from (at least) specified number of participants.
379*14675a02SAndroid Build Coastguard Worker  int32 minimum_number_of_participants = 3;
380*14675a02SAndroid Build Coastguard Worker
381*14675a02SAndroid Build Coastguard Worker  // If populated, `io_router` must be specified.
382*14675a02SAndroid Build Coastguard Worker  oneof spec {
383*14675a02SAndroid Build Coastguard Worker    // A functional interface for the TensorFlow logic the client should
384*14675a02SAndroid Build Coastguard Worker    // perform.
385*14675a02SAndroid Build Coastguard Worker    TensorflowSpec tensorflow_spec = 4 [lazy = true];
386*14675a02SAndroid Build Coastguard Worker    // Spec for client plans that issue example queries and send the query
387*14675a02SAndroid Build Coastguard Worker    // results directly to an aggregator with no or little additional
388*14675a02SAndroid Build Coastguard Worker    // processing.
389*14675a02SAndroid Build Coastguard Worker    ExampleQuerySpec example_query_spec = 9 [lazy = true];
390*14675a02SAndroid Build Coastguard Worker  }
391*14675a02SAndroid Build Coastguard Worker
392*14675a02SAndroid Build Coastguard Worker  // The specification of the inputs coming either from customer apps
393*14675a02SAndroid Build Coastguard Worker  // (Local Compute) or the federated protocol (Federated Compute).
394*14675a02SAndroid Build Coastguard Worker  oneof io_router {
395*14675a02SAndroid Build Coastguard Worker    FederatedComputeIORouter federated_compute = 5 [lazy = true];
396*14675a02SAndroid Build Coastguard Worker    LocalComputeIORouter local_compute = 6 [lazy = true];
397*14675a02SAndroid Build Coastguard Worker    FederatedComputeEligibilityIORouter federated_compute_eligibility = 7
398*14675a02SAndroid Build Coastguard Worker        [lazy = true];
399*14675a02SAndroid Build Coastguard Worker    FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true];
400*14675a02SAndroid Build Coastguard Worker  }
401*14675a02SAndroid Build Coastguard Worker
402*14675a02SAndroid Build Coastguard Worker  reserved 1;
403*14675a02SAndroid Build Coastguard Worker}
404*14675a02SAndroid Build Coastguard Worker
405*14675a02SAndroid Build Coastguard Worker// TensorflowSpec message describes a single call into TensorFlow, including the
406*14675a02SAndroid Build Coastguard Worker// expected input tensors that must be fed when making that call, which
407*14675a02SAndroid Build Coastguard Worker// output tensors to be fetched, and any operations that have no output but must
408*14675a02SAndroid Build Coastguard Worker// be run. The TensorFlow session will then use the input tensors to do some
409*14675a02SAndroid Build Coastguard Worker// computation, generally reading from one or more datasets, and provide some
410*14675a02SAndroid Build Coastguard Worker// outputs.
411*14675a02SAndroid Build Coastguard Worker//
412*14675a02SAndroid Build Coastguard Worker// Conceptually, client or server code uses this proto along with an IORouter
413*14675a02SAndroid Build Coastguard Worker// to build maps of names to input tensors, vectors of output tensor names,
414*14675a02SAndroid Build Coastguard Worker// and vectors of target nodes:
415*14675a02SAndroid Build Coastguard Worker//
416*14675a02SAndroid Build Coastguard Worker//    CreateTensorflowArguments(
417*14675a02SAndroid Build Coastguard Worker//        TensorflowSpec& spec,
418*14675a02SAndroid Build Coastguard Worker//        IORouter& io_router,
419*14675a02SAndroid Build Coastguard Worker//        const vector<pair<string, Tensor>>* input_tensors,
420*14675a02SAndroid Build Coastguard Worker//        const vector<string>* output_tensor_names,
421*14675a02SAndroid Build Coastguard Worker//        const vector<string>* target_node_names);
422*14675a02SAndroid Build Coastguard Worker//
423*14675a02SAndroid Build Coastguard Worker// Where `input_tensor`, `output_tensor_names` and `target_node_names`
424*14675a02SAndroid Build Coastguard Worker// correspond to the arguments of TensorFlow C++ API for
425*14675a02SAndroid Build Coastguard Worker// `tensorflow::Session:Run()`, and the client executes only a single
426*14675a02SAndroid Build Coastguard Worker// invocation.
427*14675a02SAndroid Build Coastguard Worker//
428*14675a02SAndroid Build Coastguard Worker// Note: the execution engine never sees any concepts related to the federated
429*14675a02SAndroid Build Coastguard Worker// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors
430*14675a02SAndroid Build Coastguard Worker// in, tensors out" interface. New aggregation methods can be added without
431*14675a02SAndroid Build Coastguard Worker// having to modify the execution engine / TensorflowSpec message, instead they
432*14675a02SAndroid Build Coastguard Worker// should modify the IORouter messages.
433*14675a02SAndroid Build Coastguard Worker//
434*14675a02SAndroid Build Coastguard Worker// Note: both `input_tensor_specs` and `output_tensor_specs` are full
435*14675a02SAndroid Build Coastguard Worker// `tensorflow.TensorSpecProto` messages, though TensorFlow technically
436*14675a02SAndroid Build Coastguard Worker// only requires the names to feed the values into the session. The additional
437*14675a02SAndroid Build Coastguard Worker// dtypes/shape information must always be included in case the runtime
438*14675a02SAndroid Build Coastguard Worker// executing this TensorflowSpec wants to perform additional, optional static
439*14675a02SAndroid Build Coastguard Worker// assertions. The runtimes however are free to ignore the dtype/shapes and only
440*14675a02SAndroid Build Coastguard Worker// rely on the names if so desired.
441*14675a02SAndroid Build Coastguard Worker//
442*14675a02SAndroid Build Coastguard Worker// Assertions:
443*14675a02SAndroid Build Coastguard Worker//   - all names in `input_tensor_specs`, `output_tensor_specs`, and
444*14675a02SAndroid Build Coastguard Worker//     `target_node_names` must appear in the serialized GraphDef where
445*14675a02SAndroid Build Coastguard Worker//     the TF execution will be invoked.
446*14675a02SAndroid Build Coastguard Worker//   - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise
447*14675a02SAndroid Build Coastguard Worker//     there is nothing to execute in the graph.
448*14675a02SAndroid Build Coastguard Workermessage TensorflowSpec {
449*14675a02SAndroid Build Coastguard Worker  // The name of a tensor into which a unique token for the current session
450*14675a02SAndroid Build Coastguard Worker  // should be written. The corresponding tensor is a scalar string tensor and
451*14675a02SAndroid Build Coastguard Worker  // is separate from `input_tensors` as there is only one.
452*14675a02SAndroid Build Coastguard Worker  //
453*14675a02SAndroid Build Coastguard Worker  // A session token allows TensorFlow ops such as `ServeSlices` or
454*14675a02SAndroid Build Coastguard Worker  // `ExternalDataset` to refer to callbacks and other session-global objects
455*14675a02SAndroid Build Coastguard Worker  // registered before running the session. In the `ExternalDataset` case, a
456*14675a02SAndroid Build Coastguard Worker  // single dataset_token is valid for multiple `tf.data.Dataset` objects as
457*14675a02SAndroid Build Coastguard Worker  // the token can be thought of as a handle to a dataset factory.
458*14675a02SAndroid Build Coastguard Worker  string dataset_token_tensor_name = 1;
459*14675a02SAndroid Build Coastguard Worker
460*14675a02SAndroid Build Coastguard Worker  // TensorSpecs of inputs which will be passed to TF.
461*14675a02SAndroid Build Coastguard Worker  //
462*14675a02SAndroid Build Coastguard Worker  // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in
463*14675a02SAndroid Build Coastguard Worker  // TensorFlow's Python API, excluding the dataset_token listed above.
464*14675a02SAndroid Build Coastguard Worker  //
465*14675a02SAndroid Build Coastguard Worker  // Assertions:
466*14675a02SAndroid Build Coastguard Worker  //   - All the tensor names designated as inputs in the corresponding IORouter
467*14675a02SAndroid Build Coastguard Worker  //     must be listed (otherwise the IORouter input work is unused).
468*14675a02SAndroid Build Coastguard Worker  //   - All placeholders in the TF graph must be listed here, with the
469*14675a02SAndroid Build Coastguard Worker  //     exception of the dataset_token which is explicitly set above (otherwise
470*14675a02SAndroid Build Coastguard Worker  //     TensorFlow will fail to execute).
471*14675a02SAndroid Build Coastguard Worker  repeated tensorflow.TensorSpecProto input_tensor_specs = 2;
472*14675a02SAndroid Build Coastguard Worker
473*14675a02SAndroid Build Coastguard Worker  // TensorSpecs that should be fetched from TF after execution.
474*14675a02SAndroid Build Coastguard Worker  //
475*14675a02SAndroid Build Coastguard Worker  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
476*14675a02SAndroid Build Coastguard Worker  // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++
477*14675a02SAndroid Build Coastguard Worker  // API.
478*14675a02SAndroid Build Coastguard Worker  //
479*14675a02SAndroid Build Coastguard Worker  // Assertions:
480*14675a02SAndroid Build Coastguard Worker  //   - The set of tensor names here must strictly match the tensor names
481*14675a02SAndroid Build Coastguard Worker  //     designated as outputs in the corresponding IORouter (if any exist).
482*14675a02SAndroid Build Coastguard Worker  repeated tensorflow.TensorSpecProto output_tensor_specs = 3;
483*14675a02SAndroid Build Coastguard Worker
484*14675a02SAndroid Build Coastguard Worker  // Node names in the graph that should be executed, but the output not
485*14675a02SAndroid Build Coastguard Worker  // returned.
486*14675a02SAndroid Build Coastguard Worker  //
487*14675a02SAndroid Build Coastguard Worker  // Corresponds to the `fetches` parameter of `tf.Session.run()` in
488*14675a02SAndroid Build Coastguard Worker  // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++
489*14675a02SAndroid Build Coastguard Worker  // API.
490*14675a02SAndroid Build Coastguard Worker  //
491*14675a02SAndroid Build Coastguard Worker  // This is intended for use with operations that do not produce tensors, but
492*14675a02SAndroid Build Coastguard Worker  // nonetheless are required to run (e.g. serializing checkpoints).
493*14675a02SAndroid Build Coastguard Worker  repeated string target_node_names = 4;
494*14675a02SAndroid Build Coastguard Worker
495*14675a02SAndroid Build Coastguard Worker  // Map of Tensor names to constant inputs.
496*14675a02SAndroid Build Coastguard Worker  // Note: tensors specified via this message should not be included in
497*14675a02SAndroid Build Coastguard Worker  // input_tensor_specs.
498*14675a02SAndroid Build Coastguard Worker  map<string, tensorflow.TensorProto> constant_inputs = 5;
499*14675a02SAndroid Build Coastguard Worker
500*14675a02SAndroid Build Coastguard Worker  // The fields below are added by OnDevicePersonalization module.
501*14675a02SAndroid Build Coastguard Worker  // Specifies an example selection procedure.
502*14675a02SAndroid Build Coastguard Worker  ExampleSelector example_selector = 999;
503*14675a02SAndroid Build Coastguard Worker}
504*14675a02SAndroid Build Coastguard Worker
505*14675a02SAndroid Build Coastguard Worker// ExampleQuerySpec message describes client execution that issues example
506*14675a02SAndroid Build Coastguard Worker// queries and sends the query results directly to an aggregator with no or
507*14675a02SAndroid Build Coastguard Worker// little additional processing.
508*14675a02SAndroid Build Coastguard Worker// This message describes one or more example store queries that perform the
509*14675a02SAndroid Build Coastguard Worker// client side analytics computation in C++. The corresponding output vectors
510*14675a02SAndroid Build Coastguard Worker// will be converted into the expected federated protocol output format.
511*14675a02SAndroid Build Coastguard Worker// This must be used in conjunction with the `FederatedExampleQueryIORouter`.
512*14675a02SAndroid Build Coastguard Workermessage ExampleQuerySpec {
513*14675a02SAndroid Build Coastguard Worker  message OutputVectorSpec {
514*14675a02SAndroid Build Coastguard Worker    // The output vector name.
515*14675a02SAndroid Build Coastguard Worker    string vector_name = 1;
516*14675a02SAndroid Build Coastguard Worker
517*14675a02SAndroid Build Coastguard Worker    // Supported data types for the vector of information.
518*14675a02SAndroid Build Coastguard Worker    enum DataType {
519*14675a02SAndroid Build Coastguard Worker      UNSPECIFIED = 0;
520*14675a02SAndroid Build Coastguard Worker      INT32 = 1;
521*14675a02SAndroid Build Coastguard Worker      INT64 = 2;
522*14675a02SAndroid Build Coastguard Worker      BOOL = 3;
523*14675a02SAndroid Build Coastguard Worker      FLOAT = 4;
524*14675a02SAndroid Build Coastguard Worker      DOUBLE = 5;
525*14675a02SAndroid Build Coastguard Worker      BYTES = 6;
526*14675a02SAndroid Build Coastguard Worker      STRING = 7;
527*14675a02SAndroid Build Coastguard Worker    }
528*14675a02SAndroid Build Coastguard Worker
529*14675a02SAndroid Build Coastguard Worker    // The data type for each entry in the vector.
530*14675a02SAndroid Build Coastguard Worker    DataType data_type = 2;
531*14675a02SAndroid Build Coastguard Worker  }
532*14675a02SAndroid Build Coastguard Worker
533*14675a02SAndroid Build Coastguard Worker  message ExampleQuery {
534*14675a02SAndroid Build Coastguard Worker    // The `ExampleSelector` to issue the query with.
535*14675a02SAndroid Build Coastguard Worker    ExampleSelector example_selector = 1;
536*14675a02SAndroid Build Coastguard Worker
537*14675a02SAndroid Build Coastguard Worker    // Indicates that the query returns vector data and must return a single
538*14675a02SAndroid Build Coastguard Worker    // ExampleQueryResult result containing a VectorData entry matching each
539*14675a02SAndroid Build Coastguard Worker    // OutputVectorSpec.vector_name.
540*14675a02SAndroid Build Coastguard Worker    //
541*14675a02SAndroid Build Coastguard Worker    // If the query instead returns no result, then it will be treated as is if
542*14675a02SAndroid Build Coastguard Worker    // an error was returned. In that case, or if the query explicitly returns
543*14675a02SAndroid Build Coastguard Worker    // an error, then the client will abort its session.
544*14675a02SAndroid Build Coastguard Worker    //
545*14675a02SAndroid Build Coastguard Worker    // The keys in the map are the names the vectors should be aggregated under,
546*14675a02SAndroid Build Coastguard Worker    // and must match the keys in FederatedExampleQueryIORouter.aggregations.
547*14675a02SAndroid Build Coastguard Worker    map<string, OutputVectorSpec> output_vector_specs = 2;
548*14675a02SAndroid Build Coastguard Worker  }
549*14675a02SAndroid Build Coastguard Worker
550*14675a02SAndroid Build Coastguard Worker  // The queries to run.
551*14675a02SAndroid Build Coastguard Worker  repeated ExampleQuery example_queries = 1;
552*14675a02SAndroid Build Coastguard Worker}
553*14675a02SAndroid Build Coastguard Worker
554*14675a02SAndroid Build Coastguard Worker// The input and output router for Federated Compute plans.
555*14675a02SAndroid Build Coastguard Worker//
556*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the federated protocol and the TensorFlow
557*14675a02SAndroid Build Coastguard Worker// execution engine. This message describes how to prepare data coming from the
558*14675a02SAndroid Build Coastguard Worker// incoming `CheckinResponse` (defined in
559*14675a02SAndroid Build Coastguard Worker// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what
560*14675a02SAndroid Build Coastguard Worker// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on
561*14675a02SAndroid Build Coastguard Worker// the server).
562*14675a02SAndroid Build Coastguard Worker//
563*14675a02SAndroid Build Coastguard Worker// TODO(team) we could replace `input_checkpoint_file_tensor_name` with
564*14675a02SAndroid Build Coastguard Worker// an `input_tensors` field, which would then be a tensor that contains the
565*14675a02SAndroid Build Coastguard Worker// input TensorProtos directly and skipping disk I/O, rather than referring to a
566*14675a02SAndroid Build Coastguard Worker// checkpoint file path.
567*14675a02SAndroid Build Coastguard Workermessage FederatedComputeIORouter {
568*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
569*14675a02SAndroid Build Coastguard Worker  //   Inputs
570*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
571*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor that is fed the file path to the
572*14675a02SAndroid Build Coastguard Worker  // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint).
573*14675a02SAndroid Build Coastguard Worker  //
574*14675a02SAndroid Build Coastguard Worker  // The federated protocol code would copy the `CheckinResponse`'s initial
575*14675a02SAndroid Build Coastguard Worker  // checkpoint to a temporary file and then pass that file path through this
576*14675a02SAndroid Build Coastguard Worker  // tensor.
577*14675a02SAndroid Build Coastguard Worker  //
578*14675a02SAndroid Build Coastguard Worker  // Ops may be added to the client graph that take this tensor as input and
579*14675a02SAndroid Build Coastguard Worker  // reads the path.
580*14675a02SAndroid Build Coastguard Worker  //
581*14675a02SAndroid Build Coastguard Worker  // This field is optional. It may be omitted if the client graph does not use
582*14675a02SAndroid Build Coastguard Worker  // an initial checkpoint.
583*14675a02SAndroid Build Coastguard Worker  string input_filepath_tensor_name = 1;
584*14675a02SAndroid Build Coastguard Worker
585*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor that is fed the file path to which
586*14675a02SAndroid Build Coastguard Worker  // client work should serialize the bytes to send back to the server.
587*14675a02SAndroid Build Coastguard Worker  //
588*14675a02SAndroid Build Coastguard Worker  // The federated protocol code generates a temporary file and passes the file
589*14675a02SAndroid Build Coastguard Worker  // path through this tensor.
590*14675a02SAndroid Build Coastguard Worker  //
591*14675a02SAndroid Build Coastguard Worker  // Ops may be be added to the client graph that use this tensor as an argument
592*14675a02SAndroid Build Coastguard Worker  // to write files (e.g. writing checkpoints to disk).
593*14675a02SAndroid Build Coastguard Worker  //
594*14675a02SAndroid Build Coastguard Worker  // This field is optional. It must be omitted if the client graph does not
595*14675a02SAndroid Build Coastguard Worker  // generate any output files (e.g. when all output tensors of `TensorflowSpec`
596*14675a02SAndroid Build Coastguard Worker  // use Secure Aggregation). If this field is not set, then the `ReportRequest`
597*14675a02SAndroid Build Coastguard Worker  // message in the federated protocol will not have the
598*14675a02SAndroid Build Coastguard Worker  // `Report.update_checkpoint` field set. This absence of a value here can be
599*14675a02SAndroid Build Coastguard Worker  // used to validate that the plan only uses Secure Aggregation.
600*14675a02SAndroid Build Coastguard Worker  //
601*14675a02SAndroid Build Coastguard Worker  // Conversely, if this field is set and executing the associated
602*14675a02SAndroid Build Coastguard Worker  // TensorflowSpec does not write to the path is indication of an internal
603*14675a02SAndroid Build Coastguard Worker  // framework error. The runtime should notify the caller that the computation
604*14675a02SAndroid Build Coastguard Worker  // was setup incorrectly.
605*14675a02SAndroid Build Coastguard Worker  string output_filepath_tensor_name = 2;
606*14675a02SAndroid Build Coastguard Worker
607*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
608*14675a02SAndroid Build Coastguard Worker  //   Outputs
609*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
610*14675a02SAndroid Build Coastguard Worker  // Describes which output tensors should be aggregated using an aggregation
611*14675a02SAndroid Build Coastguard Worker  // protocol, and the configuration for those protocols.
612*14675a02SAndroid Build Coastguard Worker  //
613*14675a02SAndroid Build Coastguard Worker  // Assertions:
614*14675a02SAndroid Build Coastguard Worker  //   - All keys must exist in the associated `TensorflowSpec` as
615*14675a02SAndroid Build Coastguard Worker  //     `output_tensor_specs.name` values.
616*14675a02SAndroid Build Coastguard Worker  map<string, AggregationConfig> aggregations = 3;
617*14675a02SAndroid Build Coastguard Worker}
618*14675a02SAndroid Build Coastguard Worker
619*14675a02SAndroid Build Coastguard Worker// The input and output router for client plans that do not use TensorFlow.
620*14675a02SAndroid Build Coastguard Worker//
621*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the federated protocol and the example query
622*14675a02SAndroid Build Coastguard Worker// execution engine, describing how the query results should ultimately be
623*14675a02SAndroid Build Coastguard Worker// aggregated.
624*14675a02SAndroid Build Coastguard Workermessage FederatedExampleQueryIORouter {
625*14675a02SAndroid Build Coastguard Worker  // Describes how each output vector should be aggregated using an aggregation
626*14675a02SAndroid Build Coastguard Worker  // protocol, and the configuration for those protocols.
627*14675a02SAndroid Build Coastguard Worker  // Keys must match the keys in ExampleQuerySpec.output_vector_specs.
628*14675a02SAndroid Build Coastguard Worker  // Note that currently only the TFV1CheckpointAggregation config is supported.
629*14675a02SAndroid Build Coastguard Worker  map<string, AggregationConfig> aggregations = 1;
630*14675a02SAndroid Build Coastguard Worker}
631*14675a02SAndroid Build Coastguard Worker
632*14675a02SAndroid Build Coastguard Worker// The specification for how to aggregate the associated tensor across clients
633*14675a02SAndroid Build Coastguard Worker// on the server.
634*14675a02SAndroid Build Coastguard Workermessage AggregationConfig {
635*14675a02SAndroid Build Coastguard Worker  oneof protocol_config {
636*14675a02SAndroid Build Coastguard Worker    // Indicates that the given output tensor should be processed using Secure
637*14675a02SAndroid Build Coastguard Worker    // Aggregation, using the specified config options.
638*14675a02SAndroid Build Coastguard Worker    SecureAggregationConfig secure_aggregation = 2;
639*14675a02SAndroid Build Coastguard Worker
640*14675a02SAndroid Build Coastguard Worker    // Note: in the future we could add a `SimpleAggregationConfig` to add
641*14675a02SAndroid Build Coastguard Worker    // support for simple aggregation without writing to an intermediate
642*14675a02SAndroid Build Coastguard Worker    // checkpoint file first.
643*14675a02SAndroid Build Coastguard Worker
644*14675a02SAndroid Build Coastguard Worker    // Indicates that the given output tensor or vector (e.g. as produced by an
645*14675a02SAndroid Build Coastguard Worker    // ExampleQuerySpec) should be placed in an output TF v1 checkpoint.
646*14675a02SAndroid Build Coastguard Worker    //
647*14675a02SAndroid Build Coastguard Worker    // Currently only ExampleQuerySpec output vectors are supported by this
648*14675a02SAndroid Build Coastguard Worker    // aggregation type (i.e. it cannot be used with TensorflowSpec output
649*14675a02SAndroid Build Coastguard Worker    // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
650*14675a02SAndroid Build Coastguard Worker    // its corresponding data type.
651*14675a02SAndroid Build Coastguard Worker    TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3;
652*14675a02SAndroid Build Coastguard Worker  }
653*14675a02SAndroid Build Coastguard Worker}
654*14675a02SAndroid Build Coastguard Worker
655*14675a02SAndroid Build Coastguard Worker// Parameters for the SecAgg protocol (go/secagg).
656*14675a02SAndroid Build Coastguard Worker//
657*14675a02SAndroid Build Coastguard Worker// Currently only the server uses the SecAgg parameters, so we only use this
658*14675a02SAndroid Build Coastguard Worker// message to signify usage of SecAgg.
659*14675a02SAndroid Build Coastguard Workermessage SecureAggregationConfig {}
660*14675a02SAndroid Build Coastguard Worker
661*14675a02SAndroid Build Coastguard Worker// Parameters for the TFV1 Checkpoint Aggregation protocol.
662*14675a02SAndroid Build Coastguard Worker//
663*14675a02SAndroid Build Coastguard Worker// Currently only ExampleQuerySpec output vectors are supported by this
664*14675a02SAndroid Build Coastguard Worker// aggregation type (i.e. it cannot be used with TensorflowSpec output
665*14675a02SAndroid Build Coastguard Worker// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of
666*14675a02SAndroid Build Coastguard Worker// its corresponding data type.
667*14675a02SAndroid Build Coastguard Workermessage TFV1CheckpointAggregation {}
668*14675a02SAndroid Build Coastguard Worker
669*14675a02SAndroid Build Coastguard Worker// The input and output router for eligibility-computing plans. These plans
670*14675a02SAndroid Build Coastguard Worker// compute which other plans a client is eligible to run, and are returned by
671*14675a02SAndroid Build Coastguard Worker// clients via a `EligibilityEvalCheckinResponse` (defined in
672*14675a02SAndroid Build Coastguard Worker// fcp/protos/federated_api.proto).
673*14675a02SAndroid Build Coastguard Workermessage FederatedComputeEligibilityIORouter {
674*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor that is fed the file path to the
675*14675a02SAndroid Build Coastguard Worker  // initial checkpoint (e.g. as provided via
676*14675a02SAndroid Build Coastguard Worker  // `EligibilityEvalPayload.init_checkpoint`).
677*14675a02SAndroid Build Coastguard Worker  //
678*14675a02SAndroid Build Coastguard Worker  // For more detail see the
679*14675a02SAndroid Build Coastguard Worker  // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same
680*14675a02SAndroid Build Coastguard Worker  // semantics.
681*14675a02SAndroid Build Coastguard Worker  //
682*14675a02SAndroid Build Coastguard Worker  // This field is optional. It may be omitted if the client graph does not use
683*14675a02SAndroid Build Coastguard Worker  // an initial checkpoint.
684*14675a02SAndroid Build Coastguard Worker  //
685*14675a02SAndroid Build Coastguard Worker  // This tensor name must exist in the associated
686*14675a02SAndroid Build Coastguard Worker  // `TensorflowSpec.input_tensor_specs` list.
687*14675a02SAndroid Build Coastguard Worker  string input_filepath_tensor_name = 1;
688*14675a02SAndroid Build Coastguard Worker
689*14675a02SAndroid Build Coastguard Worker  // Name of the output tensor (a string scalar) containing the serialized
690*14675a02SAndroid Build Coastguard Worker  // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The
691*14675a02SAndroid Build Coastguard Worker  // client code will parse this proto and place it in the
692*14675a02SAndroid Build Coastguard Worker  // `task_eligibility_info` field of the subsequent `CheckinRequest`.
693*14675a02SAndroid Build Coastguard Worker  //
694*14675a02SAndroid Build Coastguard Worker  // This tensor name must exist in the associated
695*14675a02SAndroid Build Coastguard Worker  // `TensorflowSpec.output_tensor_specs` list.
696*14675a02SAndroid Build Coastguard Worker  string task_eligibility_info_tensor_name = 2;
697*14675a02SAndroid Build Coastguard Worker}
698*14675a02SAndroid Build Coastguard Worker
699*14675a02SAndroid Build Coastguard Worker// The input and output router for Local Compute plans.
700*14675a02SAndroid Build Coastguard Worker//
701*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the customers app and the TensorFlow
702*14675a02SAndroid Build Coastguard Worker// execution engine. This message describes how to prepare data coming from the
703*14675a02SAndroid Build Coastguard Worker// customer app (e.g. the input directory the app setup), and the temporary,
704*14675a02SAndroid Build Coastguard Worker// scratch output directory that will be notified to the customer app upon
705*14675a02SAndroid Build Coastguard Worker// completion of `TensorflowSpec`.
706*14675a02SAndroid Build Coastguard Workermessage LocalComputeIORouter {
707*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
708*14675a02SAndroid Build Coastguard Worker  //   Inputs
709*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
710*14675a02SAndroid Build Coastguard Worker  // The name of the placeholder tensor representing the input resource path(s).
711*14675a02SAndroid Build Coastguard Worker  // It can be a single input directory or file path (in this case the
712*14675a02SAndroid Build Coastguard Worker  // `input_dir_tensor_name` is populated) or multiple input resources
713*14675a02SAndroid Build Coastguard Worker  // represented as a map from names to input directories or file paths (in this
714*14675a02SAndroid Build Coastguard Worker  // case the `multiple_input_resources` is populated).
715*14675a02SAndroid Build Coastguard Worker  //
716*14675a02SAndroid Build Coastguard Worker  // In the multiple input resources case, the placeholder tensors are
717*14675a02SAndroid Build Coastguard Worker  // represented as a map: the keys are the input resource names defined by the
718*14675a02SAndroid Build Coastguard Worker  // users when constructing the `LocalComputation` Python object, and the
719*14675a02SAndroid Build Coastguard Worker  // values are the corresponding placeholder tensor names created by the local
720*14675a02SAndroid Build Coastguard Worker  // computation plan builder.
721*14675a02SAndroid Build Coastguard Worker  //
722*14675a02SAndroid Build Coastguard Worker  // Apps will have the ability to create contracts between their Android code
723*14675a02SAndroid Build Coastguard Worker  // and `LocalComputation` toolkit code to place files inside the input
724*14675a02SAndroid Build Coastguard Worker  // resource paths with known names (Android code) and create graphs with ops
725*14675a02SAndroid Build Coastguard Worker  // to read from these paths (file names can be specified in toolkit code).
726*14675a02SAndroid Build Coastguard Worker  oneof input_resource {
727*14675a02SAndroid Build Coastguard Worker    string input_dir_tensor_name = 1;
728*14675a02SAndroid Build Coastguard Worker    // Directly using the `map` field is not allowed in `oneof`, so we have to
729*14675a02SAndroid Build Coastguard Worker    // wrap it in a new message.
730*14675a02SAndroid Build Coastguard Worker    MultipleInputResources multiple_input_resources = 3;
731*14675a02SAndroid Build Coastguard Worker  }
732*14675a02SAndroid Build Coastguard Worker
733*14675a02SAndroid Build Coastguard Worker  // Scalar string tensor name that will contain the output directory path.
734*14675a02SAndroid Build Coastguard Worker  //
735*14675a02SAndroid Build Coastguard Worker  // The provided directory should be considered temporary scratch that will be
736*14675a02SAndroid Build Coastguard Worker  // deleted, not persisted. It is the responsibility of the calling app to
737*14675a02SAndroid Build Coastguard Worker  // move the desired files to a permanent location once the client returns this
738*14675a02SAndroid Build Coastguard Worker  // directory back to the calling app.
739*14675a02SAndroid Build Coastguard Worker  string output_dir_tensor_name = 2;
740*14675a02SAndroid Build Coastguard Worker
741*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
742*14675a02SAndroid Build Coastguard Worker  //   Outputs
743*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
744*14675a02SAndroid Build Coastguard Worker  // NOTE: LocalCompute has no outputs other than what the client graph writes
745*14675a02SAndroid Build Coastguard Worker  // to `output_dir` specified above.
746*14675a02SAndroid Build Coastguard Worker}
747*14675a02SAndroid Build Coastguard Worker
748*14675a02SAndroid Build Coastguard Worker// Describes the multiple input resources in `LocalComputeIORouter`.
749*14675a02SAndroid Build Coastguard Workermessage MultipleInputResources {
750*14675a02SAndroid Build Coastguard Worker  // The keys are the input resource names (defined by the users when
751*14675a02SAndroid Build Coastguard Worker  // constructing the `LocalComputation` Python object), and the values are the
752*14675a02SAndroid Build Coastguard Worker  // corresponding placeholder tensor names created by the local computation
753*14675a02SAndroid Build Coastguard Worker  // plan builder.
754*14675a02SAndroid Build Coastguard Worker  map<string, string> input_resource_tensor_name_map = 1;
755*14675a02SAndroid Build Coastguard Worker}
756*14675a02SAndroid Build Coastguard Worker
757*14675a02SAndroid Build Coastguard Worker// Describes a queue to which input is fed.
758*14675a02SAndroid Build Coastguard Workermessage AsyncInputFeed {
759*14675a02SAndroid Build Coastguard Worker  // The op for enqueuing an example input.
760*14675a02SAndroid Build Coastguard Worker  string enqueue_op = 1;
761*14675a02SAndroid Build Coastguard Worker
762*14675a02SAndroid Build Coastguard Worker  // The input placeholders for the enqueue op.
763*14675a02SAndroid Build Coastguard Worker  repeated string enqueue_params = 2;
764*14675a02SAndroid Build Coastguard Worker
765*14675a02SAndroid Build Coastguard Worker  // The op for closing the input queue.
766*14675a02SAndroid Build Coastguard Worker  string close_op = 3;
767*14675a02SAndroid Build Coastguard Worker
768*14675a02SAndroid Build Coastguard Worker  // Whether the work that should be fed asynchronously is the data itself
769*14675a02SAndroid Build Coastguard Worker  // or a description of where that data lives.
770*14675a02SAndroid Build Coastguard Worker  bool feed_values_are_data = 4;
771*14675a02SAndroid Build Coastguard Worker}
772*14675a02SAndroid Build Coastguard Worker
773*14675a02SAndroid Build Coastguard Workermessage DatasetInput {
774*14675a02SAndroid Build Coastguard Worker  // Initializer of iterator corresponding to tf.data.Dataset object which
775*14675a02SAndroid Build Coastguard Worker  // handles the input data. Stores name of an op in the graph.
776*14675a02SAndroid Build Coastguard Worker  string initializer = 1;
777*14675a02SAndroid Build Coastguard Worker
778*14675a02SAndroid Build Coastguard Worker  // Placeholders necessary to initialize the dataset.
779*14675a02SAndroid Build Coastguard Worker  DatasetInputPlaceholders placeholders = 2;
780*14675a02SAndroid Build Coastguard Worker
781*14675a02SAndroid Build Coastguard Worker  // Batch size to be used in tf.data.Dataset.
782*14675a02SAndroid Build Coastguard Worker  int32 batch_size = 3;
783*14675a02SAndroid Build Coastguard Worker}
784*14675a02SAndroid Build Coastguard Worker
785*14675a02SAndroid Build Coastguard Workermessage DatasetInputPlaceholders {
786*14675a02SAndroid Build Coastguard Worker  // Name of placeholder corresponding to filename(s) of SSTable(s) to read data
787*14675a02SAndroid Build Coastguard Worker  // from.
788*14675a02SAndroid Build Coastguard Worker  string filename = 1;
789*14675a02SAndroid Build Coastguard Worker
790*14675a02SAndroid Build Coastguard Worker  // Name of placeholder corresponding to key_prefix initializing the
791*14675a02SAndroid Build Coastguard Worker  // SSTableDataset. Note the value fed should be unique user id, not a prefix.
792*14675a02SAndroid Build Coastguard Worker  string key_prefix = 2;
793*14675a02SAndroid Build Coastguard Worker
794*14675a02SAndroid Build Coastguard Worker  // Name of placeholder corresponding to number of rounds the local training
795*14675a02SAndroid Build Coastguard Worker  // should be run for.
796*14675a02SAndroid Build Coastguard Worker  string num_epochs = 3;
797*14675a02SAndroid Build Coastguard Worker
798*14675a02SAndroid Build Coastguard Worker  // Name of placeholder corresponding to batch size.
799*14675a02SAndroid Build Coastguard Worker  string batch_size = 4;
800*14675a02SAndroid Build Coastguard Worker}
801*14675a02SAndroid Build Coastguard Worker
802*14675a02SAndroid Build Coastguard Worker// Specifies an example selection procedure.
803*14675a02SAndroid Build Coastguard Workermessage ExampleSelector {
804*14675a02SAndroid Build Coastguard Worker  // Selection criteria following a contract agreed upon between client and
805*14675a02SAndroid Build Coastguard Worker  // model designers.
806*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any criteria = 1;
807*14675a02SAndroid Build Coastguard Worker
808*14675a02SAndroid Build Coastguard Worker  // A URI identifying the example collection to read from. Format should adhere
809*14675a02SAndroid Build Coastguard Worker  // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments
810*14675a02SAndroid Build Coastguard Worker  // should adhere to the following rules:
811*14675a02SAndroid Build Coastguard Worker  // - The scheme ${COLLECTION} should be one of:
812*14675a02SAndroid Build Coastguard Worker  //   - "app" for app-hosted example
813*14675a02SAndroid Build Coastguard Worker  //   - "simulation" for collections not connected to an app (e.g., if used
814*14675a02SAndroid Build Coastguard Worker  //     purely for simulation)
815*14675a02SAndroid Build Coastguard Worker  // - The authority ${APP_NAME} identifies the owner of the example
816*14675a02SAndroid Build Coastguard Worker  //   collection and should be either the app's package name, or be left empty
817*14675a02SAndroid Build Coastguard Worker  //   (which means "the current app package name").
818*14675a02SAndroid Build Coastguard Worker  // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with
819*14675a02SAndroid Build Coastguard Worker  //   a forward slash ("/").
820*14675a02SAndroid Build Coastguard Worker  // - The query and fragment are currently not used, but they may become used
821*14675a02SAndroid Build Coastguard Worker  //   for something in the future. To keep open that possibility they must
822*14675a02SAndroid Build Coastguard Worker  //   currently be left empty.
823*14675a02SAndroid Build Coastguard Worker  //
824*14675a02SAndroid Build Coastguard Worker  // Example: "app://com.google.some.app/someCollection/name"
825*14675a02SAndroid Build Coastguard Worker  // identifies the collection "/someCollection/name" owned and hosted by the
826*14675a02SAndroid Build Coastguard Worker  // app with package name "com.google.some.app".
827*14675a02SAndroid Build Coastguard Worker  //
828*14675a02SAndroid Build Coastguard Worker  // Example: "app:/someCollection/name" or "app:///someCollection/name"
829*14675a02SAndroid Build Coastguard Worker  // both identify the collection "/someCollection/name" owned and hosted by the
830*14675a02SAndroid Build Coastguard Worker  // app associated with the training job in which this URI appears.
831*14675a02SAndroid Build Coastguard Worker  //
832*14675a02SAndroid Build Coastguard Worker  // The path will not be interpreted by the runtime, and will be passed to the
833*14675a02SAndroid Build Coastguard Worker  // example collection implementation for interpretation. Thus, in the case of
834*14675a02SAndroid Build Coastguard Worker  // app-hosted example stores, the path segment's interpretation is a contract
835*14675a02SAndroid Build Coastguard Worker  // between the app's example store developers, and the app's model designers.
836*14675a02SAndroid Build Coastguard Worker  //
837*14675a02SAndroid Build Coastguard Worker  // If an `app://` URI is set, then the `TrainerOptions` collection name must
838*14675a02SAndroid Build Coastguard Worker  // not be set.
839*14675a02SAndroid Build Coastguard Worker  string collection_uri = 2;
840*14675a02SAndroid Build Coastguard Worker
841*14675a02SAndroid Build Coastguard Worker  // Resumption token following a contract agreed upon between client and
842*14675a02SAndroid Build Coastguard Worker  // model designers.
843*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any resumption_token = 3;
844*14675a02SAndroid Build Coastguard Worker}
845*14675a02SAndroid Build Coastguard Worker
846*14675a02SAndroid Build Coastguard Worker// Selector for slices to fetch as part of a `federated_select` operation.
847*14675a02SAndroid Build Coastguard Workermessage SlicesSelector {
848*14675a02SAndroid Build Coastguard Worker  // The string ID under which the slices are served.
849*14675a02SAndroid Build Coastguard Worker  //
850*14675a02SAndroid Build Coastguard Worker  // This value must have been returned by a previous call to the `serve_slices`
851*14675a02SAndroid Build Coastguard Worker  // op run during the `write_client_init` operation.
852*14675a02SAndroid Build Coastguard Worker  string served_at_id = 1;
853*14675a02SAndroid Build Coastguard Worker
854*14675a02SAndroid Build Coastguard Worker  // The indices of slices to fetch.
855*14675a02SAndroid Build Coastguard Worker  repeated int32 keys = 2;
856*14675a02SAndroid Build Coastguard Worker}
857*14675a02SAndroid Build Coastguard Worker
858*14675a02SAndroid Build Coastguard Worker// Represents slice data to be served as part of a `federated_select` operation.
859*14675a02SAndroid Build Coastguard Worker// This is used for testing.
860*14675a02SAndroid Build Coastguard Workermessage SlicesTestDataset {
861*14675a02SAndroid Build Coastguard Worker  // The test data to use. The keys map to the `SlicesSelector.served_at_id`
862*14675a02SAndroid Build Coastguard Worker  // field.  E.g. test slice data for a slice with `served_at_id`="foo" and
863*14675a02SAndroid Build Coastguard Worker  // `keys`=2 would be store in `dataset["foo"].slice_data[2]`.
864*14675a02SAndroid Build Coastguard Worker  map<string, SlicesTestData> dataset = 1;
865*14675a02SAndroid Build Coastguard Worker}
866*14675a02SAndroid Build Coastguard Workermessage SlicesTestData {
867*14675a02SAndroid Build Coastguard Worker  // The test slice data to serve. Each entry's index corresponds to the slice
868*14675a02SAndroid Build Coastguard Worker  // key it is the test data for.
869*14675a02SAndroid Build Coastguard Worker  repeated bytes slice_data = 2;
870*14675a02SAndroid Build Coastguard Worker}
871*14675a02SAndroid Build Coastguard Worker
872*14675a02SAndroid Build Coastguard Worker// Server Phase V2
873*14675a02SAndroid Build Coastguard Worker// ===============
874*14675a02SAndroid Build Coastguard Worker
875*14675a02SAndroid Build Coastguard Worker// Represents a server phase with three distinct components: pre-broadcast,
876*14675a02SAndroid Build Coastguard Worker// aggregation, and post-aggregation.
877*14675a02SAndroid Build Coastguard Worker//
878*14675a02SAndroid Build Coastguard Worker// The pre-broadcast and post-aggregation components are described with
879*14675a02SAndroid Build Coastguard Worker// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec
880*14675a02SAndroid Build Coastguard Worker// messages, respectively. These messages in combination with the server
881*14675a02SAndroid Build Coastguard Worker// IORouter messages specify how to set up a single TF sess.run call for each
882*14675a02SAndroid Build Coastguard Worker// component.
883*14675a02SAndroid Build Coastguard Worker//
884*14675a02SAndroid Build Coastguard Worker// The pre-broadcast logic is obtained by transforming the server_prepare TFF
885*14675a02SAndroid Build Coastguard Worker// computation in the DistributeAggregateForm. It takes the server state as
886*14675a02SAndroid Build Coastguard Worker// input, and it generates the checkpoint to broadcast to the clients and
887*14675a02SAndroid Build Coastguard Worker// potentially an intermediate server state. The intermediate server state may
888*14675a02SAndroid Build Coastguard Worker// be used by the aggregation and post-aggregation logic.
889*14675a02SAndroid Build Coastguard Worker//
890*14675a02SAndroid Build Coastguard Worker// The aggregation logic represents the aggregation of client results at the
891*14675a02SAndroid Build Coastguard Worker// server and is described using a list of ServerAggregationConfig messages.
892*14675a02SAndroid Build Coastguard Worker// Each ServerAggregationConfig message describes a single aggregation operation
893*14675a02SAndroid Build Coastguard Worker// on a set of input/output tensors. The input tensors may represent parts of
894*14675a02SAndroid Build Coastguard Worker// either the client results or the intermediate server state. These messages
895*14675a02SAndroid Build Coastguard Worker// are obtained by transforming the client_to_server_aggregation TFF computation
896*14675a02SAndroid Build Coastguard Worker// in the DistributeAggregateForm.
897*14675a02SAndroid Build Coastguard Worker//
898*14675a02SAndroid Build Coastguard Worker// The post-aggregation logic is obtained by transforming the server_result TFF
899*14675a02SAndroid Build Coastguard Worker// computation in the DistributeAggregateForm. It takes the intermediate server
900*14675a02SAndroid Build Coastguard Worker// state and the aggregated client results as input, and it generates the new
901*14675a02SAndroid Build Coastguard Worker// server state and potentially other server-side output.
902*14675a02SAndroid Build Coastguard Worker//
903*14675a02SAndroid Build Coastguard Worker// Note that while a ServerPhaseV2 message can be generated for all types of
904*14675a02SAndroid Build Coastguard Worker// intrinsics, it is currently only compatible with the ClientPhase message if
905*14675a02SAndroid Build Coastguard Worker// the aggregations being used are exclusively federated_sum (not SecAgg). If
906*14675a02SAndroid Build Coastguard Worker// this compatibility requirement is satisfied, it is also valid to run the
907*14675a02SAndroid Build Coastguard Worker// aggregation portion of this ServerPhaseV2 message alongside the pre- and
908*14675a02SAndroid Build Coastguard Worker// post-aggregation logic from the original ServerPhase message. Ultimately,
909*14675a02SAndroid Build Coastguard Worker// we expect the full ServerPhaseV2 message to be run and the ServerPhase
910*14675a02SAndroid Build Coastguard Worker// message to be deprecated.
911*14675a02SAndroid Build Coastguard Workermessage ServerPhaseV2 {
912*14675a02SAndroid Build Coastguard Worker  // A short CamelCase name for the ServerPhaseV2.
913*14675a02SAndroid Build Coastguard Worker  string name = 1;
914*14675a02SAndroid Build Coastguard Worker
915*14675a02SAndroid Build Coastguard Worker  // A functional interface for the TensorFlow logic the server should perform
916*14675a02SAndroid Build Coastguard Worker  // prior to the server-to-client broadcast. This should be used with the
917*14675a02SAndroid Build Coastguard Worker  // TensorFlow graph defined in server_graph_prepare_bytes.
918*14675a02SAndroid Build Coastguard Worker  TensorflowSpec tensorflow_spec_prepare = 3;
919*14675a02SAndroid Build Coastguard Worker
920*14675a02SAndroid Build Coastguard Worker  // The specification of inputs needed by the server_prepare TF logic.
921*14675a02SAndroid Build Coastguard Worker  oneof server_prepare_io_router {
922*14675a02SAndroid Build Coastguard Worker    ServerPrepareIORouter prepare_router = 4;
923*14675a02SAndroid Build Coastguard Worker  }
924*14675a02SAndroid Build Coastguard Worker
925*14675a02SAndroid Build Coastguard Worker  // A list of client-to-server aggregations to perform.
926*14675a02SAndroid Build Coastguard Worker  repeated ServerAggregationConfig aggregations = 2;
927*14675a02SAndroid Build Coastguard Worker
928*14675a02SAndroid Build Coastguard Worker  // A functional interface for the TensorFlow logic the server should perform
929*14675a02SAndroid Build Coastguard Worker  // post-aggregation. This should be used with the TensorFlow graph defined
930*14675a02SAndroid Build Coastguard Worker  // in server_graph_result_bytes.
931*14675a02SAndroid Build Coastguard Worker  TensorflowSpec tensorflow_spec_result = 5;
932*14675a02SAndroid Build Coastguard Worker
933*14675a02SAndroid Build Coastguard Worker  // The specification of inputs and outputs needed by the server_result TF
934*14675a02SAndroid Build Coastguard Worker  // logic.
935*14675a02SAndroid Build Coastguard Worker  oneof server_result_io_router {
936*14675a02SAndroid Build Coastguard Worker    ServerResultIORouter result_router = 6;
937*14675a02SAndroid Build Coastguard Worker  }
938*14675a02SAndroid Build Coastguard Worker}
939*14675a02SAndroid Build Coastguard Worker
940*14675a02SAndroid Build Coastguard Worker// Routing for server_prepare graph
941*14675a02SAndroid Build Coastguard Workermessage ServerPrepareIORouter {
942*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_prepare TF graph that
943*14675a02SAndroid Build Coastguard Worker  // is fed the filepath to the initial server state checkpoint. The
944*14675a02SAndroid Build Coastguard Worker  // server_prepare logic reads from this filepath.
945*14675a02SAndroid Build Coastguard Worker  string prepare_server_state_input_filepath_tensor_name = 1;
946*14675a02SAndroid Build Coastguard Worker
947*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_prepare TF graph that
948*14675a02SAndroid Build Coastguard Worker  // is fed the filepath where the client checkpoint should be stored. The
949*14675a02SAndroid Build Coastguard Worker  // server_prepare logic writes to this filepath.
950*14675a02SAndroid Build Coastguard Worker  string prepare_output_filepath_tensor_name = 2;
951*14675a02SAndroid Build Coastguard Worker
952*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_prepare TF graph that
953*14675a02SAndroid Build Coastguard Worker  // is fed the filepath where the intermediate state checkpoint should be
954*14675a02SAndroid Build Coastguard Worker  // stored. The server_prepare logic writes to this filepath. The intermediate
955*14675a02SAndroid Build Coastguard Worker  // state checkpoint will be consumed by both the logic used to set parameters
956*14675a02SAndroid Build Coastguard Worker  // for aggregation and the post-aggregation logic.
957*14675a02SAndroid Build Coastguard Worker  string prepare_intermediate_state_output_filepath_tensor_name = 3;
958*14675a02SAndroid Build Coastguard Worker}
959*14675a02SAndroid Build Coastguard Worker
960*14675a02SAndroid Build Coastguard Worker// Routing for server_result graph
961*14675a02SAndroid Build Coastguard Workermessage ServerResultIORouter {
962*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_result TF graph that is
963*14675a02SAndroid Build Coastguard Worker  // fed the filepath to the intermediate state checkpoint. The server_result
964*14675a02SAndroid Build Coastguard Worker  // logic reads from this filepath.
965*14675a02SAndroid Build Coastguard Worker  string result_intermediate_state_input_filepath_tensor_name = 1;
966*14675a02SAndroid Build Coastguard Worker
967*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_result TF graph that is
968*14675a02SAndroid Build Coastguard Worker  // fed the filepath to the aggregated client result checkpoint. The
969*14675a02SAndroid Build Coastguard Worker  // server_result logic reads from this filepath.
970*14675a02SAndroid Build Coastguard Worker  string result_aggregate_result_input_filepath_tensor_name = 2;
971*14675a02SAndroid Build Coastguard Worker
972*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor in the server_result TF graph that is
973*14675a02SAndroid Build Coastguard Worker  // fed the filepath where the updated server state should be stored. The
974*14675a02SAndroid Build Coastguard Worker  // server_result logic writes to this filepath.
975*14675a02SAndroid Build Coastguard Worker  string result_server_state_output_filepath_tensor_name = 3;
976*14675a02SAndroid Build Coastguard Worker}
977*14675a02SAndroid Build Coastguard Worker
978*14675a02SAndroid Build Coastguard Worker// Represents a single aggregation operation, combining one or more input
979*14675a02SAndroid Build Coastguard Worker// tensors from a collection of clients into one or more output tensors on the
980*14675a02SAndroid Build Coastguard Worker// server.
981*14675a02SAndroid Build Coastguard Workermessage ServerAggregationConfig {
982*14675a02SAndroid Build Coastguard Worker  // The uri of the aggregation intrinsic (e.g. 'federated_sum').
983*14675a02SAndroid Build Coastguard Worker  string intrinsic_uri = 1;
984*14675a02SAndroid Build Coastguard Worker
985*14675a02SAndroid Build Coastguard Worker  // Describes an argument to the aggregation operation.
986*14675a02SAndroid Build Coastguard Worker  message IntrinsicArg {
987*14675a02SAndroid Build Coastguard Worker    oneof arg {
988*14675a02SAndroid Build Coastguard Worker      // Refers to a tensor within the checkpoint provided by each client.
989*14675a02SAndroid Build Coastguard Worker      tensorflow.TensorSpecProto input_tensor = 2;
990*14675a02SAndroid Build Coastguard Worker
991*14675a02SAndroid Build Coastguard Worker      // Refers to a tensor within the intermediate server state checkpoint.
992*14675a02SAndroid Build Coastguard Worker      tensorflow.TensorSpecProto state_tensor = 3;
993*14675a02SAndroid Build Coastguard Worker    }
994*14675a02SAndroid Build Coastguard Worker  }
995*14675a02SAndroid Build Coastguard Worker
996*14675a02SAndroid Build Coastguard Worker  // List of arguments for the aggregation operation. The arguments can be
997*14675a02SAndroid Build Coastguard Worker  // dependent on client data (in which case they must be retrieved from
998*14675a02SAndroid Build Coastguard Worker  // clients) or they can be independent of client data (in which case they
999*14675a02SAndroid Build Coastguard Worker  // can be configured server-side). For now we assume all client-independent
1000*14675a02SAndroid Build Coastguard Worker  // arguments are constants. The arguments must be in the order expected by
1001*14675a02SAndroid Build Coastguard Worker  // the server.
1002*14675a02SAndroid Build Coastguard Worker  repeated IntrinsicArg intrinsic_args = 4;
1003*14675a02SAndroid Build Coastguard Worker
1004*14675a02SAndroid Build Coastguard Worker  // List of server-side outputs produced by the aggregation operation.
1005*14675a02SAndroid Build Coastguard Worker  repeated tensorflow.TensorSpecProto output_tensors = 5;
1006*14675a02SAndroid Build Coastguard Worker
1007*14675a02SAndroid Build Coastguard Worker  // List of inner aggregation intrinsics. This can be used to delegate parts
1008*14675a02SAndroid Build Coastguard Worker  // of the aggregation logic (e.g. a groupby intrinsic may want to delegate
1009*14675a02SAndroid Build Coastguard Worker  // a sum operation to a sum intrinsic).
1010*14675a02SAndroid Build Coastguard Worker  repeated ServerAggregationConfig inner_aggregations = 6;
1011*14675a02SAndroid Build Coastguard Worker}
1012*14675a02SAndroid Build Coastguard Worker
1013*14675a02SAndroid Build Coastguard Worker// Server Phase
1014*14675a02SAndroid Build Coastguard Worker// ============
1015*14675a02SAndroid Build Coastguard Worker
1016*14675a02SAndroid Build Coastguard Worker// Represents a server phase which implements TF-based aggregation of multiple
1017*14675a02SAndroid Build Coastguard Worker// client updates.
1018*14675a02SAndroid Build Coastguard Worker//
1019*14675a02SAndroid Build Coastguard Worker// There are two different modes of aggregation that are described
1020*14675a02SAndroid Build Coastguard Worker// by the values in this message. The first is aggregation that is
1021*14675a02SAndroid Build Coastguard Worker// coming from coordinated sets of clients. This includes aggregation
1022*14675a02SAndroid Build Coastguard Worker// done via checkpoints from clients or aggregation done over a set
1023*14675a02SAndroid Build Coastguard Worker// of clients by a process like secure aggregation. The results of
1024*14675a02SAndroid Build Coastguard Worker// this first aggregation are saved to intermediate aggregation
1025*14675a02SAndroid Build Coastguard Worker// checkpoints. The second aggregation then comes from taking
1026*14675a02SAndroid Build Coastguard Worker// these intermediate checkpoints and aggregating over them.
1027*14675a02SAndroid Build Coastguard Worker//
1028*14675a02SAndroid Build Coastguard Worker// These two different modes of aggregation are done on different
1029*14675a02SAndroid Build Coastguard Worker// servers, the first in the 'L1' servers and the second in the
1030*14675a02SAndroid Build Coastguard Worker// 'L2' servers, so we use this nomenclature to describe these
1031*14675a02SAndroid Build Coastguard Worker// phases below.
1032*14675a02SAndroid Build Coastguard Worker//
1033*14675a02SAndroid Build Coastguard Worker// The ServerPhase message is currently in the process of being replaced by the
1034*14675a02SAndroid Build Coastguard Worker// ServerPhaseV2 message as we switch the plan building pipeline to use
1035*14675a02SAndroid Build Coastguard Worker// DistributeAggregateForm instead of MapReduceForm. During the migration
1036*14675a02SAndroid Build Coastguard Worker// process, we may generate both messages and use components from either
1037*14675a02SAndroid Build Coastguard Worker// message during execution.
1038*14675a02SAndroid Build Coastguard Worker//
1039*14675a02SAndroid Build Coastguard Workermessage ServerPhase {
1040*14675a02SAndroid Build Coastguard Worker  // A short CamelCase name for the ServerPhase.
1041*14675a02SAndroid Build Coastguard Worker  string name = 8;
1042*14675a02SAndroid Build Coastguard Worker
1043*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
1044*14675a02SAndroid Build Coastguard Worker  // L1 "Intermediate" Aggregation.
1045*14675a02SAndroid Build Coastguard Worker  //
1046*14675a02SAndroid Build Coastguard Worker  // This is the initial aggregation that creates partial aggregates from client
1047*14675a02SAndroid Build Coastguard Worker  // results. L1 Aggregation may be run on many different instances.
1048*14675a02SAndroid Build Coastguard Worker  //
1049*14675a02SAndroid Build Coastguard Worker  // Pre-condition:
1050*14675a02SAndroid Build Coastguard Worker  //   The execution environment has loaded the graph from `server_graph_bytes`.
1051*14675a02SAndroid Build Coastguard Worker
1052*14675a02SAndroid Build Coastguard Worker  // 1. Initialize the phase.
1053*14675a02SAndroid Build Coastguard Worker  //
1054*14675a02SAndroid Build Coastguard Worker  // Operation to run before the first aggregation happens.
1055*14675a02SAndroid Build Coastguard Worker  // For instance, clears the accumulators so that a new aggregation can begin.
1056*14675a02SAndroid Build Coastguard Worker  string phase_init_op = 1;
1057*14675a02SAndroid Build Coastguard Worker
1058*14675a02SAndroid Build Coastguard Worker  // 2. For each client in set of clients:
1059*14675a02SAndroid Build Coastguard Worker  //   a. Restore variables from the client checkpoint.
1060*14675a02SAndroid Build Coastguard Worker  //
1061*14675a02SAndroid Build Coastguard Worker  // Loads a checkpoint from a single client written via
1062*14675a02SAndroid Build Coastguard Worker  // `FederatedComputeIORouter.output_filepath_tensor_name`.  This is done once
1063*14675a02SAndroid Build Coastguard Worker  // for every client checkpoint in a round.
1064*14675a02SAndroid Build Coastguard Worker  CheckpointOp read_update = 3;
1065*14675a02SAndroid Build Coastguard Worker  //   b. Aggregate the data coming from the client checkpoint.
1066*14675a02SAndroid Build Coastguard Worker  //
1067*14675a02SAndroid Build Coastguard Worker  // An operation that aggregates the data from read_update.
1068*14675a02SAndroid Build Coastguard Worker  // Generally this will add to accumulators and it may leverage internal data
1069*14675a02SAndroid Build Coastguard Worker  // inside the graph to adjust the weights of the Tensors.
1070*14675a02SAndroid Build Coastguard Worker  //
1071*14675a02SAndroid Build Coastguard Worker  // Executed once for each `read_update`, to (for example) update accumulator
1072*14675a02SAndroid Build Coastguard Worker  // variables using the values loaded during `read_update`.
1073*14675a02SAndroid Build Coastguard Worker  string aggregate_into_accumulators_op = 4;
1074*14675a02SAndroid Build Coastguard Worker
1075*14675a02SAndroid Build Coastguard Worker  // 3. After all clients have been aggregated, possibly restore
1076*14675a02SAndroid Build Coastguard Worker  //    variables that have been aggregated via a separate process.
1077*14675a02SAndroid Build Coastguard Worker  //
1078*14675a02SAndroid Build Coastguard Worker  // Optionally restores variables where aggregation is done across
1079*14675a02SAndroid Build Coastguard Worker  // an entire round of client data updates. In contrast to `read_update`,
1080*14675a02SAndroid Build Coastguard Worker  // which restores once per client, this occurs after all clients
1081*14675a02SAndroid Build Coastguard Worker  // in a round have been processed. This allows, for example, side
1082*14675a02SAndroid Build Coastguard Worker  // channels where aggregation is done by a separate process (such
1083*14675a02SAndroid Build Coastguard Worker  // as in secure aggregation), in which the side channel aggregated
1084*14675a02SAndroid Build Coastguard Worker  // tensor is passed to the `before_restore_op` which ensure the
1085*14675a02SAndroid Build Coastguard Worker  // variables are restored properly. The `after_restore_op` will then
1086*14675a02SAndroid Build Coastguard Worker  // be responsible for performing the accumulation.
1087*14675a02SAndroid Build Coastguard Worker  //
1088*14675a02SAndroid Build Coastguard Worker  // Note that in current use this should not have a SaverDef, but
1089*14675a02SAndroid Build Coastguard Worker  // should only be used for side channels.
1090*14675a02SAndroid Build Coastguard Worker  CheckpointOp read_aggregated_update = 10;
1091*14675a02SAndroid Build Coastguard Worker
1092*14675a02SAndroid Build Coastguard Worker  // 4. Write the aggregated variables to an intermediate checkpoint.
1093*14675a02SAndroid Build Coastguard Worker  //
1094*14675a02SAndroid Build Coastguard Worker  // We require that `aggregate_into_accumulators_op` is associative and
1095*14675a02SAndroid Build Coastguard Worker  // commutative, so that the aggregates can be computed across
1096*14675a02SAndroid Build Coastguard Worker  // multiple TensorFlow sessions.
1097*14675a02SAndroid Build Coastguard Worker  // As an example, say we are computing the sum of 5 client updates:
1098*14675a02SAndroid Build Coastguard Worker  //    A = X1 + X2 + X3 + X4 + X5
1099*14675a02SAndroid Build Coastguard Worker  // We can always do this in one session by calling `read_update`j and
1100*14675a02SAndroid Build Coastguard Worker  // `aggregate_into_accumulators_op` once for each client checkpoint.
1101*14675a02SAndroid Build Coastguard Worker  //
1102*14675a02SAndroid Build Coastguard Worker  // Alternatively, we could compute:
1103*14675a02SAndroid Build Coastguard Worker  //    A1 = X1 + X2 in one TensorFlow session, and
1104*14675a02SAndroid Build Coastguard Worker  //    A2 = X3 + X4 + X5 in a different session.
1105*14675a02SAndroid Build Coastguard Worker  // Each of these sessions can then write their accumulator state
1106*14675a02SAndroid Build Coastguard Worker  // with the `write_intermediate_update` CheckpointOp, and a yet another third
1107*14675a02SAndroid Build Coastguard Worker  // session can then call `read_intermediate_update` and
1108*14675a02SAndroid Build Coastguard Worker  // `aggregate_into_accumulators_op` on each of these checkpoints to compute:
1109*14675a02SAndroid Build Coastguard Worker  //    A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5).
1110*14675a02SAndroid Build Coastguard Worker  CheckpointOp write_intermediate_update = 7;
1111*14675a02SAndroid Build Coastguard Worker  // End L1 "Intermediate" Aggregation.
1112*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
1113*14675a02SAndroid Build Coastguard Worker
1114*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
1115*14675a02SAndroid Build Coastguard Worker  // L2 Aggregation and Coordinator.
1116*14675a02SAndroid Build Coastguard Worker  //
1117*14675a02SAndroid Build Coastguard Worker  // This aggregates intermediate checkpoints from L1 Aggregation and performs
1118*14675a02SAndroid Build Coastguard Worker  // the finalizing of the update. Unlike L1 there will only be one instance
1119*14675a02SAndroid Build Coastguard Worker  // that does this aggregation.
1120*14675a02SAndroid Build Coastguard Worker
1121*14675a02SAndroid Build Coastguard Worker  // Pre-condition:
1122*14675a02SAndroid Build Coastguard Worker  //   The execution environment has loaded the graph from `server_graph_bytes`
1123*14675a02SAndroid Build Coastguard Worker  //   and restored the global model using `server_savepoint` from the parent
1124*14675a02SAndroid Build Coastguard Worker  //   `Plan` message.
1125*14675a02SAndroid Build Coastguard Worker
1126*14675a02SAndroid Build Coastguard Worker  // 1. Initialize the phase.
1127*14675a02SAndroid Build Coastguard Worker  //
1128*14675a02SAndroid Build Coastguard Worker  // This currently re-uses the `phase_init_op` from L1 aggregation above.
1129*14675a02SAndroid Build Coastguard Worker
1130*14675a02SAndroid Build Coastguard Worker  // 2. Write a checkpoint that can be sent to the client.
1131*14675a02SAndroid Build Coastguard Worker  //
1132*14675a02SAndroid Build Coastguard Worker  // Generates a checkpoint to be sent to the client, to be read by
1133*14675a02SAndroid Build Coastguard Worker  // `FederatedComputeIORouter.input_filepath_tensor_name`.
1134*14675a02SAndroid Build Coastguard Worker
1135*14675a02SAndroid Build Coastguard Worker  CheckpointOp write_client_init = 2;
1136*14675a02SAndroid Build Coastguard Worker
1137*14675a02SAndroid Build Coastguard Worker  // 3. For each intermediate checkpoint:
1138*14675a02SAndroid Build Coastguard Worker  //   a. Restore variables from the intermediate checkpoint.
1139*14675a02SAndroid Build Coastguard Worker  //
1140*14675a02SAndroid Build Coastguard Worker  // The corresponding read checkpoint op to the write_intermediate_update.
1141*14675a02SAndroid Build Coastguard Worker  // This is used instead of read_update for intermediate checkpoints because
1142*14675a02SAndroid Build Coastguard Worker  // the format of these updates may be different than those used in updates
1143*14675a02SAndroid Build Coastguard Worker  // from clients (which may, for example, be compressed).
1144*14675a02SAndroid Build Coastguard Worker  CheckpointOp read_intermediate_update = 9;
1145*14675a02SAndroid Build Coastguard Worker  //   b. Aggregate the data coming from the intermediate checkpoint.
1146*14675a02SAndroid Build Coastguard Worker  //
1147*14675a02SAndroid Build Coastguard Worker  // An operation that aggregates the data from `read_intermediate_update`.
1148*14675a02SAndroid Build Coastguard Worker  // Generally this will add to accumulators and it may leverage internal data
1149*14675a02SAndroid Build Coastguard Worker  // inside the graph to adjust the weights of the Tensors.
1150*14675a02SAndroid Build Coastguard Worker  string intermediate_aggregate_into_accumulators_op = 11;
1151*14675a02SAndroid Build Coastguard Worker
1152*14675a02SAndroid Build Coastguard Worker  // 4. Write the aggregated intermediate variables to a checkpoint.
1153*14675a02SAndroid Build Coastguard Worker  //
1154*14675a02SAndroid Build Coastguard Worker  // This is used for downstream, cross-round aggregation of metrics.
1155*14675a02SAndroid Build Coastguard Worker  // These variables will be read back into a session with
1156*14675a02SAndroid Build Coastguard Worker  // read_intermediate_update.
1157*14675a02SAndroid Build Coastguard Worker  //
1158*14675a02SAndroid Build Coastguard Worker  // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def
1159*14675a02SAndroid Build Coastguard Worker  // to disable writing accumulator checkpoints.
1160*14675a02SAndroid Build Coastguard Worker  CheckpointOp write_accumulators = 12;
1161*14675a02SAndroid Build Coastguard Worker
1162*14675a02SAndroid Build Coastguard Worker  // 5. Finalize the round.
1163*14675a02SAndroid Build Coastguard Worker  //
1164*14675a02SAndroid Build Coastguard Worker  // This can include:
1165*14675a02SAndroid Build Coastguard Worker  // - Applying the update aggregated from the intermediate checkpoints to the
1166*14675a02SAndroid Build Coastguard Worker  //   global model and other updates to cross-round state variables.
1167*14675a02SAndroid Build Coastguard Worker  // - Computing final round metric values (e.g. the `report` of a
1168*14675a02SAndroid Build Coastguard Worker  //   `tff.federated_aggregate`).
1169*14675a02SAndroid Build Coastguard Worker  string apply_aggregrated_updates_op = 5;
1170*14675a02SAndroid Build Coastguard Worker
1171*14675a02SAndroid Build Coastguard Worker  // 5. Fetch the server aggregated metrics.
1172*14675a02SAndroid Build Coastguard Worker  //
1173*14675a02SAndroid Build Coastguard Worker  // A list of names of metric variables to fetch from the TensorFlow session.
1174*14675a02SAndroid Build Coastguard Worker  repeated Metric metrics = 6;
1175*14675a02SAndroid Build Coastguard Worker
1176*14675a02SAndroid Build Coastguard Worker  // 6. Serialize the updated server state (e.g. the coefficients of the global
1177*14675a02SAndroid Build Coastguard Worker  //    model in FL) using `server_savepoint` in the parent `Plan` message.
1178*14675a02SAndroid Build Coastguard Worker
1179*14675a02SAndroid Build Coastguard Worker  // End L2 Aggregation.
1180*14675a02SAndroid Build Coastguard Worker  // ===========================================================================
1181*14675a02SAndroid Build Coastguard Worker}
1182*14675a02SAndroid Build Coastguard Worker
1183*14675a02SAndroid Build Coastguard Worker// Represents the server phase in an eligibility computation.
1184*14675a02SAndroid Build Coastguard Worker//
1185*14675a02SAndroid Build Coastguard Worker// This phase produces a checkpoint to be sent to clients. This checkpoint is
1186*14675a02SAndroid Build Coastguard Worker// then used as an input to the clients' task eligibility computations.
1187*14675a02SAndroid Build Coastguard Worker// This phase *does not include any aggregation.*
1188*14675a02SAndroid Build Coastguard Workermessage ServerEligibilityComputationPhase {
1189*14675a02SAndroid Build Coastguard Worker  // A short CamelCase name for the ServerEligibilityComputationPhase.
1190*14675a02SAndroid Build Coastguard Worker  string name = 1;
1191*14675a02SAndroid Build Coastguard Worker
1192*14675a02SAndroid Build Coastguard Worker  // The names of the TensorFlow nodes to run in order to produce output.
1193*14675a02SAndroid Build Coastguard Worker  repeated string target_node_names = 2;
1194*14675a02SAndroid Build Coastguard Worker
1195*14675a02SAndroid Build Coastguard Worker  // The specification of inputs and outputs to the TensorFlow graph.
1196*14675a02SAndroid Build Coastguard Worker  oneof server_eligibility_io_router {
1197*14675a02SAndroid Build Coastguard Worker    TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true];
1198*14675a02SAndroid Build Coastguard Worker  }
1199*14675a02SAndroid Build Coastguard Worker}
1200*14675a02SAndroid Build Coastguard Worker
1201*14675a02SAndroid Build Coastguard Worker// Represents the inputs and outputs of a `ServerEligibilityComputationPhase`
1202*14675a02SAndroid Build Coastguard Worker// which takes a single `TaskEligibilityContext` as input.
1203*14675a02SAndroid Build Coastguard Workermessage TEContextServerEligibilityIORouter {
1204*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor that must be fed a serialized
1205*14675a02SAndroid Build Coastguard Worker  // `TaskEligibilityContext`.
1206*14675a02SAndroid Build Coastguard Worker  string context_proto_input_tensor_name = 1;
1207*14675a02SAndroid Build Coastguard Worker
1208*14675a02SAndroid Build Coastguard Worker  // The name of the scalar string tensor that must be fed the path to which
1209*14675a02SAndroid Build Coastguard Worker  // the server graph should write the checkpoint file to be sent to the client.
1210*14675a02SAndroid Build Coastguard Worker  string output_filepath_tensor_name = 2;
1211*14675a02SAndroid Build Coastguard Worker}
1212*14675a02SAndroid Build Coastguard Worker
1213*14675a02SAndroid Build Coastguard Worker// Plan
1214*14675a02SAndroid Build Coastguard Worker// =====
1215*14675a02SAndroid Build Coastguard Worker
1216*14675a02SAndroid Build Coastguard Worker// Represents the overall plan for performing federated optimization or
1217*14675a02SAndroid Build Coastguard Worker// personalization, as handed over to the production system. This will
1218*14675a02SAndroid Build Coastguard Worker// typically be split down into individual pieces for different production
1219*14675a02SAndroid Build Coastguard Worker// parts, e.g. server and client side.
1220*14675a02SAndroid Build Coastguard Worker// NEXT_TAG: 15
1221*14675a02SAndroid Build Coastguard Workermessage Plan {
1222*14675a02SAndroid Build Coastguard Worker  reserved 1, 3, 5;
1223*14675a02SAndroid Build Coastguard Worker
1224*14675a02SAndroid Build Coastguard Worker  // The actual type of the server_*_graph_bytes fields below is expected to be
1225*14675a02SAndroid Build Coastguard Worker  // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form
1226*14675a02SAndroid Build Coastguard Worker  // for two reasons.
1227*14675a02SAndroid Build Coastguard Worker  // 1) We may use execution engines other than TensorFlow.
1228*14675a02SAndroid Build Coastguard Worker  // 2) We wish to avoid the cost of deserialized and re-serializing large
1229*14675a02SAndroid Build Coastguard Worker  // graphs, in the Federated Learning service.
1230*14675a02SAndroid Build Coastguard Worker
1231*14675a02SAndroid Build Coastguard Worker  // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes,
1232*14675a02SAndroid Build Coastguard Worker  // server_graph_prepare_bytes, and server_graph_result_bytes may all be set.
1233*14675a02SAndroid Build Coastguard Worker  // If we're using a MapReduceForm-based server implementation, only
1234*14675a02SAndroid Build Coastguard Worker  // server_graph_bytes will be used. If we're using a DistributeAggregateForm-
1235*14675a02SAndroid Build Coastguard Worker  // based server implementation, only server_graph_prepare_bytes and
1236*14675a02SAndroid Build Coastguard Worker  // server_graph_result_bytes will be used.
1237*14675a02SAndroid Build Coastguard Worker
1238*14675a02SAndroid Build Coastguard Worker  // Optional. The TensorFlow graph used for all server processing described by
1239*14675a02SAndroid Build Coastguard Worker  // ServerPhase. For personalization, this will not be set.
1240*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any server_graph_bytes = 7;
1241*14675a02SAndroid Build Coastguard Worker
1242*14675a02SAndroid Build Coastguard Worker  // Optional. The TensorFlow graph used for all server processing described by
1243*14675a02SAndroid Build Coastguard Worker  // ServerPhaseV2.tensorflow_spec_prepare.
1244*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any server_graph_prepare_bytes = 13;
1245*14675a02SAndroid Build Coastguard Worker
1246*14675a02SAndroid Build Coastguard Worker  // Optional. The TensorFlow graph used for all server processing described by
1247*14675a02SAndroid Build Coastguard Worker  // ServerPhaseV2.tensorflow_spec_result.
1248*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any server_graph_result_bytes = 14;
1249*14675a02SAndroid Build Coastguard Worker
1250*14675a02SAndroid Build Coastguard Worker  // A savepoint to sync the server checkpoint with a persistent
1251*14675a02SAndroid Build Coastguard Worker  // storage system.  The storage initially holds a seeded checkpoint
1252*14675a02SAndroid Build Coastguard Worker  // which can subsequently read and updated by this savepoint.
1253*14675a02SAndroid Build Coastguard Worker  // Optional-- not present in eligibility computation plans (those with a
1254*14675a02SAndroid Build Coastguard Worker  // ServerEligibilityComputationPhase). This is used in conjunction with
1255*14675a02SAndroid Build Coastguard Worker  // ServerPhase only.
1256*14675a02SAndroid Build Coastguard Worker  CheckpointOp server_savepoint = 2;
1257*14675a02SAndroid Build Coastguard Worker
1258*14675a02SAndroid Build Coastguard Worker  // Required. The TensorFlow graph that describes the TensorFlow logic a client
1259*14675a02SAndroid Build Coastguard Worker  // should perform. It should be consistent with the `TensorflowSpec` field in
1260*14675a02SAndroid Build Coastguard Worker  // the `client_phase`. The actual type is expected to be tensorflow.GraphDef.
1261*14675a02SAndroid Build Coastguard Worker  // The TensorFlow graph is stored in serialized form for two reasons.
1262*14675a02SAndroid Build Coastguard Worker  // 1) We may use execution engines other than TensorFlow.
1263*14675a02SAndroid Build Coastguard Worker  // 2) We wish to avoid the cost of deserialized and re-serializing large
1264*14675a02SAndroid Build Coastguard Worker  // graphs, in the Federated Learning service.
1265*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any client_graph_bytes = 8;
1266*14675a02SAndroid Build Coastguard Worker
1267*14675a02SAndroid Build Coastguard Worker  // Optional. The FlatBuffer used for TFLite training.
1268*14675a02SAndroid Build Coastguard Worker  // It contains the same model information as the client_graph_bytes, but with
1269*14675a02SAndroid Build Coastguard Worker  // a different format.
1270*14675a02SAndroid Build Coastguard Worker  bytes client_tflite_graph_bytes = 12;
1271*14675a02SAndroid Build Coastguard Worker
1272*14675a02SAndroid Build Coastguard Worker  // A pair of client phase and server phase which are processed in
1273*14675a02SAndroid Build Coastguard Worker  // sync. The server execution defines how the results of a client
1274*14675a02SAndroid Build Coastguard Worker  // phase are aggregated, and how the checkpoints for clients are
1275*14675a02SAndroid Build Coastguard Worker  // generated.
1276*14675a02SAndroid Build Coastguard Worker  message Phase {
1277*14675a02SAndroid Build Coastguard Worker    // Required. The client phase.
1278*14675a02SAndroid Build Coastguard Worker    ClientPhase client_phase = 1;
1279*14675a02SAndroid Build Coastguard Worker
1280*14675a02SAndroid Build Coastguard Worker    // Optional. Server phase for TF-based aggregation; not provided for
1281*14675a02SAndroid Build Coastguard Worker    // personalization or eligibility tasks.
1282*14675a02SAndroid Build Coastguard Worker    ServerPhase server_phase = 2;
1283*14675a02SAndroid Build Coastguard Worker
1284*14675a02SAndroid Build Coastguard Worker    // Optional. Server phase for native aggregation; only provided for tasks
1285*14675a02SAndroid Build Coastguard Worker    // that have enabled the corresponding flag.
1286*14675a02SAndroid Build Coastguard Worker    ServerPhaseV2 server_phase_v2 = 4;
1287*14675a02SAndroid Build Coastguard Worker
1288*14675a02SAndroid Build Coastguard Worker    // Optional. Only provided for eligibility tasks.
1289*14675a02SAndroid Build Coastguard Worker    ServerEligibilityComputationPhase server_eligibility_phase = 3;
1290*14675a02SAndroid Build Coastguard Worker  }
1291*14675a02SAndroid Build Coastguard Worker
1292*14675a02SAndroid Build Coastguard Worker  // A pair of client and server computations to run.
1293*14675a02SAndroid Build Coastguard Worker  repeated Phase phase = 4;
1294*14675a02SAndroid Build Coastguard Worker
1295*14675a02SAndroid Build Coastguard Worker  // Metrics that are persistent across different phases. This
1296*14675a02SAndroid Build Coastguard Worker  // includes, for example, counters that track how much work of
1297*14675a02SAndroid Build Coastguard Worker  // different kinds has been done.
1298*14675a02SAndroid Build Coastguard Worker  repeated Metric metrics = 6;
1299*14675a02SAndroid Build Coastguard Worker
1300*14675a02SAndroid Build Coastguard Worker  // Describes how metrics in both the client and server phases should be
1301*14675a02SAndroid Build Coastguard Worker  // aggregated.
1302*14675a02SAndroid Build Coastguard Worker  repeated OutputMetric output_metrics = 10;
1303*14675a02SAndroid Build Coastguard Worker
1304*14675a02SAndroid Build Coastguard Worker  // Version of the plan:
1305*14675a02SAndroid Build Coastguard Worker  // version == 0 - Old plan without version field, containing b/65131070
1306*14675a02SAndroid Build Coastguard Worker  // version >= 1 - plan supports multi-shard aggregation mode (L1/L2)
1307*14675a02SAndroid Build Coastguard Worker  int32 version = 9;
1308*14675a02SAndroid Build Coastguard Worker
1309*14675a02SAndroid Build Coastguard Worker  // A TensorFlow ConfigProto packed in an Any.
1310*14675a02SAndroid Build Coastguard Worker  //
1311*14675a02SAndroid Build Coastguard Worker  // If this field is unset, if the Any proto is set but empty, or if the Any
1312*14675a02SAndroid Build Coastguard Worker  // proto is populated with an empty ConfigProto (i.e. its `type_url` field is
1313*14675a02SAndroid Build Coastguard Worker  // set, but the `value` field is empty) then the client implementation may
1314*14675a02SAndroid Build Coastguard Worker  // choose a set of configuration parameters to provide to TensorFlow by
1315*14675a02SAndroid Build Coastguard Worker  // default.
1316*14675a02SAndroid Build Coastguard Worker  //
1317*14675a02SAndroid Build Coastguard Worker  // In all other cases this field must contain a valid packed ConfigProto
1318*14675a02SAndroid Build Coastguard Worker  // (invalid values will result in an error at execution time), and in this
1319*14675a02SAndroid Build Coastguard Worker  // case the client will not provide any other configuration parameters by
1320*14675a02SAndroid Build Coastguard Worker  // default.
1321*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any tensorflow_config_proto = 11;
1322*14675a02SAndroid Build Coastguard Worker}
1323*14675a02SAndroid Build Coastguard Worker
1324*14675a02SAndroid Build Coastguard Worker// Represents a client part of the plan of federated optimization.
1325*14675a02SAndroid Build Coastguard Worker// This also used to describe a client-only plan for standalone on-device
1326*14675a02SAndroid Build Coastguard Worker// training, known as personalization.
1327*14675a02SAndroid Build Coastguard Worker// NEXT_TAG: 6
1328*14675a02SAndroid Build Coastguard Workermessage ClientOnlyPlan {
1329*14675a02SAndroid Build Coastguard Worker  reserved 3;
1330*14675a02SAndroid Build Coastguard Worker
1331*14675a02SAndroid Build Coastguard Worker  // The graph to use for training, in binary form.
1332*14675a02SAndroid Build Coastguard Worker  bytes graph = 1;
1333*14675a02SAndroid Build Coastguard Worker
1334*14675a02SAndroid Build Coastguard Worker  // Optional. The flatbuffer used for TFLite training.
1335*14675a02SAndroid Build Coastguard Worker  // Whether "graph" or "tflite_graph" is used for training is up to the client
1336*14675a02SAndroid Build Coastguard Worker  // code to allow for a flag-controlled a/b rollout.
1337*14675a02SAndroid Build Coastguard Worker  bytes tflite_graph = 5;
1338*14675a02SAndroid Build Coastguard Worker
1339*14675a02SAndroid Build Coastguard Worker  // The client phase to execute.
1340*14675a02SAndroid Build Coastguard Worker  ClientPhase phase = 2;
1341*14675a02SAndroid Build Coastguard Worker
1342*14675a02SAndroid Build Coastguard Worker  // A TensorFlow ConfigProto.
1343*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any tensorflow_config_proto = 4;
1344*14675a02SAndroid Build Coastguard Worker}
1345*14675a02SAndroid Build Coastguard Worker
1346*14675a02SAndroid Build Coastguard Worker// Represents the cross round aggregation portion for user defined measurements.
1347*14675a02SAndroid Build Coastguard Worker// This is used by tools that process / analyze accumulator checkpoints
1348*14675a02SAndroid Build Coastguard Worker// after a round of computation, to achieve aggregation beyond a round.
1349*14675a02SAndroid Build Coastguard Workermessage CrossRoundAggregationExecution {
1350*14675a02SAndroid Build Coastguard Worker  // Operation to run before reading accumulator checkpoint.
1351*14675a02SAndroid Build Coastguard Worker  string init_op = 1;
1352*14675a02SAndroid Build Coastguard Worker
1353*14675a02SAndroid Build Coastguard Worker  // Reads accumulator checkpoint.
1354*14675a02SAndroid Build Coastguard Worker  CheckpointOp read_aggregated_update = 2;
1355*14675a02SAndroid Build Coastguard Worker
1356*14675a02SAndroid Build Coastguard Worker  // Operation to merge loaded checkpoint into accumulator.
1357*14675a02SAndroid Build Coastguard Worker  string merge_op = 3;
1358*14675a02SAndroid Build Coastguard Worker
1359*14675a02SAndroid Build Coastguard Worker  // Reads and writes the final aggregated accumulator vars.
1360*14675a02SAndroid Build Coastguard Worker  CheckpointOp read_write_final_accumulators = 6;
1361*14675a02SAndroid Build Coastguard Worker
1362*14675a02SAndroid Build Coastguard Worker  // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable`
1363*14675a02SAndroid Build Coastguard Worker  // to the user defined name of the signal.
1364*14675a02SAndroid Build Coastguard Worker  repeated Measurement measurements = 4;
1365*14675a02SAndroid Build Coastguard Worker
1366*14675a02SAndroid Build Coastguard Worker  // The `tf.Graph` used for aggregating accumulator checkpoints when
1367*14675a02SAndroid Build Coastguard Worker  // loading metrics.
1368*14675a02SAndroid Build Coastguard Worker  google.protobuf.Any cross_round_aggregation_graph_bytes = 5;
1369*14675a02SAndroid Build Coastguard Worker}
1370*14675a02SAndroid Build Coastguard Worker
1371*14675a02SAndroid Build Coastguard Workermessage Measurement {
1372*14675a02SAndroid Build Coastguard Worker  // Name of a TensorFlow op to run to read/fetch the value of this measurement.
1373*14675a02SAndroid Build Coastguard Worker  string read_op_name = 1;
1374*14675a02SAndroid Build Coastguard Worker
1375*14675a02SAndroid Build Coastguard Worker  // A human-readable name for the measurement. Names are usually
1376*14675a02SAndroid Build Coastguard Worker  // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'.
1377*14675a02SAndroid Build Coastguard Worker  string name = 2;
1378*14675a02SAndroid Build Coastguard Worker
1379*14675a02SAndroid Build Coastguard Worker  reserved 3;
1380*14675a02SAndroid Build Coastguard Worker
1381*14675a02SAndroid Build Coastguard Worker  // A serialized `tff.Type` for the measurement.
1382*14675a02SAndroid Build Coastguard Worker  bytes tff_type = 4;
1383*14675a02SAndroid Build Coastguard Worker}
1384