1*14675a02SAndroid Build Coastguard Worker// Copyright 2021 Google LLC 2*14675a02SAndroid Build Coastguard Worker// 3*14675a02SAndroid Build Coastguard Worker// Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker// you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker// You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker// 7*14675a02SAndroid Build Coastguard Worker// http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker// 9*14675a02SAndroid Build Coastguard Worker// Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker// distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker// See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker// limitations under the License. 14*14675a02SAndroid Build Coastguard Worker 15*14675a02SAndroid Build Coastguard Workersyntax = "proto3"; 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Workerpackage google.internal.federated.plan; 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Workerimport "google/protobuf/any.proto"; 20*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/framework/tensor.proto"; 21*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/framework/types.proto"; 22*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/protobuf/saver.proto"; 23*14675a02SAndroid Build Coastguard Workerimport "tensorflow/core/protobuf/struct.proto"; 24*14675a02SAndroid Build Coastguard Worker 25*14675a02SAndroid Build Coastguard Workeroption java_package = "com.google.internal.federated.plan"; 26*14675a02SAndroid Build Coastguard Workeroption java_multiple_files = true; 27*14675a02SAndroid Build Coastguard Workeroption java_outer_classname = "PlanProto"; 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Worker// Primitives 30*14675a02SAndroid Build Coastguard Worker// =========== 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Worker// Represents an operation to save or restore from a checkpoint. Some 33*14675a02SAndroid Build Coastguard Worker// instances of this message may only be used either for restore or for 34*14675a02SAndroid Build Coastguard Worker// save, others for both directions. This is documented together with 35*14675a02SAndroid Build Coastguard Worker// their usage. 36*14675a02SAndroid Build Coastguard Worker// 37*14675a02SAndroid Build Coastguard Worker// This op has four essential uses: 38*14675a02SAndroid Build Coastguard Worker// 1. read and apply a checkpoint. 39*14675a02SAndroid Build Coastguard Worker// 2. write a checkpoint. 40*14675a02SAndroid Build Coastguard Worker// 3. read and apply from an aggregated side channel. 41*14675a02SAndroid Build Coastguard Worker// 4. write to a side channel (grouped with write a checkpoint). 42*14675a02SAndroid Build Coastguard Worker// We should consider splitting this into four separate messages. 43*14675a02SAndroid Build Coastguard Workermessage CheckpointOp { 44*14675a02SAndroid Build Coastguard Worker // An optional standard saver def. If not provided, only the 45*14675a02SAndroid Build Coastguard Worker // op(s) below will be executed. This must be a version 1 SaverDef. 46*14675a02SAndroid Build Coastguard Worker tensorflow.SaverDef saver_def = 1; 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Worker // An optional operation to run before the saver_def is executed for 49*14675a02SAndroid Build Coastguard Worker // restore. 50*14675a02SAndroid Build Coastguard Worker string before_restore_op = 2; 51*14675a02SAndroid Build Coastguard Worker 52*14675a02SAndroid Build Coastguard Worker // An optional operation to run after the saver_def has been 53*14675a02SAndroid Build Coastguard Worker // executed for restore. If side_channel_tensors are provided, then 54*14675a02SAndroid Build Coastguard Worker // they should be provided in a feed_dict to this op. 55*14675a02SAndroid Build Coastguard Worker string after_restore_op = 3; 56*14675a02SAndroid Build Coastguard Worker 57*14675a02SAndroid Build Coastguard Worker // An optional operation to run before the saver_def will be 58*14675a02SAndroid Build Coastguard Worker // executed for save. 59*14675a02SAndroid Build Coastguard Worker string before_save_op = 4; 60*14675a02SAndroid Build Coastguard Worker 61*14675a02SAndroid Build Coastguard Worker // An optional operation to run after the saver_def has been 62*14675a02SAndroid Build Coastguard Worker // executed for save. If there are side_channel_tensors, this op 63*14675a02SAndroid Build Coastguard Worker // should be run after the side_channel_tensors have been fetched. 64*14675a02SAndroid Build Coastguard Worker string after_save_op = 5; 65*14675a02SAndroid Build Coastguard Worker 66*14675a02SAndroid Build Coastguard Worker // In addition to being saved and restored from a checkpoint, one can 67*14675a02SAndroid Build Coastguard Worker // also save and restore via a side channel. The keys in this map are 68*14675a02SAndroid Build Coastguard Worker // the names of the tensors transmitted by the side channel. These (key) 69*14675a02SAndroid Build Coastguard Worker // tensors should be read off just before saving a SaveDef and used 70*14675a02SAndroid Build Coastguard Worker // by the code that handles the side channel. Any variables provided this 71*14675a02SAndroid Build Coastguard Worker // way should NOT be saved in the SaveDef. 72*14675a02SAndroid Build Coastguard Worker // 73*14675a02SAndroid Build Coastguard Worker // For restoring, the variables that are provided by the side channel 74*14675a02SAndroid Build Coastguard Worker // are restored differently than those for a checkpoint. For those from 75*14675a02SAndroid Build Coastguard Worker // the side channel, these should be restored by calling the before_restore_op 76*14675a02SAndroid Build Coastguard Worker // with a feed dict whose keys are the restore_names in the SideChannel and 77*14675a02SAndroid Build Coastguard Worker // whose values are the values to be restored. 78*14675a02SAndroid Build Coastguard Worker map<string, SideChannel> side_channel_tensors = 6; 79*14675a02SAndroid Build Coastguard Worker 80*14675a02SAndroid Build Coastguard Worker // An optional name of a tensor in to which a unique token for the current 81*14675a02SAndroid Build Coastguard Worker // session should be written. 82*14675a02SAndroid Build Coastguard Worker // 83*14675a02SAndroid Build Coastguard Worker // This session identifier allows TensorFlow ops such as `ServeSlices` or 84*14675a02SAndroid Build Coastguard Worker // `ExternalDataset` to refer to callbacks and other session-global objects 85*14675a02SAndroid Build Coastguard Worker // registered before running the session. 86*14675a02SAndroid Build Coastguard Worker string session_token_tensor_name = 7; 87*14675a02SAndroid Build Coastguard Worker} 88*14675a02SAndroid Build Coastguard Worker 89*14675a02SAndroid Build Coastguard Workermessage SideChannel { 90*14675a02SAndroid Build Coastguard Worker // A side channel whose variables are processed via SecureAggregation. 91*14675a02SAndroid Build Coastguard Worker // This side channel implements aggregation via sum over a set of 92*14675a02SAndroid Build Coastguard Worker // clients, so the restored tensor will be a sum of multiple clients 93*14675a02SAndroid Build Coastguard Worker // inputs into the side channel. Hence this will restore during the 94*14675a02SAndroid Build Coastguard Worker // read_aggregate_update restore, not the per-client read_update restore. 95*14675a02SAndroid Build Coastguard Worker message SecureAggregand { 96*14675a02SAndroid Build Coastguard Worker message Dimension { 97*14675a02SAndroid Build Coastguard Worker int64 size = 1; 98*14675a02SAndroid Build Coastguard Worker } 99*14675a02SAndroid Build Coastguard Worker 100*14675a02SAndroid Build Coastguard Worker // Dimensions of the aggregand. This is used by the secure aggregation 101*14675a02SAndroid Build Coastguard Worker // protocol in its early rounds, not as redundant info which could be 102*14675a02SAndroid Build Coastguard Worker // obtained by reading the dimensions of the tensor itself. 103*14675a02SAndroid Build Coastguard Worker repeated Dimension dimension = 3; 104*14675a02SAndroid Build Coastguard Worker 105*14675a02SAndroid Build Coastguard Worker // The data type anticipated by the server-side graph. 106*14675a02SAndroid Build Coastguard Worker tensorflow.DataType dtype = 4; 107*14675a02SAndroid Build Coastguard Worker 108*14675a02SAndroid Build Coastguard Worker // SecureAggregation will compute sum modulo this modulus. 109*14675a02SAndroid Build Coastguard Worker message FixedModulus { 110*14675a02SAndroid Build Coastguard Worker uint64 modulus = 1; 111*14675a02SAndroid Build Coastguard Worker } 112*14675a02SAndroid Build Coastguard Worker 113*14675a02SAndroid Build Coastguard Worker // SecureAggregation will for each shard compute sum modulo m with m at 114*14675a02SAndroid Build Coastguard Worker // least (1 + shard_size * (base_modulus - 1)), then aggregate 115*14675a02SAndroid Build Coastguard Worker // shard results with non-modular addition. Here, shard_size is the number 116*14675a02SAndroid Build Coastguard Worker // of clients in the shard. 117*14675a02SAndroid Build Coastguard Worker // 118*14675a02SAndroid Build Coastguard Worker // Note that the modulus for each shard will be greater than the largest 119*14675a02SAndroid Build Coastguard Worker // possible (non-modular) sum of the inputs to that shard. That is, 120*14675a02SAndroid Build Coastguard Worker // assuming each client has input on range [0, base_modulus), the result 121*14675a02SAndroid Build Coastguard Worker // will be identical to non-modular addition (i.e. federated_sum). 122*14675a02SAndroid Build Coastguard Worker // 123*14675a02SAndroid Build Coastguard Worker // While any m >= (1 + shard_size * (base_modulus - 1)), the current 124*14675a02SAndroid Build Coastguard Worker // implementation takes 125*14675a02SAndroid Build Coastguard Worker // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the 126*14675a02SAndroid Build Coastguard Worker // smallest possible value of m that is also a power of 2. This choice is 127*14675a02SAndroid Build Coastguard Worker // made because (a) it uses the same number of bits per vector entry as 128*14675a02SAndroid Build Coastguard Worker // valid smaller m, using the current on-the-wire encoding scheme, and (b) 129*14675a02SAndroid Build Coastguard Worker // it enables the underlying mask-generation PRNG to run in its most 130*14675a02SAndroid Build Coastguard Worker // computationally efficient mode, which can be up to 2x faster. 131*14675a02SAndroid Build Coastguard Worker message ModulusTimesShardSize { 132*14675a02SAndroid Build Coastguard Worker uint64 base_modulus = 1; 133*14675a02SAndroid Build Coastguard Worker } 134*14675a02SAndroid Build Coastguard Worker 135*14675a02SAndroid Build Coastguard Worker oneof modulus_scheme { 136*14675a02SAndroid Build Coastguard Worker // Bitwidth of the aggregand. 137*14675a02SAndroid Build Coastguard Worker // 138*14675a02SAndroid Build Coastguard Worker // This is the bitwidth of an input value (i.e. the bitwidth that 139*14675a02SAndroid Build Coastguard Worker // quantization should target). The Secure Aggregation bitwidth (i.e., 140*14675a02SAndroid Build Coastguard Worker // the bitwidth of the *sum* of the input values) will be a function of 141*14675a02SAndroid Build Coastguard Worker // this bitwidth and the number of participating clients, as negotiated 142*14675a02SAndroid Build Coastguard Worker // with the server when the protocol is initiated. 143*14675a02SAndroid Build Coastguard Worker // 144*14675a02SAndroid Build Coastguard Worker // Deprecated; prefer fixed_modulus instead. 145*14675a02SAndroid Build Coastguard Worker int32 quantized_input_bitwidth = 2 [deprecated = true]; 146*14675a02SAndroid Build Coastguard Worker 147*14675a02SAndroid Build Coastguard Worker FixedModulus fixed_modulus = 5; 148*14675a02SAndroid Build Coastguard Worker ModulusTimesShardSize modulus_times_shard_size = 6; 149*14675a02SAndroid Build Coastguard Worker } 150*14675a02SAndroid Build Coastguard Worker 151*14675a02SAndroid Build Coastguard Worker reserved 1; 152*14675a02SAndroid Build Coastguard Worker } 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker // What type of side channel is used. 155*14675a02SAndroid Build Coastguard Worker oneof type { 156*14675a02SAndroid Build Coastguard Worker SecureAggregand secure_aggregand = 1; 157*14675a02SAndroid Build Coastguard Worker } 158*14675a02SAndroid Build Coastguard Worker 159*14675a02SAndroid Build Coastguard Worker // When restoring the name of the tensor to restore to. This is the name 160*14675a02SAndroid Build Coastguard Worker // (key) supplied in the feed_dict in the before_restore_op in order to 161*14675a02SAndroid Build Coastguard Worker // restore the tensor provided by the side channel (which will be the 162*14675a02SAndroid Build Coastguard Worker // value in the feed_dict). 163*14675a02SAndroid Build Coastguard Worker string restore_name = 2; 164*14675a02SAndroid Build Coastguard Worker} 165*14675a02SAndroid Build Coastguard Worker 166*14675a02SAndroid Build Coastguard Worker// Container for a metric used by the internal toolkit. 167*14675a02SAndroid Build Coastguard Workermessage Metric { 168*14675a02SAndroid Build Coastguard Worker // Name of an Op to run to read the value. 169*14675a02SAndroid Build Coastguard Worker string variable_name = 1; 170*14675a02SAndroid Build Coastguard Worker 171*14675a02SAndroid Build Coastguard Worker // A human-readable name for the statistic. Metric names are usually 172*14675a02SAndroid Build Coastguard Worker // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 173*14675a02SAndroid Build Coastguard Worker // Must be 7-bit ASCII and under 122 characters. 174*14675a02SAndroid Build Coastguard Worker string stat_name = 2; 175*14675a02SAndroid Build Coastguard Worker 176*14675a02SAndroid Build Coastguard Worker // The human-readable name of another metric by which this metric should be 177*14675a02SAndroid Build Coastguard Worker // normalized, if any. If empty, this Metric should be aggregated with simple 178*14675a02SAndroid Build Coastguard Worker // summation. If not empty, the Metric is aggregated according to 179*14675a02SAndroid Build Coastguard Worker // weighted_metric_sum = sum_i (metric_i * weight_i) 180*14675a02SAndroid Build Coastguard Worker // weight_sum = sum_i weight_i 181*14675a02SAndroid Build Coastguard Worker // average_metric_value = weighted_metric_sum / weight_sum 182*14675a02SAndroid Build Coastguard Worker string weight_name = 3; 183*14675a02SAndroid Build Coastguard Worker} 184*14675a02SAndroid Build Coastguard Worker 185*14675a02SAndroid Build Coastguard Worker// Controls the format of output metrics users receive. Represents instructions 186*14675a02SAndroid Build Coastguard Worker// for how metrics are to be output to users, controlling the end format of 187*14675a02SAndroid Build Coastguard Worker// the metric users receive. 188*14675a02SAndroid Build Coastguard Workermessage OutputMetric { 189*14675a02SAndroid Build Coastguard Worker // Metric name. 190*14675a02SAndroid Build Coastguard Worker string name = 1; 191*14675a02SAndroid Build Coastguard Worker 192*14675a02SAndroid Build Coastguard Worker oneof value_source { 193*14675a02SAndroid Build Coastguard Worker // A metric representing one stat with aggregation type sum. 194*14675a02SAndroid Build Coastguard Worker SumOptions sum = 2; 195*14675a02SAndroid Build Coastguard Worker 196*14675a02SAndroid Build Coastguard Worker // A metric representing a ratio between metrics with aggregation 197*14675a02SAndroid Build Coastguard Worker // type average. 198*14675a02SAndroid Build Coastguard Worker AverageOptions average = 3; 199*14675a02SAndroid Build Coastguard Worker 200*14675a02SAndroid Build Coastguard Worker // A metric that is not aggregated by the MetricReportAggregator or 201*14675a02SAndroid Build Coastguard Worker // metrics_loader. This includes metrics like 'num_server_updates' that are 202*14675a02SAndroid Build Coastguard Worker // aggregated in TensorFlow. 203*14675a02SAndroid Build Coastguard Worker NoneOptions none = 4; 204*14675a02SAndroid Build Coastguard Worker 205*14675a02SAndroid Build Coastguard Worker // A metric representing one stat with aggregation type only sample. 206*14675a02SAndroid Build Coastguard Worker // Samples at most 101 clients' values. 207*14675a02SAndroid Build Coastguard Worker OnlySampleOptions only_sample = 5; 208*14675a02SAndroid Build Coastguard Worker } 209*14675a02SAndroid Build Coastguard Worker // Iff True, the metric will be plotted in the default view of the 210*14675a02SAndroid Build Coastguard Worker // task level Colab automatically. 211*14675a02SAndroid Build Coastguard Worker oneof visualization_info { 212*14675a02SAndroid Build Coastguard Worker bool auto_plot = 6 [deprecated = true]; 213*14675a02SAndroid Build Coastguard Worker VisualizationSpec plot_spec = 7; 214*14675a02SAndroid Build Coastguard Worker } 215*14675a02SAndroid Build Coastguard Worker} 216*14675a02SAndroid Build Coastguard Worker 217*14675a02SAndroid Build Coastguard Workermessage VisualizationSpec { 218*14675a02SAndroid Build Coastguard Worker // Different allowable plot types. 219*14675a02SAndroid Build Coastguard Worker enum VisualizationType { 220*14675a02SAndroid Build Coastguard Worker NONE = 0; 221*14675a02SAndroid Build Coastguard Worker DEFAULT_PLOT_FOR_TASK_TYPE = 1; 222*14675a02SAndroid Build Coastguard Worker LINE_PLOT = 2; 223*14675a02SAndroid Build Coastguard Worker LINE_PLOT_WITH_PERCENTILES = 3; 224*14675a02SAndroid Build Coastguard Worker HISTOGRAM = 4; 225*14675a02SAndroid Build Coastguard Worker } 226*14675a02SAndroid Build Coastguard Worker 227*14675a02SAndroid Build Coastguard Worker // Defines the plot type to provide downstream. 228*14675a02SAndroid Build Coastguard Worker VisualizationType plot_type = 1; 229*14675a02SAndroid Build Coastguard Worker 230*14675a02SAndroid Build Coastguard Worker // The x-axis which to provide for the given metric. Must be the name of a 231*14675a02SAndroid Build Coastguard Worker // metric or counter. Recommended x_axis options are source_round, round, 232*14675a02SAndroid Build Coastguard Worker // or time. 233*14675a02SAndroid Build Coastguard Worker string x_axis = 2; 234*14675a02SAndroid Build Coastguard Worker 235*14675a02SAndroid Build Coastguard Worker // Iff True, metric will be displayed on a population level dashboard. 236*14675a02SAndroid Build Coastguard Worker bool plot_on_population_dashboard = 3; 237*14675a02SAndroid Build Coastguard Worker} 238*14675a02SAndroid Build Coastguard Worker 239*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type sum. 240*14675a02SAndroid Build Coastguard Workermessage SumOptions { 241*14675a02SAndroid Build Coastguard Worker // Name for corresponding Metric stat_name field. 242*14675a02SAndroid Build Coastguard Worker string stat_name = 1; 243*14675a02SAndroid Build Coastguard Worker 244*14675a02SAndroid Build Coastguard Worker // Iff True, a cumulative sum over rounds will be provided in addition to a 245*14675a02SAndroid Build Coastguard Worker // sum per round for the value metric. 246*14675a02SAndroid Build Coastguard Worker bool include_cumulative_sum = 2; 247*14675a02SAndroid Build Coastguard Worker 248*14675a02SAndroid Build Coastguard Worker // Iff True, sample of at most 101 clients' values. 249*14675a02SAndroid Build Coastguard Worker // Used to calculate quantiles in downstream visualization pipeline. 250*14675a02SAndroid Build Coastguard Worker bool include_client_samples = 3; 251*14675a02SAndroid Build Coastguard Worker} 252*14675a02SAndroid Build Coastguard Worker 253*14675a02SAndroid Build Coastguard Worker// A metric representing a ratio between metrics with aggregation type average. 254*14675a02SAndroid Build Coastguard Worker// Represents: numerator stat / denominator stat. 255*14675a02SAndroid Build Coastguard Workermessage AverageOptions { 256*14675a02SAndroid Build Coastguard Worker // Numerator stat name pointing to corresponding Metric stat_name. 257*14675a02SAndroid Build Coastguard Worker string numerator_stat_name = 1; 258*14675a02SAndroid Build Coastguard Worker 259*14675a02SAndroid Build Coastguard Worker // Denominator stat name pointing to corresponding Metric stat_name. 260*14675a02SAndroid Build Coastguard Worker string denominator_stat_name = 2; 261*14675a02SAndroid Build Coastguard Worker 262*14675a02SAndroid Build Coastguard Worker // Name for corresponding Metric stat_name that is the ratio of the 263*14675a02SAndroid Build Coastguard Worker // numerator stat / denominator stat. 264*14675a02SAndroid Build Coastguard Worker string average_stat_name = 3; 265*14675a02SAndroid Build Coastguard Worker 266*14675a02SAndroid Build Coastguard Worker // Iff True, sample of at most 101 client's values. 267*14675a02SAndroid Build Coastguard Worker // Used to calculate quantiles in downstream visualization pipeline. 268*14675a02SAndroid Build Coastguard Worker bool include_client_samples = 4; 269*14675a02SAndroid Build Coastguard Worker} 270*14675a02SAndroid Build Coastguard Worker 271*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type none. 272*14675a02SAndroid Build Coastguard Workermessage NoneOptions { 273*14675a02SAndroid Build Coastguard Worker // Name for corresponding Metric stat_name field. 274*14675a02SAndroid Build Coastguard Worker string stat_name = 1; 275*14675a02SAndroid Build Coastguard Worker} 276*14675a02SAndroid Build Coastguard Worker 277*14675a02SAndroid Build Coastguard Worker// A metric representing one stat with aggregation type only sample. 278*14675a02SAndroid Build Coastguard Workermessage OnlySampleOptions { 279*14675a02SAndroid Build Coastguard Worker // Name for corresponding Metric stat_name field. 280*14675a02SAndroid Build Coastguard Worker string stat_name = 1; 281*14675a02SAndroid Build Coastguard Worker} 282*14675a02SAndroid Build Coastguard Worker 283*14675a02SAndroid Build Coastguard Worker// Represents a data set. This is used for testing. 284*14675a02SAndroid Build Coastguard Workermessage Dataset { 285*14675a02SAndroid Build Coastguard Worker // Represents the data set for one client. 286*14675a02SAndroid Build Coastguard Worker message ClientDataset { 287*14675a02SAndroid Build Coastguard Worker // A string identifying the client. 288*14675a02SAndroid Build Coastguard Worker string client_id = 1; 289*14675a02SAndroid Build Coastguard Worker 290*14675a02SAndroid Build Coastguard Worker // A list of serialized tf.Example protos. 291*14675a02SAndroid Build Coastguard Worker repeated bytes example = 2; 292*14675a02SAndroid Build Coastguard Worker 293*14675a02SAndroid Build Coastguard Worker // Represents a dataset whose examples are selected by an ExampleSelector. 294*14675a02SAndroid Build Coastguard Worker message SelectedExample { 295*14675a02SAndroid Build Coastguard Worker ExampleSelector selector = 1; 296*14675a02SAndroid Build Coastguard Worker repeated bytes example = 2; 297*14675a02SAndroid Build Coastguard Worker } 298*14675a02SAndroid Build Coastguard Worker 299*14675a02SAndroid Build Coastguard Worker // A list of (selector, dataset) pairs. Used in testing some *TFF-based 300*14675a02SAndroid Build Coastguard Worker // tasks* that require multiple datasets as client input, e.g., a TFF-based 301*14675a02SAndroid Build Coastguard Worker // personalization eval task requires each client to provide at least two 302*14675a02SAndroid Build Coastguard Worker // datasets: one for train, and the other for test. 303*14675a02SAndroid Build Coastguard Worker repeated SelectedExample selected_example = 3; 304*14675a02SAndroid Build Coastguard Worker } 305*14675a02SAndroid Build Coastguard Worker 306*14675a02SAndroid Build Coastguard Worker // A list of client data. 307*14675a02SAndroid Build Coastguard Worker repeated ClientDataset client_data = 1; 308*14675a02SAndroid Build Coastguard Worker} 309*14675a02SAndroid Build Coastguard Worker 310*14675a02SAndroid Build Coastguard Worker// Represents predicates over metrics - i.e., expectations. This is used in 311*14675a02SAndroid Build Coastguard Worker// training/eval tests to encode metric names and values expected to be reported 312*14675a02SAndroid Build Coastguard Worker// by a client execution. 313*14675a02SAndroid Build Coastguard Workermessage MetricTestPredicates { 314*14675a02SAndroid Build Coastguard Worker // The value must lie in [lower_bound; upper_bound]. Can also be used for 315*14675a02SAndroid Build Coastguard Worker // approximate matching (lower == value - epsilon; upper = value + epsilon). 316*14675a02SAndroid Build Coastguard Worker message Interval { 317*14675a02SAndroid Build Coastguard Worker double lower_bound = 1; 318*14675a02SAndroid Build Coastguard Worker double upper_bound = 2; 319*14675a02SAndroid Build Coastguard Worker } 320*14675a02SAndroid Build Coastguard Worker 321*14675a02SAndroid Build Coastguard Worker // The value must be a real value as long as the value of the weight_name 322*14675a02SAndroid Build Coastguard Worker // metric is non-zero. If the weight metric is zero, then it is acceptable for 323*14675a02SAndroid Build Coastguard Worker // the value to be non-real. 324*14675a02SAndroid Build Coastguard Worker message RealIfNonzeroWeight { 325*14675a02SAndroid Build Coastguard Worker string weight_name = 1; 326*14675a02SAndroid Build Coastguard Worker } 327*14675a02SAndroid Build Coastguard Worker 328*14675a02SAndroid Build Coastguard Worker message MetricCriterion { 329*14675a02SAndroid Build Coastguard Worker // Name of the metric. 330*14675a02SAndroid Build Coastguard Worker string name = 1; 331*14675a02SAndroid Build Coastguard Worker 332*14675a02SAndroid Build Coastguard Worker // FL training round this metric is expected to appear in. 333*14675a02SAndroid Build Coastguard Worker int32 training_round_index = 2; 334*14675a02SAndroid Build Coastguard Worker 335*14675a02SAndroid Build Coastguard Worker // If none of the following is set, no matching is performed; but the 336*14675a02SAndroid Build Coastguard Worker // metric is still expected to be present (with whatever value). 337*14675a02SAndroid Build Coastguard Worker oneof Criterion { 338*14675a02SAndroid Build Coastguard Worker // The reported metric must be < lt. 339*14675a02SAndroid Build Coastguard Worker float lt = 3; 340*14675a02SAndroid Build Coastguard Worker // The reported metric must be > gt. 341*14675a02SAndroid Build Coastguard Worker float gt = 4; 342*14675a02SAndroid Build Coastguard Worker // The reported metric must be <= le. 343*14675a02SAndroid Build Coastguard Worker float le = 5; 344*14675a02SAndroid Build Coastguard Worker // The reported metric must be >= ge. 345*14675a02SAndroid Build Coastguard Worker float ge = 6; 346*14675a02SAndroid Build Coastguard Worker // The reported metric must be == eq. 347*14675a02SAndroid Build Coastguard Worker float eq = 7; 348*14675a02SAndroid Build Coastguard Worker // The reported metric must lie in the interval. 349*14675a02SAndroid Build Coastguard Worker Interval interval = 8; 350*14675a02SAndroid Build Coastguard Worker // The reported metric is not NaN or +/- infinity. 351*14675a02SAndroid Build Coastguard Worker bool real = 9; 352*14675a02SAndroid Build Coastguard Worker // The reported metric is real (i.e., not NaN or +/- infinity) if the 353*14675a02SAndroid Build Coastguard Worker // value of an associated weight is not 0. 354*14675a02SAndroid Build Coastguard Worker RealIfNonzeroWeight real_if_nonzero_weight = 10; 355*14675a02SAndroid Build Coastguard Worker } 356*14675a02SAndroid Build Coastguard Worker } 357*14675a02SAndroid Build Coastguard Worker 358*14675a02SAndroid Build Coastguard Worker repeated MetricCriterion metric_criterion = 1; 359*14675a02SAndroid Build Coastguard Worker 360*14675a02SAndroid Build Coastguard Worker reserved 2; 361*14675a02SAndroid Build Coastguard Worker} 362*14675a02SAndroid Build Coastguard Worker 363*14675a02SAndroid Build Coastguard Worker// Client Phase 364*14675a02SAndroid Build Coastguard Worker// ============ 365*14675a02SAndroid Build Coastguard Worker 366*14675a02SAndroid Build Coastguard Worker// A `TensorflowSpec` that is executed on the client in a single `tf.Session`. 367*14675a02SAndroid Build Coastguard Worker// In federated optimization, this will correspond to one `ServerPhase`. 368*14675a02SAndroid Build Coastguard Workermessage ClientPhase { 369*14675a02SAndroid Build Coastguard Worker // A short CamelCase name for the ClientPhase. 370*14675a02SAndroid Build Coastguard Worker string name = 2; 371*14675a02SAndroid Build Coastguard Worker 372*14675a02SAndroid Build Coastguard Worker // Minimum number of clients in aggregation. 373*14675a02SAndroid Build Coastguard Worker // In secure aggregation mode this is used to configure the protocol instance 374*14675a02SAndroid Build Coastguard Worker // in a way that server can't learn aggregated values with number of 375*14675a02SAndroid Build Coastguard Worker // participants lower than this number. 376*14675a02SAndroid Build Coastguard Worker // Without secure aggregation server still respects this parameter, 377*14675a02SAndroid Build Coastguard Worker // ensuring that aggregated values never leave server RAM unless they include 378*14675a02SAndroid Build Coastguard Worker // data from (at least) specified number of participants. 379*14675a02SAndroid Build Coastguard Worker int32 minimum_number_of_participants = 3; 380*14675a02SAndroid Build Coastguard Worker 381*14675a02SAndroid Build Coastguard Worker // If populated, `io_router` must be specified. 382*14675a02SAndroid Build Coastguard Worker oneof spec { 383*14675a02SAndroid Build Coastguard Worker // A functional interface for the TensorFlow logic the client should 384*14675a02SAndroid Build Coastguard Worker // perform. 385*14675a02SAndroid Build Coastguard Worker TensorflowSpec tensorflow_spec = 4 [lazy = true]; 386*14675a02SAndroid Build Coastguard Worker // Spec for client plans that issue example queries and send the query 387*14675a02SAndroid Build Coastguard Worker // results directly to an aggregator with no or little additional 388*14675a02SAndroid Build Coastguard Worker // processing. 389*14675a02SAndroid Build Coastguard Worker ExampleQuerySpec example_query_spec = 9 [lazy = true]; 390*14675a02SAndroid Build Coastguard Worker } 391*14675a02SAndroid Build Coastguard Worker 392*14675a02SAndroid Build Coastguard Worker // The specification of the inputs coming either from customer apps 393*14675a02SAndroid Build Coastguard Worker // (Local Compute) or the federated protocol (Federated Compute). 394*14675a02SAndroid Build Coastguard Worker oneof io_router { 395*14675a02SAndroid Build Coastguard Worker FederatedComputeIORouter federated_compute = 5 [lazy = true]; 396*14675a02SAndroid Build Coastguard Worker LocalComputeIORouter local_compute = 6 [lazy = true]; 397*14675a02SAndroid Build Coastguard Worker FederatedComputeEligibilityIORouter federated_compute_eligibility = 7 398*14675a02SAndroid Build Coastguard Worker [lazy = true]; 399*14675a02SAndroid Build Coastguard Worker FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true]; 400*14675a02SAndroid Build Coastguard Worker } 401*14675a02SAndroid Build Coastguard Worker 402*14675a02SAndroid Build Coastguard Worker reserved 1; 403*14675a02SAndroid Build Coastguard Worker} 404*14675a02SAndroid Build Coastguard Worker 405*14675a02SAndroid Build Coastguard Worker// TensorflowSpec message describes a single call into TensorFlow, including the 406*14675a02SAndroid Build Coastguard Worker// expected input tensors that must be fed when making that call, which 407*14675a02SAndroid Build Coastguard Worker// output tensors to be fetched, and any operations that have no output but must 408*14675a02SAndroid Build Coastguard Worker// be run. The TensorFlow session will then use the input tensors to do some 409*14675a02SAndroid Build Coastguard Worker// computation, generally reading from one or more datasets, and provide some 410*14675a02SAndroid Build Coastguard Worker// outputs. 411*14675a02SAndroid Build Coastguard Worker// 412*14675a02SAndroid Build Coastguard Worker// Conceptually, client or server code uses this proto along with an IORouter 413*14675a02SAndroid Build Coastguard Worker// to build maps of names to input tensors, vectors of output tensor names, 414*14675a02SAndroid Build Coastguard Worker// and vectors of target nodes: 415*14675a02SAndroid Build Coastguard Worker// 416*14675a02SAndroid Build Coastguard Worker// CreateTensorflowArguments( 417*14675a02SAndroid Build Coastguard Worker// TensorflowSpec& spec, 418*14675a02SAndroid Build Coastguard Worker// IORouter& io_router, 419*14675a02SAndroid Build Coastguard Worker// const vector<pair<string, Tensor>>* input_tensors, 420*14675a02SAndroid Build Coastguard Worker// const vector<string>* output_tensor_names, 421*14675a02SAndroid Build Coastguard Worker// const vector<string>* target_node_names); 422*14675a02SAndroid Build Coastguard Worker// 423*14675a02SAndroid Build Coastguard Worker// Where `input_tensor`, `output_tensor_names` and `target_node_names` 424*14675a02SAndroid Build Coastguard Worker// correspond to the arguments of TensorFlow C++ API for 425*14675a02SAndroid Build Coastguard Worker// `tensorflow::Session:Run()`, and the client executes only a single 426*14675a02SAndroid Build Coastguard Worker// invocation. 427*14675a02SAndroid Build Coastguard Worker// 428*14675a02SAndroid Build Coastguard Worker// Note: the execution engine never sees any concepts related to the federated 429*14675a02SAndroid Build Coastguard Worker// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors 430*14675a02SAndroid Build Coastguard Worker// in, tensors out" interface. New aggregation methods can be added without 431*14675a02SAndroid Build Coastguard Worker// having to modify the execution engine / TensorflowSpec message, instead they 432*14675a02SAndroid Build Coastguard Worker// should modify the IORouter messages. 433*14675a02SAndroid Build Coastguard Worker// 434*14675a02SAndroid Build Coastguard Worker// Note: both `input_tensor_specs` and `output_tensor_specs` are full 435*14675a02SAndroid Build Coastguard Worker// `tensorflow.TensorSpecProto` messages, though TensorFlow technically 436*14675a02SAndroid Build Coastguard Worker// only requires the names to feed the values into the session. The additional 437*14675a02SAndroid Build Coastguard Worker// dtypes/shape information must always be included in case the runtime 438*14675a02SAndroid Build Coastguard Worker// executing this TensorflowSpec wants to perform additional, optional static 439*14675a02SAndroid Build Coastguard Worker// assertions. The runtimes however are free to ignore the dtype/shapes and only 440*14675a02SAndroid Build Coastguard Worker// rely on the names if so desired. 441*14675a02SAndroid Build Coastguard Worker// 442*14675a02SAndroid Build Coastguard Worker// Assertions: 443*14675a02SAndroid Build Coastguard Worker// - all names in `input_tensor_specs`, `output_tensor_specs`, and 444*14675a02SAndroid Build Coastguard Worker// `target_node_names` must appear in the serialized GraphDef where 445*14675a02SAndroid Build Coastguard Worker// the TF execution will be invoked. 446*14675a02SAndroid Build Coastguard Worker// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise 447*14675a02SAndroid Build Coastguard Worker// there is nothing to execute in the graph. 448*14675a02SAndroid Build Coastguard Workermessage TensorflowSpec { 449*14675a02SAndroid Build Coastguard Worker // The name of a tensor into which a unique token for the current session 450*14675a02SAndroid Build Coastguard Worker // should be written. The corresponding tensor is a scalar string tensor and 451*14675a02SAndroid Build Coastguard Worker // is separate from `input_tensors` as there is only one. 452*14675a02SAndroid Build Coastguard Worker // 453*14675a02SAndroid Build Coastguard Worker // A session token allows TensorFlow ops such as `ServeSlices` or 454*14675a02SAndroid Build Coastguard Worker // `ExternalDataset` to refer to callbacks and other session-global objects 455*14675a02SAndroid Build Coastguard Worker // registered before running the session. In the `ExternalDataset` case, a 456*14675a02SAndroid Build Coastguard Worker // single dataset_token is valid for multiple `tf.data.Dataset` objects as 457*14675a02SAndroid Build Coastguard Worker // the token can be thought of as a handle to a dataset factory. 458*14675a02SAndroid Build Coastguard Worker string dataset_token_tensor_name = 1; 459*14675a02SAndroid Build Coastguard Worker 460*14675a02SAndroid Build Coastguard Worker // TensorSpecs of inputs which will be passed to TF. 461*14675a02SAndroid Build Coastguard Worker // 462*14675a02SAndroid Build Coastguard Worker // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in 463*14675a02SAndroid Build Coastguard Worker // TensorFlow's Python API, excluding the dataset_token listed above. 464*14675a02SAndroid Build Coastguard Worker // 465*14675a02SAndroid Build Coastguard Worker // Assertions: 466*14675a02SAndroid Build Coastguard Worker // - All the tensor names designated as inputs in the corresponding IORouter 467*14675a02SAndroid Build Coastguard Worker // must be listed (otherwise the IORouter input work is unused). 468*14675a02SAndroid Build Coastguard Worker // - All placeholders in the TF graph must be listed here, with the 469*14675a02SAndroid Build Coastguard Worker // exception of the dataset_token which is explicitly set above (otherwise 470*14675a02SAndroid Build Coastguard Worker // TensorFlow will fail to execute). 471*14675a02SAndroid Build Coastguard Worker repeated tensorflow.TensorSpecProto input_tensor_specs = 2; 472*14675a02SAndroid Build Coastguard Worker 473*14675a02SAndroid Build Coastguard Worker // TensorSpecs that should be fetched from TF after execution. 474*14675a02SAndroid Build Coastguard Worker // 475*14675a02SAndroid Build Coastguard Worker // Corresponds to the `fetches` parameter of `tf.Session.run()` in 476*14675a02SAndroid Build Coastguard Worker // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++ 477*14675a02SAndroid Build Coastguard Worker // API. 478*14675a02SAndroid Build Coastguard Worker // 479*14675a02SAndroid Build Coastguard Worker // Assertions: 480*14675a02SAndroid Build Coastguard Worker // - The set of tensor names here must strictly match the tensor names 481*14675a02SAndroid Build Coastguard Worker // designated as outputs in the corresponding IORouter (if any exist). 482*14675a02SAndroid Build Coastguard Worker repeated tensorflow.TensorSpecProto output_tensor_specs = 3; 483*14675a02SAndroid Build Coastguard Worker 484*14675a02SAndroid Build Coastguard Worker // Node names in the graph that should be executed, but the output not 485*14675a02SAndroid Build Coastguard Worker // returned. 486*14675a02SAndroid Build Coastguard Worker // 487*14675a02SAndroid Build Coastguard Worker // Corresponds to the `fetches` parameter of `tf.Session.run()` in 488*14675a02SAndroid Build Coastguard Worker // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++ 489*14675a02SAndroid Build Coastguard Worker // API. 490*14675a02SAndroid Build Coastguard Worker // 491*14675a02SAndroid Build Coastguard Worker // This is intended for use with operations that do not produce tensors, but 492*14675a02SAndroid Build Coastguard Worker // nonetheless are required to run (e.g. serializing checkpoints). 493*14675a02SAndroid Build Coastguard Worker repeated string target_node_names = 4; 494*14675a02SAndroid Build Coastguard Worker 495*14675a02SAndroid Build Coastguard Worker // Map of Tensor names to constant inputs. 496*14675a02SAndroid Build Coastguard Worker // Note: tensors specified via this message should not be included in 497*14675a02SAndroid Build Coastguard Worker // input_tensor_specs. 498*14675a02SAndroid Build Coastguard Worker map<string, tensorflow.TensorProto> constant_inputs = 5; 499*14675a02SAndroid Build Coastguard Worker 500*14675a02SAndroid Build Coastguard Worker // The fields below are added by OnDevicePersonalization module. 501*14675a02SAndroid Build Coastguard Worker // Specifies an example selection procedure. 502*14675a02SAndroid Build Coastguard Worker ExampleSelector example_selector = 999; 503*14675a02SAndroid Build Coastguard Worker} 504*14675a02SAndroid Build Coastguard Worker 505*14675a02SAndroid Build Coastguard Worker// ExampleQuerySpec message describes client execution that issues example 506*14675a02SAndroid Build Coastguard Worker// queries and sends the query results directly to an aggregator with no or 507*14675a02SAndroid Build Coastguard Worker// little additional processing. 508*14675a02SAndroid Build Coastguard Worker// This message describes one or more example store queries that perform the 509*14675a02SAndroid Build Coastguard Worker// client side analytics computation in C++. The corresponding output vectors 510*14675a02SAndroid Build Coastguard Worker// will be converted into the expected federated protocol output format. 511*14675a02SAndroid Build Coastguard Worker// This must be used in conjunction with the `FederatedExampleQueryIORouter`. 512*14675a02SAndroid Build Coastguard Workermessage ExampleQuerySpec { 513*14675a02SAndroid Build Coastguard Worker message OutputVectorSpec { 514*14675a02SAndroid Build Coastguard Worker // The output vector name. 515*14675a02SAndroid Build Coastguard Worker string vector_name = 1; 516*14675a02SAndroid Build Coastguard Worker 517*14675a02SAndroid Build Coastguard Worker // Supported data types for the vector of information. 518*14675a02SAndroid Build Coastguard Worker enum DataType { 519*14675a02SAndroid Build Coastguard Worker UNSPECIFIED = 0; 520*14675a02SAndroid Build Coastguard Worker INT32 = 1; 521*14675a02SAndroid Build Coastguard Worker INT64 = 2; 522*14675a02SAndroid Build Coastguard Worker BOOL = 3; 523*14675a02SAndroid Build Coastguard Worker FLOAT = 4; 524*14675a02SAndroid Build Coastguard Worker DOUBLE = 5; 525*14675a02SAndroid Build Coastguard Worker BYTES = 6; 526*14675a02SAndroid Build Coastguard Worker STRING = 7; 527*14675a02SAndroid Build Coastguard Worker } 528*14675a02SAndroid Build Coastguard Worker 529*14675a02SAndroid Build Coastguard Worker // The data type for each entry in the vector. 530*14675a02SAndroid Build Coastguard Worker DataType data_type = 2; 531*14675a02SAndroid Build Coastguard Worker } 532*14675a02SAndroid Build Coastguard Worker 533*14675a02SAndroid Build Coastguard Worker message ExampleQuery { 534*14675a02SAndroid Build Coastguard Worker // The `ExampleSelector` to issue the query with. 535*14675a02SAndroid Build Coastguard Worker ExampleSelector example_selector = 1; 536*14675a02SAndroid Build Coastguard Worker 537*14675a02SAndroid Build Coastguard Worker // Indicates that the query returns vector data and must return a single 538*14675a02SAndroid Build Coastguard Worker // ExampleQueryResult result containing a VectorData entry matching each 539*14675a02SAndroid Build Coastguard Worker // OutputVectorSpec.vector_name. 540*14675a02SAndroid Build Coastguard Worker // 541*14675a02SAndroid Build Coastguard Worker // If the query instead returns no result, then it will be treated as is if 542*14675a02SAndroid Build Coastguard Worker // an error was returned. In that case, or if the query explicitly returns 543*14675a02SAndroid Build Coastguard Worker // an error, then the client will abort its session. 544*14675a02SAndroid Build Coastguard Worker // 545*14675a02SAndroid Build Coastguard Worker // The keys in the map are the names the vectors should be aggregated under, 546*14675a02SAndroid Build Coastguard Worker // and must match the keys in FederatedExampleQueryIORouter.aggregations. 547*14675a02SAndroid Build Coastguard Worker map<string, OutputVectorSpec> output_vector_specs = 2; 548*14675a02SAndroid Build Coastguard Worker } 549*14675a02SAndroid Build Coastguard Worker 550*14675a02SAndroid Build Coastguard Worker // The queries to run. 551*14675a02SAndroid Build Coastguard Worker repeated ExampleQuery example_queries = 1; 552*14675a02SAndroid Build Coastguard Worker} 553*14675a02SAndroid Build Coastguard Worker 554*14675a02SAndroid Build Coastguard Worker// The input and output router for Federated Compute plans. 555*14675a02SAndroid Build Coastguard Worker// 556*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the federated protocol and the TensorFlow 557*14675a02SAndroid Build Coastguard Worker// execution engine. This message describes how to prepare data coming from the 558*14675a02SAndroid Build Coastguard Worker// incoming `CheckinResponse` (defined in 559*14675a02SAndroid Build Coastguard Worker// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what 560*14675a02SAndroid Build Coastguard Worker// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on 561*14675a02SAndroid Build Coastguard Worker// the server). 562*14675a02SAndroid Build Coastguard Worker// 563*14675a02SAndroid Build Coastguard Worker// TODO(team) we could replace `input_checkpoint_file_tensor_name` with 564*14675a02SAndroid Build Coastguard Worker// an `input_tensors` field, which would then be a tensor that contains the 565*14675a02SAndroid Build Coastguard Worker// input TensorProtos directly and skipping disk I/O, rather than referring to a 566*14675a02SAndroid Build Coastguard Worker// checkpoint file path. 567*14675a02SAndroid Build Coastguard Workermessage FederatedComputeIORouter { 568*14675a02SAndroid Build Coastguard Worker // =========================================================================== 569*14675a02SAndroid Build Coastguard Worker // Inputs 570*14675a02SAndroid Build Coastguard Worker // =========================================================================== 571*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor that is fed the file path to the 572*14675a02SAndroid Build Coastguard Worker // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint). 573*14675a02SAndroid Build Coastguard Worker // 574*14675a02SAndroid Build Coastguard Worker // The federated protocol code would copy the `CheckinResponse`'s initial 575*14675a02SAndroid Build Coastguard Worker // checkpoint to a temporary file and then pass that file path through this 576*14675a02SAndroid Build Coastguard Worker // tensor. 577*14675a02SAndroid Build Coastguard Worker // 578*14675a02SAndroid Build Coastguard Worker // Ops may be added to the client graph that take this tensor as input and 579*14675a02SAndroid Build Coastguard Worker // reads the path. 580*14675a02SAndroid Build Coastguard Worker // 581*14675a02SAndroid Build Coastguard Worker // This field is optional. It may be omitted if the client graph does not use 582*14675a02SAndroid Build Coastguard Worker // an initial checkpoint. 583*14675a02SAndroid Build Coastguard Worker string input_filepath_tensor_name = 1; 584*14675a02SAndroid Build Coastguard Worker 585*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor that is fed the file path to which 586*14675a02SAndroid Build Coastguard Worker // client work should serialize the bytes to send back to the server. 587*14675a02SAndroid Build Coastguard Worker // 588*14675a02SAndroid Build Coastguard Worker // The federated protocol code generates a temporary file and passes the file 589*14675a02SAndroid Build Coastguard Worker // path through this tensor. 590*14675a02SAndroid Build Coastguard Worker // 591*14675a02SAndroid Build Coastguard Worker // Ops may be be added to the client graph that use this tensor as an argument 592*14675a02SAndroid Build Coastguard Worker // to write files (e.g. writing checkpoints to disk). 593*14675a02SAndroid Build Coastguard Worker // 594*14675a02SAndroid Build Coastguard Worker // This field is optional. It must be omitted if the client graph does not 595*14675a02SAndroid Build Coastguard Worker // generate any output files (e.g. when all output tensors of `TensorflowSpec` 596*14675a02SAndroid Build Coastguard Worker // use Secure Aggregation). If this field is not set, then the `ReportRequest` 597*14675a02SAndroid Build Coastguard Worker // message in the federated protocol will not have the 598*14675a02SAndroid Build Coastguard Worker // `Report.update_checkpoint` field set. This absence of a value here can be 599*14675a02SAndroid Build Coastguard Worker // used to validate that the plan only uses Secure Aggregation. 600*14675a02SAndroid Build Coastguard Worker // 601*14675a02SAndroid Build Coastguard Worker // Conversely, if this field is set and executing the associated 602*14675a02SAndroid Build Coastguard Worker // TensorflowSpec does not write to the path is indication of an internal 603*14675a02SAndroid Build Coastguard Worker // framework error. The runtime should notify the caller that the computation 604*14675a02SAndroid Build Coastguard Worker // was setup incorrectly. 605*14675a02SAndroid Build Coastguard Worker string output_filepath_tensor_name = 2; 606*14675a02SAndroid Build Coastguard Worker 607*14675a02SAndroid Build Coastguard Worker // =========================================================================== 608*14675a02SAndroid Build Coastguard Worker // Outputs 609*14675a02SAndroid Build Coastguard Worker // =========================================================================== 610*14675a02SAndroid Build Coastguard Worker // Describes which output tensors should be aggregated using an aggregation 611*14675a02SAndroid Build Coastguard Worker // protocol, and the configuration for those protocols. 612*14675a02SAndroid Build Coastguard Worker // 613*14675a02SAndroid Build Coastguard Worker // Assertions: 614*14675a02SAndroid Build Coastguard Worker // - All keys must exist in the associated `TensorflowSpec` as 615*14675a02SAndroid Build Coastguard Worker // `output_tensor_specs.name` values. 616*14675a02SAndroid Build Coastguard Worker map<string, AggregationConfig> aggregations = 3; 617*14675a02SAndroid Build Coastguard Worker} 618*14675a02SAndroid Build Coastguard Worker 619*14675a02SAndroid Build Coastguard Worker// The input and output router for client plans that do not use TensorFlow. 620*14675a02SAndroid Build Coastguard Worker// 621*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the federated protocol and the example query 622*14675a02SAndroid Build Coastguard Worker// execution engine, describing how the query results should ultimately be 623*14675a02SAndroid Build Coastguard Worker// aggregated. 624*14675a02SAndroid Build Coastguard Workermessage FederatedExampleQueryIORouter { 625*14675a02SAndroid Build Coastguard Worker // Describes how each output vector should be aggregated using an aggregation 626*14675a02SAndroid Build Coastguard Worker // protocol, and the configuration for those protocols. 627*14675a02SAndroid Build Coastguard Worker // Keys must match the keys in ExampleQuerySpec.output_vector_specs. 628*14675a02SAndroid Build Coastguard Worker // Note that currently only the TFV1CheckpointAggregation config is supported. 629*14675a02SAndroid Build Coastguard Worker map<string, AggregationConfig> aggregations = 1; 630*14675a02SAndroid Build Coastguard Worker} 631*14675a02SAndroid Build Coastguard Worker 632*14675a02SAndroid Build Coastguard Worker// The specification for how to aggregate the associated tensor across clients 633*14675a02SAndroid Build Coastguard Worker// on the server. 634*14675a02SAndroid Build Coastguard Workermessage AggregationConfig { 635*14675a02SAndroid Build Coastguard Worker oneof protocol_config { 636*14675a02SAndroid Build Coastguard Worker // Indicates that the given output tensor should be processed using Secure 637*14675a02SAndroid Build Coastguard Worker // Aggregation, using the specified config options. 638*14675a02SAndroid Build Coastguard Worker SecureAggregationConfig secure_aggregation = 2; 639*14675a02SAndroid Build Coastguard Worker 640*14675a02SAndroid Build Coastguard Worker // Note: in the future we could add a `SimpleAggregationConfig` to add 641*14675a02SAndroid Build Coastguard Worker // support for simple aggregation without writing to an intermediate 642*14675a02SAndroid Build Coastguard Worker // checkpoint file first. 643*14675a02SAndroid Build Coastguard Worker 644*14675a02SAndroid Build Coastguard Worker // Indicates that the given output tensor or vector (e.g. as produced by an 645*14675a02SAndroid Build Coastguard Worker // ExampleQuerySpec) should be placed in an output TF v1 checkpoint. 646*14675a02SAndroid Build Coastguard Worker // 647*14675a02SAndroid Build Coastguard Worker // Currently only ExampleQuerySpec output vectors are supported by this 648*14675a02SAndroid Build Coastguard Worker // aggregation type (i.e. it cannot be used with TensorflowSpec output 649*14675a02SAndroid Build Coastguard Worker // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 650*14675a02SAndroid Build Coastguard Worker // its corresponding data type. 651*14675a02SAndroid Build Coastguard Worker TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3; 652*14675a02SAndroid Build Coastguard Worker } 653*14675a02SAndroid Build Coastguard Worker} 654*14675a02SAndroid Build Coastguard Worker 655*14675a02SAndroid Build Coastguard Worker// Parameters for the SecAgg protocol (go/secagg). 656*14675a02SAndroid Build Coastguard Worker// 657*14675a02SAndroid Build Coastguard Worker// Currently only the server uses the SecAgg parameters, so we only use this 658*14675a02SAndroid Build Coastguard Worker// message to signify usage of SecAgg. 659*14675a02SAndroid Build Coastguard Workermessage SecureAggregationConfig {} 660*14675a02SAndroid Build Coastguard Worker 661*14675a02SAndroid Build Coastguard Worker// Parameters for the TFV1 Checkpoint Aggregation protocol. 662*14675a02SAndroid Build Coastguard Worker// 663*14675a02SAndroid Build Coastguard Worker// Currently only ExampleQuerySpec output vectors are supported by this 664*14675a02SAndroid Build Coastguard Worker// aggregation type (i.e. it cannot be used with TensorflowSpec output 665*14675a02SAndroid Build Coastguard Worker// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 666*14675a02SAndroid Build Coastguard Worker// its corresponding data type. 667*14675a02SAndroid Build Coastguard Workermessage TFV1CheckpointAggregation {} 668*14675a02SAndroid Build Coastguard Worker 669*14675a02SAndroid Build Coastguard Worker// The input and output router for eligibility-computing plans. These plans 670*14675a02SAndroid Build Coastguard Worker// compute which other plans a client is eligible to run, and are returned by 671*14675a02SAndroid Build Coastguard Worker// clients via a `EligibilityEvalCheckinResponse` (defined in 672*14675a02SAndroid Build Coastguard Worker// fcp/protos/federated_api.proto). 673*14675a02SAndroid Build Coastguard Workermessage FederatedComputeEligibilityIORouter { 674*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor that is fed the file path to the 675*14675a02SAndroid Build Coastguard Worker // initial checkpoint (e.g. as provided via 676*14675a02SAndroid Build Coastguard Worker // `EligibilityEvalPayload.init_checkpoint`). 677*14675a02SAndroid Build Coastguard Worker // 678*14675a02SAndroid Build Coastguard Worker // For more detail see the 679*14675a02SAndroid Build Coastguard Worker // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same 680*14675a02SAndroid Build Coastguard Worker // semantics. 681*14675a02SAndroid Build Coastguard Worker // 682*14675a02SAndroid Build Coastguard Worker // This field is optional. It may be omitted if the client graph does not use 683*14675a02SAndroid Build Coastguard Worker // an initial checkpoint. 684*14675a02SAndroid Build Coastguard Worker // 685*14675a02SAndroid Build Coastguard Worker // This tensor name must exist in the associated 686*14675a02SAndroid Build Coastguard Worker // `TensorflowSpec.input_tensor_specs` list. 687*14675a02SAndroid Build Coastguard Worker string input_filepath_tensor_name = 1; 688*14675a02SAndroid Build Coastguard Worker 689*14675a02SAndroid Build Coastguard Worker // Name of the output tensor (a string scalar) containing the serialized 690*14675a02SAndroid Build Coastguard Worker // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The 691*14675a02SAndroid Build Coastguard Worker // client code will parse this proto and place it in the 692*14675a02SAndroid Build Coastguard Worker // `task_eligibility_info` field of the subsequent `CheckinRequest`. 693*14675a02SAndroid Build Coastguard Worker // 694*14675a02SAndroid Build Coastguard Worker // This tensor name must exist in the associated 695*14675a02SAndroid Build Coastguard Worker // `TensorflowSpec.output_tensor_specs` list. 696*14675a02SAndroid Build Coastguard Worker string task_eligibility_info_tensor_name = 2; 697*14675a02SAndroid Build Coastguard Worker} 698*14675a02SAndroid Build Coastguard Worker 699*14675a02SAndroid Build Coastguard Worker// The input and output router for Local Compute plans. 700*14675a02SAndroid Build Coastguard Worker// 701*14675a02SAndroid Build Coastguard Worker// This proto is the glue between the customers app and the TensorFlow 702*14675a02SAndroid Build Coastguard Worker// execution engine. This message describes how to prepare data coming from the 703*14675a02SAndroid Build Coastguard Worker// customer app (e.g. the input directory the app setup), and the temporary, 704*14675a02SAndroid Build Coastguard Worker// scratch output directory that will be notified to the customer app upon 705*14675a02SAndroid Build Coastguard Worker// completion of `TensorflowSpec`. 706*14675a02SAndroid Build Coastguard Workermessage LocalComputeIORouter { 707*14675a02SAndroid Build Coastguard Worker // =========================================================================== 708*14675a02SAndroid Build Coastguard Worker // Inputs 709*14675a02SAndroid Build Coastguard Worker // =========================================================================== 710*14675a02SAndroid Build Coastguard Worker // The name of the placeholder tensor representing the input resource path(s). 711*14675a02SAndroid Build Coastguard Worker // It can be a single input directory or file path (in this case the 712*14675a02SAndroid Build Coastguard Worker // `input_dir_tensor_name` is populated) or multiple input resources 713*14675a02SAndroid Build Coastguard Worker // represented as a map from names to input directories or file paths (in this 714*14675a02SAndroid Build Coastguard Worker // case the `multiple_input_resources` is populated). 715*14675a02SAndroid Build Coastguard Worker // 716*14675a02SAndroid Build Coastguard Worker // In the multiple input resources case, the placeholder tensors are 717*14675a02SAndroid Build Coastguard Worker // represented as a map: the keys are the input resource names defined by the 718*14675a02SAndroid Build Coastguard Worker // users when constructing the `LocalComputation` Python object, and the 719*14675a02SAndroid Build Coastguard Worker // values are the corresponding placeholder tensor names created by the local 720*14675a02SAndroid Build Coastguard Worker // computation plan builder. 721*14675a02SAndroid Build Coastguard Worker // 722*14675a02SAndroid Build Coastguard Worker // Apps will have the ability to create contracts between their Android code 723*14675a02SAndroid Build Coastguard Worker // and `LocalComputation` toolkit code to place files inside the input 724*14675a02SAndroid Build Coastguard Worker // resource paths with known names (Android code) and create graphs with ops 725*14675a02SAndroid Build Coastguard Worker // to read from these paths (file names can be specified in toolkit code). 726*14675a02SAndroid Build Coastguard Worker oneof input_resource { 727*14675a02SAndroid Build Coastguard Worker string input_dir_tensor_name = 1; 728*14675a02SAndroid Build Coastguard Worker // Directly using the `map` field is not allowed in `oneof`, so we have to 729*14675a02SAndroid Build Coastguard Worker // wrap it in a new message. 730*14675a02SAndroid Build Coastguard Worker MultipleInputResources multiple_input_resources = 3; 731*14675a02SAndroid Build Coastguard Worker } 732*14675a02SAndroid Build Coastguard Worker 733*14675a02SAndroid Build Coastguard Worker // Scalar string tensor name that will contain the output directory path. 734*14675a02SAndroid Build Coastguard Worker // 735*14675a02SAndroid Build Coastguard Worker // The provided directory should be considered temporary scratch that will be 736*14675a02SAndroid Build Coastguard Worker // deleted, not persisted. It is the responsibility of the calling app to 737*14675a02SAndroid Build Coastguard Worker // move the desired files to a permanent location once the client returns this 738*14675a02SAndroid Build Coastguard Worker // directory back to the calling app. 739*14675a02SAndroid Build Coastguard Worker string output_dir_tensor_name = 2; 740*14675a02SAndroid Build Coastguard Worker 741*14675a02SAndroid Build Coastguard Worker // =========================================================================== 742*14675a02SAndroid Build Coastguard Worker // Outputs 743*14675a02SAndroid Build Coastguard Worker // =========================================================================== 744*14675a02SAndroid Build Coastguard Worker // NOTE: LocalCompute has no outputs other than what the client graph writes 745*14675a02SAndroid Build Coastguard Worker // to `output_dir` specified above. 746*14675a02SAndroid Build Coastguard Worker} 747*14675a02SAndroid Build Coastguard Worker 748*14675a02SAndroid Build Coastguard Worker// Describes the multiple input resources in `LocalComputeIORouter`. 749*14675a02SAndroid Build Coastguard Workermessage MultipleInputResources { 750*14675a02SAndroid Build Coastguard Worker // The keys are the input resource names (defined by the users when 751*14675a02SAndroid Build Coastguard Worker // constructing the `LocalComputation` Python object), and the values are the 752*14675a02SAndroid Build Coastguard Worker // corresponding placeholder tensor names created by the local computation 753*14675a02SAndroid Build Coastguard Worker // plan builder. 754*14675a02SAndroid Build Coastguard Worker map<string, string> input_resource_tensor_name_map = 1; 755*14675a02SAndroid Build Coastguard Worker} 756*14675a02SAndroid Build Coastguard Worker 757*14675a02SAndroid Build Coastguard Worker// Describes a queue to which input is fed. 758*14675a02SAndroid Build Coastguard Workermessage AsyncInputFeed { 759*14675a02SAndroid Build Coastguard Worker // The op for enqueuing an example input. 760*14675a02SAndroid Build Coastguard Worker string enqueue_op = 1; 761*14675a02SAndroid Build Coastguard Worker 762*14675a02SAndroid Build Coastguard Worker // The input placeholders for the enqueue op. 763*14675a02SAndroid Build Coastguard Worker repeated string enqueue_params = 2; 764*14675a02SAndroid Build Coastguard Worker 765*14675a02SAndroid Build Coastguard Worker // The op for closing the input queue. 766*14675a02SAndroid Build Coastguard Worker string close_op = 3; 767*14675a02SAndroid Build Coastguard Worker 768*14675a02SAndroid Build Coastguard Worker // Whether the work that should be fed asynchronously is the data itself 769*14675a02SAndroid Build Coastguard Worker // or a description of where that data lives. 770*14675a02SAndroid Build Coastguard Worker bool feed_values_are_data = 4; 771*14675a02SAndroid Build Coastguard Worker} 772*14675a02SAndroid Build Coastguard Worker 773*14675a02SAndroid Build Coastguard Workermessage DatasetInput { 774*14675a02SAndroid Build Coastguard Worker // Initializer of iterator corresponding to tf.data.Dataset object which 775*14675a02SAndroid Build Coastguard Worker // handles the input data. Stores name of an op in the graph. 776*14675a02SAndroid Build Coastguard Worker string initializer = 1; 777*14675a02SAndroid Build Coastguard Worker 778*14675a02SAndroid Build Coastguard Worker // Placeholders necessary to initialize the dataset. 779*14675a02SAndroid Build Coastguard Worker DatasetInputPlaceholders placeholders = 2; 780*14675a02SAndroid Build Coastguard Worker 781*14675a02SAndroid Build Coastguard Worker // Batch size to be used in tf.data.Dataset. 782*14675a02SAndroid Build Coastguard Worker int32 batch_size = 3; 783*14675a02SAndroid Build Coastguard Worker} 784*14675a02SAndroid Build Coastguard Worker 785*14675a02SAndroid Build Coastguard Workermessage DatasetInputPlaceholders { 786*14675a02SAndroid Build Coastguard Worker // Name of placeholder corresponding to filename(s) of SSTable(s) to read data 787*14675a02SAndroid Build Coastguard Worker // from. 788*14675a02SAndroid Build Coastguard Worker string filename = 1; 789*14675a02SAndroid Build Coastguard Worker 790*14675a02SAndroid Build Coastguard Worker // Name of placeholder corresponding to key_prefix initializing the 791*14675a02SAndroid Build Coastguard Worker // SSTableDataset. Note the value fed should be unique user id, not a prefix. 792*14675a02SAndroid Build Coastguard Worker string key_prefix = 2; 793*14675a02SAndroid Build Coastguard Worker 794*14675a02SAndroid Build Coastguard Worker // Name of placeholder corresponding to number of rounds the local training 795*14675a02SAndroid Build Coastguard Worker // should be run for. 796*14675a02SAndroid Build Coastguard Worker string num_epochs = 3; 797*14675a02SAndroid Build Coastguard Worker 798*14675a02SAndroid Build Coastguard Worker // Name of placeholder corresponding to batch size. 799*14675a02SAndroid Build Coastguard Worker string batch_size = 4; 800*14675a02SAndroid Build Coastguard Worker} 801*14675a02SAndroid Build Coastguard Worker 802*14675a02SAndroid Build Coastguard Worker// Specifies an example selection procedure. 803*14675a02SAndroid Build Coastguard Workermessage ExampleSelector { 804*14675a02SAndroid Build Coastguard Worker // Selection criteria following a contract agreed upon between client and 805*14675a02SAndroid Build Coastguard Worker // model designers. 806*14675a02SAndroid Build Coastguard Worker google.protobuf.Any criteria = 1; 807*14675a02SAndroid Build Coastguard Worker 808*14675a02SAndroid Build Coastguard Worker // A URI identifying the example collection to read from. Format should adhere 809*14675a02SAndroid Build Coastguard Worker // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments 810*14675a02SAndroid Build Coastguard Worker // should adhere to the following rules: 811*14675a02SAndroid Build Coastguard Worker // - The scheme ${COLLECTION} should be one of: 812*14675a02SAndroid Build Coastguard Worker // - "app" for app-hosted example 813*14675a02SAndroid Build Coastguard Worker // - "simulation" for collections not connected to an app (e.g., if used 814*14675a02SAndroid Build Coastguard Worker // purely for simulation) 815*14675a02SAndroid Build Coastguard Worker // - The authority ${APP_NAME} identifies the owner of the example 816*14675a02SAndroid Build Coastguard Worker // collection and should be either the app's package name, or be left empty 817*14675a02SAndroid Build Coastguard Worker // (which means "the current app package name"). 818*14675a02SAndroid Build Coastguard Worker // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with 819*14675a02SAndroid Build Coastguard Worker // a forward slash ("/"). 820*14675a02SAndroid Build Coastguard Worker // - The query and fragment are currently not used, but they may become used 821*14675a02SAndroid Build Coastguard Worker // for something in the future. To keep open that possibility they must 822*14675a02SAndroid Build Coastguard Worker // currently be left empty. 823*14675a02SAndroid Build Coastguard Worker // 824*14675a02SAndroid Build Coastguard Worker // Example: "app://com.google.some.app/someCollection/name" 825*14675a02SAndroid Build Coastguard Worker // identifies the collection "/someCollection/name" owned and hosted by the 826*14675a02SAndroid Build Coastguard Worker // app with package name "com.google.some.app". 827*14675a02SAndroid Build Coastguard Worker // 828*14675a02SAndroid Build Coastguard Worker // Example: "app:/someCollection/name" or "app:///someCollection/name" 829*14675a02SAndroid Build Coastguard Worker // both identify the collection "/someCollection/name" owned and hosted by the 830*14675a02SAndroid Build Coastguard Worker // app associated with the training job in which this URI appears. 831*14675a02SAndroid Build Coastguard Worker // 832*14675a02SAndroid Build Coastguard Worker // The path will not be interpreted by the runtime, and will be passed to the 833*14675a02SAndroid Build Coastguard Worker // example collection implementation for interpretation. Thus, in the case of 834*14675a02SAndroid Build Coastguard Worker // app-hosted example stores, the path segment's interpretation is a contract 835*14675a02SAndroid Build Coastguard Worker // between the app's example store developers, and the app's model designers. 836*14675a02SAndroid Build Coastguard Worker // 837*14675a02SAndroid Build Coastguard Worker // If an `app://` URI is set, then the `TrainerOptions` collection name must 838*14675a02SAndroid Build Coastguard Worker // not be set. 839*14675a02SAndroid Build Coastguard Worker string collection_uri = 2; 840*14675a02SAndroid Build Coastguard Worker 841*14675a02SAndroid Build Coastguard Worker // Resumption token following a contract agreed upon between client and 842*14675a02SAndroid Build Coastguard Worker // model designers. 843*14675a02SAndroid Build Coastguard Worker google.protobuf.Any resumption_token = 3; 844*14675a02SAndroid Build Coastguard Worker} 845*14675a02SAndroid Build Coastguard Worker 846*14675a02SAndroid Build Coastguard Worker// Selector for slices to fetch as part of a `federated_select` operation. 847*14675a02SAndroid Build Coastguard Workermessage SlicesSelector { 848*14675a02SAndroid Build Coastguard Worker // The string ID under which the slices are served. 849*14675a02SAndroid Build Coastguard Worker // 850*14675a02SAndroid Build Coastguard Worker // This value must have been returned by a previous call to the `serve_slices` 851*14675a02SAndroid Build Coastguard Worker // op run during the `write_client_init` operation. 852*14675a02SAndroid Build Coastguard Worker string served_at_id = 1; 853*14675a02SAndroid Build Coastguard Worker 854*14675a02SAndroid Build Coastguard Worker // The indices of slices to fetch. 855*14675a02SAndroid Build Coastguard Worker repeated int32 keys = 2; 856*14675a02SAndroid Build Coastguard Worker} 857*14675a02SAndroid Build Coastguard Worker 858*14675a02SAndroid Build Coastguard Worker// Represents slice data to be served as part of a `federated_select` operation. 859*14675a02SAndroid Build Coastguard Worker// This is used for testing. 860*14675a02SAndroid Build Coastguard Workermessage SlicesTestDataset { 861*14675a02SAndroid Build Coastguard Worker // The test data to use. The keys map to the `SlicesSelector.served_at_id` 862*14675a02SAndroid Build Coastguard Worker // field. E.g. test slice data for a slice with `served_at_id`="foo" and 863*14675a02SAndroid Build Coastguard Worker // `keys`=2 would be store in `dataset["foo"].slice_data[2]`. 864*14675a02SAndroid Build Coastguard Worker map<string, SlicesTestData> dataset = 1; 865*14675a02SAndroid Build Coastguard Worker} 866*14675a02SAndroid Build Coastguard Workermessage SlicesTestData { 867*14675a02SAndroid Build Coastguard Worker // The test slice data to serve. Each entry's index corresponds to the slice 868*14675a02SAndroid Build Coastguard Worker // key it is the test data for. 869*14675a02SAndroid Build Coastguard Worker repeated bytes slice_data = 2; 870*14675a02SAndroid Build Coastguard Worker} 871*14675a02SAndroid Build Coastguard Worker 872*14675a02SAndroid Build Coastguard Worker// Server Phase V2 873*14675a02SAndroid Build Coastguard Worker// =============== 874*14675a02SAndroid Build Coastguard Worker 875*14675a02SAndroid Build Coastguard Worker// Represents a server phase with three distinct components: pre-broadcast, 876*14675a02SAndroid Build Coastguard Worker// aggregation, and post-aggregation. 877*14675a02SAndroid Build Coastguard Worker// 878*14675a02SAndroid Build Coastguard Worker// The pre-broadcast and post-aggregation components are described with 879*14675a02SAndroid Build Coastguard Worker// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec 880*14675a02SAndroid Build Coastguard Worker// messages, respectively. These messages in combination with the server 881*14675a02SAndroid Build Coastguard Worker// IORouter messages specify how to set up a single TF sess.run call for each 882*14675a02SAndroid Build Coastguard Worker// component. 883*14675a02SAndroid Build Coastguard Worker// 884*14675a02SAndroid Build Coastguard Worker// The pre-broadcast logic is obtained by transforming the server_prepare TFF 885*14675a02SAndroid Build Coastguard Worker// computation in the DistributeAggregateForm. It takes the server state as 886*14675a02SAndroid Build Coastguard Worker// input, and it generates the checkpoint to broadcast to the clients and 887*14675a02SAndroid Build Coastguard Worker// potentially an intermediate server state. The intermediate server state may 888*14675a02SAndroid Build Coastguard Worker// be used by the aggregation and post-aggregation logic. 889*14675a02SAndroid Build Coastguard Worker// 890*14675a02SAndroid Build Coastguard Worker// The aggregation logic represents the aggregation of client results at the 891*14675a02SAndroid Build Coastguard Worker// server and is described using a list of ServerAggregationConfig messages. 892*14675a02SAndroid Build Coastguard Worker// Each ServerAggregationConfig message describes a single aggregation operation 893*14675a02SAndroid Build Coastguard Worker// on a set of input/output tensors. The input tensors may represent parts of 894*14675a02SAndroid Build Coastguard Worker// either the client results or the intermediate server state. These messages 895*14675a02SAndroid Build Coastguard Worker// are obtained by transforming the client_to_server_aggregation TFF computation 896*14675a02SAndroid Build Coastguard Worker// in the DistributeAggregateForm. 897*14675a02SAndroid Build Coastguard Worker// 898*14675a02SAndroid Build Coastguard Worker// The post-aggregation logic is obtained by transforming the server_result TFF 899*14675a02SAndroid Build Coastguard Worker// computation in the DistributeAggregateForm. It takes the intermediate server 900*14675a02SAndroid Build Coastguard Worker// state and the aggregated client results as input, and it generates the new 901*14675a02SAndroid Build Coastguard Worker// server state and potentially other server-side output. 902*14675a02SAndroid Build Coastguard Worker// 903*14675a02SAndroid Build Coastguard Worker// Note that while a ServerPhaseV2 message can be generated for all types of 904*14675a02SAndroid Build Coastguard Worker// intrinsics, it is currently only compatible with the ClientPhase message if 905*14675a02SAndroid Build Coastguard Worker// the aggregations being used are exclusively federated_sum (not SecAgg). If 906*14675a02SAndroid Build Coastguard Worker// this compatibility requirement is satisfied, it is also valid to run the 907*14675a02SAndroid Build Coastguard Worker// aggregation portion of this ServerPhaseV2 message alongside the pre- and 908*14675a02SAndroid Build Coastguard Worker// post-aggregation logic from the original ServerPhase message. Ultimately, 909*14675a02SAndroid Build Coastguard Worker// we expect the full ServerPhaseV2 message to be run and the ServerPhase 910*14675a02SAndroid Build Coastguard Worker// message to be deprecated. 911*14675a02SAndroid Build Coastguard Workermessage ServerPhaseV2 { 912*14675a02SAndroid Build Coastguard Worker // A short CamelCase name for the ServerPhaseV2. 913*14675a02SAndroid Build Coastguard Worker string name = 1; 914*14675a02SAndroid Build Coastguard Worker 915*14675a02SAndroid Build Coastguard Worker // A functional interface for the TensorFlow logic the server should perform 916*14675a02SAndroid Build Coastguard Worker // prior to the server-to-client broadcast. This should be used with the 917*14675a02SAndroid Build Coastguard Worker // TensorFlow graph defined in server_graph_prepare_bytes. 918*14675a02SAndroid Build Coastguard Worker TensorflowSpec tensorflow_spec_prepare = 3; 919*14675a02SAndroid Build Coastguard Worker 920*14675a02SAndroid Build Coastguard Worker // The specification of inputs needed by the server_prepare TF logic. 921*14675a02SAndroid Build Coastguard Worker oneof server_prepare_io_router { 922*14675a02SAndroid Build Coastguard Worker ServerPrepareIORouter prepare_router = 4; 923*14675a02SAndroid Build Coastguard Worker } 924*14675a02SAndroid Build Coastguard Worker 925*14675a02SAndroid Build Coastguard Worker // A list of client-to-server aggregations to perform. 926*14675a02SAndroid Build Coastguard Worker repeated ServerAggregationConfig aggregations = 2; 927*14675a02SAndroid Build Coastguard Worker 928*14675a02SAndroid Build Coastguard Worker // A functional interface for the TensorFlow logic the server should perform 929*14675a02SAndroid Build Coastguard Worker // post-aggregation. This should be used with the TensorFlow graph defined 930*14675a02SAndroid Build Coastguard Worker // in server_graph_result_bytes. 931*14675a02SAndroid Build Coastguard Worker TensorflowSpec tensorflow_spec_result = 5; 932*14675a02SAndroid Build Coastguard Worker 933*14675a02SAndroid Build Coastguard Worker // The specification of inputs and outputs needed by the server_result TF 934*14675a02SAndroid Build Coastguard Worker // logic. 935*14675a02SAndroid Build Coastguard Worker oneof server_result_io_router { 936*14675a02SAndroid Build Coastguard Worker ServerResultIORouter result_router = 6; 937*14675a02SAndroid Build Coastguard Worker } 938*14675a02SAndroid Build Coastguard Worker} 939*14675a02SAndroid Build Coastguard Worker 940*14675a02SAndroid Build Coastguard Worker// Routing for server_prepare graph 941*14675a02SAndroid Build Coastguard Workermessage ServerPrepareIORouter { 942*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_prepare TF graph that 943*14675a02SAndroid Build Coastguard Worker // is fed the filepath to the initial server state checkpoint. The 944*14675a02SAndroid Build Coastguard Worker // server_prepare logic reads from this filepath. 945*14675a02SAndroid Build Coastguard Worker string prepare_server_state_input_filepath_tensor_name = 1; 946*14675a02SAndroid Build Coastguard Worker 947*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_prepare TF graph that 948*14675a02SAndroid Build Coastguard Worker // is fed the filepath where the client checkpoint should be stored. The 949*14675a02SAndroid Build Coastguard Worker // server_prepare logic writes to this filepath. 950*14675a02SAndroid Build Coastguard Worker string prepare_output_filepath_tensor_name = 2; 951*14675a02SAndroid Build Coastguard Worker 952*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_prepare TF graph that 953*14675a02SAndroid Build Coastguard Worker // is fed the filepath where the intermediate state checkpoint should be 954*14675a02SAndroid Build Coastguard Worker // stored. The server_prepare logic writes to this filepath. The intermediate 955*14675a02SAndroid Build Coastguard Worker // state checkpoint will be consumed by both the logic used to set parameters 956*14675a02SAndroid Build Coastguard Worker // for aggregation and the post-aggregation logic. 957*14675a02SAndroid Build Coastguard Worker string prepare_intermediate_state_output_filepath_tensor_name = 3; 958*14675a02SAndroid Build Coastguard Worker} 959*14675a02SAndroid Build Coastguard Worker 960*14675a02SAndroid Build Coastguard Worker// Routing for server_result graph 961*14675a02SAndroid Build Coastguard Workermessage ServerResultIORouter { 962*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_result TF graph that is 963*14675a02SAndroid Build Coastguard Worker // fed the filepath to the intermediate state checkpoint. The server_result 964*14675a02SAndroid Build Coastguard Worker // logic reads from this filepath. 965*14675a02SAndroid Build Coastguard Worker string result_intermediate_state_input_filepath_tensor_name = 1; 966*14675a02SAndroid Build Coastguard Worker 967*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_result TF graph that is 968*14675a02SAndroid Build Coastguard Worker // fed the filepath to the aggregated client result checkpoint. The 969*14675a02SAndroid Build Coastguard Worker // server_result logic reads from this filepath. 970*14675a02SAndroid Build Coastguard Worker string result_aggregate_result_input_filepath_tensor_name = 2; 971*14675a02SAndroid Build Coastguard Worker 972*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor in the server_result TF graph that is 973*14675a02SAndroid Build Coastguard Worker // fed the filepath where the updated server state should be stored. The 974*14675a02SAndroid Build Coastguard Worker // server_result logic writes to this filepath. 975*14675a02SAndroid Build Coastguard Worker string result_server_state_output_filepath_tensor_name = 3; 976*14675a02SAndroid Build Coastguard Worker} 977*14675a02SAndroid Build Coastguard Worker 978*14675a02SAndroid Build Coastguard Worker// Represents a single aggregation operation, combining one or more input 979*14675a02SAndroid Build Coastguard Worker// tensors from a collection of clients into one or more output tensors on the 980*14675a02SAndroid Build Coastguard Worker// server. 981*14675a02SAndroid Build Coastguard Workermessage ServerAggregationConfig { 982*14675a02SAndroid Build Coastguard Worker // The uri of the aggregation intrinsic (e.g. 'federated_sum'). 983*14675a02SAndroid Build Coastguard Worker string intrinsic_uri = 1; 984*14675a02SAndroid Build Coastguard Worker 985*14675a02SAndroid Build Coastguard Worker // Describes an argument to the aggregation operation. 986*14675a02SAndroid Build Coastguard Worker message IntrinsicArg { 987*14675a02SAndroid Build Coastguard Worker oneof arg { 988*14675a02SAndroid Build Coastguard Worker // Refers to a tensor within the checkpoint provided by each client. 989*14675a02SAndroid Build Coastguard Worker tensorflow.TensorSpecProto input_tensor = 2; 990*14675a02SAndroid Build Coastguard Worker 991*14675a02SAndroid Build Coastguard Worker // Refers to a tensor within the intermediate server state checkpoint. 992*14675a02SAndroid Build Coastguard Worker tensorflow.TensorSpecProto state_tensor = 3; 993*14675a02SAndroid Build Coastguard Worker } 994*14675a02SAndroid Build Coastguard Worker } 995*14675a02SAndroid Build Coastguard Worker 996*14675a02SAndroid Build Coastguard Worker // List of arguments for the aggregation operation. The arguments can be 997*14675a02SAndroid Build Coastguard Worker // dependent on client data (in which case they must be retrieved from 998*14675a02SAndroid Build Coastguard Worker // clients) or they can be independent of client data (in which case they 999*14675a02SAndroid Build Coastguard Worker // can be configured server-side). For now we assume all client-independent 1000*14675a02SAndroid Build Coastguard Worker // arguments are constants. The arguments must be in the order expected by 1001*14675a02SAndroid Build Coastguard Worker // the server. 1002*14675a02SAndroid Build Coastguard Worker repeated IntrinsicArg intrinsic_args = 4; 1003*14675a02SAndroid Build Coastguard Worker 1004*14675a02SAndroid Build Coastguard Worker // List of server-side outputs produced by the aggregation operation. 1005*14675a02SAndroid Build Coastguard Worker repeated tensorflow.TensorSpecProto output_tensors = 5; 1006*14675a02SAndroid Build Coastguard Worker 1007*14675a02SAndroid Build Coastguard Worker // List of inner aggregation intrinsics. This can be used to delegate parts 1008*14675a02SAndroid Build Coastguard Worker // of the aggregation logic (e.g. a groupby intrinsic may want to delegate 1009*14675a02SAndroid Build Coastguard Worker // a sum operation to a sum intrinsic). 1010*14675a02SAndroid Build Coastguard Worker repeated ServerAggregationConfig inner_aggregations = 6; 1011*14675a02SAndroid Build Coastguard Worker} 1012*14675a02SAndroid Build Coastguard Worker 1013*14675a02SAndroid Build Coastguard Worker// Server Phase 1014*14675a02SAndroid Build Coastguard Worker// ============ 1015*14675a02SAndroid Build Coastguard Worker 1016*14675a02SAndroid Build Coastguard Worker// Represents a server phase which implements TF-based aggregation of multiple 1017*14675a02SAndroid Build Coastguard Worker// client updates. 1018*14675a02SAndroid Build Coastguard Worker// 1019*14675a02SAndroid Build Coastguard Worker// There are two different modes of aggregation that are described 1020*14675a02SAndroid Build Coastguard Worker// by the values in this message. The first is aggregation that is 1021*14675a02SAndroid Build Coastguard Worker// coming from coordinated sets of clients. This includes aggregation 1022*14675a02SAndroid Build Coastguard Worker// done via checkpoints from clients or aggregation done over a set 1023*14675a02SAndroid Build Coastguard Worker// of clients by a process like secure aggregation. The results of 1024*14675a02SAndroid Build Coastguard Worker// this first aggregation are saved to intermediate aggregation 1025*14675a02SAndroid Build Coastguard Worker// checkpoints. The second aggregation then comes from taking 1026*14675a02SAndroid Build Coastguard Worker// these intermediate checkpoints and aggregating over them. 1027*14675a02SAndroid Build Coastguard Worker// 1028*14675a02SAndroid Build Coastguard Worker// These two different modes of aggregation are done on different 1029*14675a02SAndroid Build Coastguard Worker// servers, the first in the 'L1' servers and the second in the 1030*14675a02SAndroid Build Coastguard Worker// 'L2' servers, so we use this nomenclature to describe these 1031*14675a02SAndroid Build Coastguard Worker// phases below. 1032*14675a02SAndroid Build Coastguard Worker// 1033*14675a02SAndroid Build Coastguard Worker// The ServerPhase message is currently in the process of being replaced by the 1034*14675a02SAndroid Build Coastguard Worker// ServerPhaseV2 message as we switch the plan building pipeline to use 1035*14675a02SAndroid Build Coastguard Worker// DistributeAggregateForm instead of MapReduceForm. During the migration 1036*14675a02SAndroid Build Coastguard Worker// process, we may generate both messages and use components from either 1037*14675a02SAndroid Build Coastguard Worker// message during execution. 1038*14675a02SAndroid Build Coastguard Worker// 1039*14675a02SAndroid Build Coastguard Workermessage ServerPhase { 1040*14675a02SAndroid Build Coastguard Worker // A short CamelCase name for the ServerPhase. 1041*14675a02SAndroid Build Coastguard Worker string name = 8; 1042*14675a02SAndroid Build Coastguard Worker 1043*14675a02SAndroid Build Coastguard Worker // =========================================================================== 1044*14675a02SAndroid Build Coastguard Worker // L1 "Intermediate" Aggregation. 1045*14675a02SAndroid Build Coastguard Worker // 1046*14675a02SAndroid Build Coastguard Worker // This is the initial aggregation that creates partial aggregates from client 1047*14675a02SAndroid Build Coastguard Worker // results. L1 Aggregation may be run on many different instances. 1048*14675a02SAndroid Build Coastguard Worker // 1049*14675a02SAndroid Build Coastguard Worker // Pre-condition: 1050*14675a02SAndroid Build Coastguard Worker // The execution environment has loaded the graph from `server_graph_bytes`. 1051*14675a02SAndroid Build Coastguard Worker 1052*14675a02SAndroid Build Coastguard Worker // 1. Initialize the phase. 1053*14675a02SAndroid Build Coastguard Worker // 1054*14675a02SAndroid Build Coastguard Worker // Operation to run before the first aggregation happens. 1055*14675a02SAndroid Build Coastguard Worker // For instance, clears the accumulators so that a new aggregation can begin. 1056*14675a02SAndroid Build Coastguard Worker string phase_init_op = 1; 1057*14675a02SAndroid Build Coastguard Worker 1058*14675a02SAndroid Build Coastguard Worker // 2. For each client in set of clients: 1059*14675a02SAndroid Build Coastguard Worker // a. Restore variables from the client checkpoint. 1060*14675a02SAndroid Build Coastguard Worker // 1061*14675a02SAndroid Build Coastguard Worker // Loads a checkpoint from a single client written via 1062*14675a02SAndroid Build Coastguard Worker // `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once 1063*14675a02SAndroid Build Coastguard Worker // for every client checkpoint in a round. 1064*14675a02SAndroid Build Coastguard Worker CheckpointOp read_update = 3; 1065*14675a02SAndroid Build Coastguard Worker // b. Aggregate the data coming from the client checkpoint. 1066*14675a02SAndroid Build Coastguard Worker // 1067*14675a02SAndroid Build Coastguard Worker // An operation that aggregates the data from read_update. 1068*14675a02SAndroid Build Coastguard Worker // Generally this will add to accumulators and it may leverage internal data 1069*14675a02SAndroid Build Coastguard Worker // inside the graph to adjust the weights of the Tensors. 1070*14675a02SAndroid Build Coastguard Worker // 1071*14675a02SAndroid Build Coastguard Worker // Executed once for each `read_update`, to (for example) update accumulator 1072*14675a02SAndroid Build Coastguard Worker // variables using the values loaded during `read_update`. 1073*14675a02SAndroid Build Coastguard Worker string aggregate_into_accumulators_op = 4; 1074*14675a02SAndroid Build Coastguard Worker 1075*14675a02SAndroid Build Coastguard Worker // 3. After all clients have been aggregated, possibly restore 1076*14675a02SAndroid Build Coastguard Worker // variables that have been aggregated via a separate process. 1077*14675a02SAndroid Build Coastguard Worker // 1078*14675a02SAndroid Build Coastguard Worker // Optionally restores variables where aggregation is done across 1079*14675a02SAndroid Build Coastguard Worker // an entire round of client data updates. In contrast to `read_update`, 1080*14675a02SAndroid Build Coastguard Worker // which restores once per client, this occurs after all clients 1081*14675a02SAndroid Build Coastguard Worker // in a round have been processed. This allows, for example, side 1082*14675a02SAndroid Build Coastguard Worker // channels where aggregation is done by a separate process (such 1083*14675a02SAndroid Build Coastguard Worker // as in secure aggregation), in which the side channel aggregated 1084*14675a02SAndroid Build Coastguard Worker // tensor is passed to the `before_restore_op` which ensure the 1085*14675a02SAndroid Build Coastguard Worker // variables are restored properly. The `after_restore_op` will then 1086*14675a02SAndroid Build Coastguard Worker // be responsible for performing the accumulation. 1087*14675a02SAndroid Build Coastguard Worker // 1088*14675a02SAndroid Build Coastguard Worker // Note that in current use this should not have a SaverDef, but 1089*14675a02SAndroid Build Coastguard Worker // should only be used for side channels. 1090*14675a02SAndroid Build Coastguard Worker CheckpointOp read_aggregated_update = 10; 1091*14675a02SAndroid Build Coastguard Worker 1092*14675a02SAndroid Build Coastguard Worker // 4. Write the aggregated variables to an intermediate checkpoint. 1093*14675a02SAndroid Build Coastguard Worker // 1094*14675a02SAndroid Build Coastguard Worker // We require that `aggregate_into_accumulators_op` is associative and 1095*14675a02SAndroid Build Coastguard Worker // commutative, so that the aggregates can be computed across 1096*14675a02SAndroid Build Coastguard Worker // multiple TensorFlow sessions. 1097*14675a02SAndroid Build Coastguard Worker // As an example, say we are computing the sum of 5 client updates: 1098*14675a02SAndroid Build Coastguard Worker // A = X1 + X2 + X3 + X4 + X5 1099*14675a02SAndroid Build Coastguard Worker // We can always do this in one session by calling `read_update`j and 1100*14675a02SAndroid Build Coastguard Worker // `aggregate_into_accumulators_op` once for each client checkpoint. 1101*14675a02SAndroid Build Coastguard Worker // 1102*14675a02SAndroid Build Coastguard Worker // Alternatively, we could compute: 1103*14675a02SAndroid Build Coastguard Worker // A1 = X1 + X2 in one TensorFlow session, and 1104*14675a02SAndroid Build Coastguard Worker // A2 = X3 + X4 + X5 in a different session. 1105*14675a02SAndroid Build Coastguard Worker // Each of these sessions can then write their accumulator state 1106*14675a02SAndroid Build Coastguard Worker // with the `write_intermediate_update` CheckpointOp, and a yet another third 1107*14675a02SAndroid Build Coastguard Worker // session can then call `read_intermediate_update` and 1108*14675a02SAndroid Build Coastguard Worker // `aggregate_into_accumulators_op` on each of these checkpoints to compute: 1109*14675a02SAndroid Build Coastguard Worker // A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5). 1110*14675a02SAndroid Build Coastguard Worker CheckpointOp write_intermediate_update = 7; 1111*14675a02SAndroid Build Coastguard Worker // End L1 "Intermediate" Aggregation. 1112*14675a02SAndroid Build Coastguard Worker // =========================================================================== 1113*14675a02SAndroid Build Coastguard Worker 1114*14675a02SAndroid Build Coastguard Worker // =========================================================================== 1115*14675a02SAndroid Build Coastguard Worker // L2 Aggregation and Coordinator. 1116*14675a02SAndroid Build Coastguard Worker // 1117*14675a02SAndroid Build Coastguard Worker // This aggregates intermediate checkpoints from L1 Aggregation and performs 1118*14675a02SAndroid Build Coastguard Worker // the finalizing of the update. Unlike L1 there will only be one instance 1119*14675a02SAndroid Build Coastguard Worker // that does this aggregation. 1120*14675a02SAndroid Build Coastguard Worker 1121*14675a02SAndroid Build Coastguard Worker // Pre-condition: 1122*14675a02SAndroid Build Coastguard Worker // The execution environment has loaded the graph from `server_graph_bytes` 1123*14675a02SAndroid Build Coastguard Worker // and restored the global model using `server_savepoint` from the parent 1124*14675a02SAndroid Build Coastguard Worker // `Plan` message. 1125*14675a02SAndroid Build Coastguard Worker 1126*14675a02SAndroid Build Coastguard Worker // 1. Initialize the phase. 1127*14675a02SAndroid Build Coastguard Worker // 1128*14675a02SAndroid Build Coastguard Worker // This currently re-uses the `phase_init_op` from L1 aggregation above. 1129*14675a02SAndroid Build Coastguard Worker 1130*14675a02SAndroid Build Coastguard Worker // 2. Write a checkpoint that can be sent to the client. 1131*14675a02SAndroid Build Coastguard Worker // 1132*14675a02SAndroid Build Coastguard Worker // Generates a checkpoint to be sent to the client, to be read by 1133*14675a02SAndroid Build Coastguard Worker // `FederatedComputeIORouter.input_filepath_tensor_name`. 1134*14675a02SAndroid Build Coastguard Worker 1135*14675a02SAndroid Build Coastguard Worker CheckpointOp write_client_init = 2; 1136*14675a02SAndroid Build Coastguard Worker 1137*14675a02SAndroid Build Coastguard Worker // 3. For each intermediate checkpoint: 1138*14675a02SAndroid Build Coastguard Worker // a. Restore variables from the intermediate checkpoint. 1139*14675a02SAndroid Build Coastguard Worker // 1140*14675a02SAndroid Build Coastguard Worker // The corresponding read checkpoint op to the write_intermediate_update. 1141*14675a02SAndroid Build Coastguard Worker // This is used instead of read_update for intermediate checkpoints because 1142*14675a02SAndroid Build Coastguard Worker // the format of these updates may be different than those used in updates 1143*14675a02SAndroid Build Coastguard Worker // from clients (which may, for example, be compressed). 1144*14675a02SAndroid Build Coastguard Worker CheckpointOp read_intermediate_update = 9; 1145*14675a02SAndroid Build Coastguard Worker // b. Aggregate the data coming from the intermediate checkpoint. 1146*14675a02SAndroid Build Coastguard Worker // 1147*14675a02SAndroid Build Coastguard Worker // An operation that aggregates the data from `read_intermediate_update`. 1148*14675a02SAndroid Build Coastguard Worker // Generally this will add to accumulators and it may leverage internal data 1149*14675a02SAndroid Build Coastguard Worker // inside the graph to adjust the weights of the Tensors. 1150*14675a02SAndroid Build Coastguard Worker string intermediate_aggregate_into_accumulators_op = 11; 1151*14675a02SAndroid Build Coastguard Worker 1152*14675a02SAndroid Build Coastguard Worker // 4. Write the aggregated intermediate variables to a checkpoint. 1153*14675a02SAndroid Build Coastguard Worker // 1154*14675a02SAndroid Build Coastguard Worker // This is used for downstream, cross-round aggregation of metrics. 1155*14675a02SAndroid Build Coastguard Worker // These variables will be read back into a session with 1156*14675a02SAndroid Build Coastguard Worker // read_intermediate_update. 1157*14675a02SAndroid Build Coastguard Worker // 1158*14675a02SAndroid Build Coastguard Worker // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def 1159*14675a02SAndroid Build Coastguard Worker // to disable writing accumulator checkpoints. 1160*14675a02SAndroid Build Coastguard Worker CheckpointOp write_accumulators = 12; 1161*14675a02SAndroid Build Coastguard Worker 1162*14675a02SAndroid Build Coastguard Worker // 5. Finalize the round. 1163*14675a02SAndroid Build Coastguard Worker // 1164*14675a02SAndroid Build Coastguard Worker // This can include: 1165*14675a02SAndroid Build Coastguard Worker // - Applying the update aggregated from the intermediate checkpoints to the 1166*14675a02SAndroid Build Coastguard Worker // global model and other updates to cross-round state variables. 1167*14675a02SAndroid Build Coastguard Worker // - Computing final round metric values (e.g. the `report` of a 1168*14675a02SAndroid Build Coastguard Worker // `tff.federated_aggregate`). 1169*14675a02SAndroid Build Coastguard Worker string apply_aggregrated_updates_op = 5; 1170*14675a02SAndroid Build Coastguard Worker 1171*14675a02SAndroid Build Coastguard Worker // 5. Fetch the server aggregated metrics. 1172*14675a02SAndroid Build Coastguard Worker // 1173*14675a02SAndroid Build Coastguard Worker // A list of names of metric variables to fetch from the TensorFlow session. 1174*14675a02SAndroid Build Coastguard Worker repeated Metric metrics = 6; 1175*14675a02SAndroid Build Coastguard Worker 1176*14675a02SAndroid Build Coastguard Worker // 6. Serialize the updated server state (e.g. the coefficients of the global 1177*14675a02SAndroid Build Coastguard Worker // model in FL) using `server_savepoint` in the parent `Plan` message. 1178*14675a02SAndroid Build Coastguard Worker 1179*14675a02SAndroid Build Coastguard Worker // End L2 Aggregation. 1180*14675a02SAndroid Build Coastguard Worker // =========================================================================== 1181*14675a02SAndroid Build Coastguard Worker} 1182*14675a02SAndroid Build Coastguard Worker 1183*14675a02SAndroid Build Coastguard Worker// Represents the server phase in an eligibility computation. 1184*14675a02SAndroid Build Coastguard Worker// 1185*14675a02SAndroid Build Coastguard Worker// This phase produces a checkpoint to be sent to clients. This checkpoint is 1186*14675a02SAndroid Build Coastguard Worker// then used as an input to the clients' task eligibility computations. 1187*14675a02SAndroid Build Coastguard Worker// This phase *does not include any aggregation.* 1188*14675a02SAndroid Build Coastguard Workermessage ServerEligibilityComputationPhase { 1189*14675a02SAndroid Build Coastguard Worker // A short CamelCase name for the ServerEligibilityComputationPhase. 1190*14675a02SAndroid Build Coastguard Worker string name = 1; 1191*14675a02SAndroid Build Coastguard Worker 1192*14675a02SAndroid Build Coastguard Worker // The names of the TensorFlow nodes to run in order to produce output. 1193*14675a02SAndroid Build Coastguard Worker repeated string target_node_names = 2; 1194*14675a02SAndroid Build Coastguard Worker 1195*14675a02SAndroid Build Coastguard Worker // The specification of inputs and outputs to the TensorFlow graph. 1196*14675a02SAndroid Build Coastguard Worker oneof server_eligibility_io_router { 1197*14675a02SAndroid Build Coastguard Worker TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true]; 1198*14675a02SAndroid Build Coastguard Worker } 1199*14675a02SAndroid Build Coastguard Worker} 1200*14675a02SAndroid Build Coastguard Worker 1201*14675a02SAndroid Build Coastguard Worker// Represents the inputs and outputs of a `ServerEligibilityComputationPhase` 1202*14675a02SAndroid Build Coastguard Worker// which takes a single `TaskEligibilityContext` as input. 1203*14675a02SAndroid Build Coastguard Workermessage TEContextServerEligibilityIORouter { 1204*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor that must be fed a serialized 1205*14675a02SAndroid Build Coastguard Worker // `TaskEligibilityContext`. 1206*14675a02SAndroid Build Coastguard Worker string context_proto_input_tensor_name = 1; 1207*14675a02SAndroid Build Coastguard Worker 1208*14675a02SAndroid Build Coastguard Worker // The name of the scalar string tensor that must be fed the path to which 1209*14675a02SAndroid Build Coastguard Worker // the server graph should write the checkpoint file to be sent to the client. 1210*14675a02SAndroid Build Coastguard Worker string output_filepath_tensor_name = 2; 1211*14675a02SAndroid Build Coastguard Worker} 1212*14675a02SAndroid Build Coastguard Worker 1213*14675a02SAndroid Build Coastguard Worker// Plan 1214*14675a02SAndroid Build Coastguard Worker// ===== 1215*14675a02SAndroid Build Coastguard Worker 1216*14675a02SAndroid Build Coastguard Worker// Represents the overall plan for performing federated optimization or 1217*14675a02SAndroid Build Coastguard Worker// personalization, as handed over to the production system. This will 1218*14675a02SAndroid Build Coastguard Worker// typically be split down into individual pieces for different production 1219*14675a02SAndroid Build Coastguard Worker// parts, e.g. server and client side. 1220*14675a02SAndroid Build Coastguard Worker// NEXT_TAG: 15 1221*14675a02SAndroid Build Coastguard Workermessage Plan { 1222*14675a02SAndroid Build Coastguard Worker reserved 1, 3, 5; 1223*14675a02SAndroid Build Coastguard Worker 1224*14675a02SAndroid Build Coastguard Worker // The actual type of the server_*_graph_bytes fields below is expected to be 1225*14675a02SAndroid Build Coastguard Worker // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form 1226*14675a02SAndroid Build Coastguard Worker // for two reasons. 1227*14675a02SAndroid Build Coastguard Worker // 1) We may use execution engines other than TensorFlow. 1228*14675a02SAndroid Build Coastguard Worker // 2) We wish to avoid the cost of deserialized and re-serializing large 1229*14675a02SAndroid Build Coastguard Worker // graphs, in the Federated Learning service. 1230*14675a02SAndroid Build Coastguard Worker 1231*14675a02SAndroid Build Coastguard Worker // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes, 1232*14675a02SAndroid Build Coastguard Worker // server_graph_prepare_bytes, and server_graph_result_bytes may all be set. 1233*14675a02SAndroid Build Coastguard Worker // If we're using a MapReduceForm-based server implementation, only 1234*14675a02SAndroid Build Coastguard Worker // server_graph_bytes will be used. If we're using a DistributeAggregateForm- 1235*14675a02SAndroid Build Coastguard Worker // based server implementation, only server_graph_prepare_bytes and 1236*14675a02SAndroid Build Coastguard Worker // server_graph_result_bytes will be used. 1237*14675a02SAndroid Build Coastguard Worker 1238*14675a02SAndroid Build Coastguard Worker // Optional. The TensorFlow graph used for all server processing described by 1239*14675a02SAndroid Build Coastguard Worker // ServerPhase. For personalization, this will not be set. 1240*14675a02SAndroid Build Coastguard Worker google.protobuf.Any server_graph_bytes = 7; 1241*14675a02SAndroid Build Coastguard Worker 1242*14675a02SAndroid Build Coastguard Worker // Optional. The TensorFlow graph used for all server processing described by 1243*14675a02SAndroid Build Coastguard Worker // ServerPhaseV2.tensorflow_spec_prepare. 1244*14675a02SAndroid Build Coastguard Worker google.protobuf.Any server_graph_prepare_bytes = 13; 1245*14675a02SAndroid Build Coastguard Worker 1246*14675a02SAndroid Build Coastguard Worker // Optional. The TensorFlow graph used for all server processing described by 1247*14675a02SAndroid Build Coastguard Worker // ServerPhaseV2.tensorflow_spec_result. 1248*14675a02SAndroid Build Coastguard Worker google.protobuf.Any server_graph_result_bytes = 14; 1249*14675a02SAndroid Build Coastguard Worker 1250*14675a02SAndroid Build Coastguard Worker // A savepoint to sync the server checkpoint with a persistent 1251*14675a02SAndroid Build Coastguard Worker // storage system. The storage initially holds a seeded checkpoint 1252*14675a02SAndroid Build Coastguard Worker // which can subsequently read and updated by this savepoint. 1253*14675a02SAndroid Build Coastguard Worker // Optional-- not present in eligibility computation plans (those with a 1254*14675a02SAndroid Build Coastguard Worker // ServerEligibilityComputationPhase). This is used in conjunction with 1255*14675a02SAndroid Build Coastguard Worker // ServerPhase only. 1256*14675a02SAndroid Build Coastguard Worker CheckpointOp server_savepoint = 2; 1257*14675a02SAndroid Build Coastguard Worker 1258*14675a02SAndroid Build Coastguard Worker // Required. The TensorFlow graph that describes the TensorFlow logic a client 1259*14675a02SAndroid Build Coastguard Worker // should perform. It should be consistent with the `TensorflowSpec` field in 1260*14675a02SAndroid Build Coastguard Worker // the `client_phase`. The actual type is expected to be tensorflow.GraphDef. 1261*14675a02SAndroid Build Coastguard Worker // The TensorFlow graph is stored in serialized form for two reasons. 1262*14675a02SAndroid Build Coastguard Worker // 1) We may use execution engines other than TensorFlow. 1263*14675a02SAndroid Build Coastguard Worker // 2) We wish to avoid the cost of deserialized and re-serializing large 1264*14675a02SAndroid Build Coastguard Worker // graphs, in the Federated Learning service. 1265*14675a02SAndroid Build Coastguard Worker google.protobuf.Any client_graph_bytes = 8; 1266*14675a02SAndroid Build Coastguard Worker 1267*14675a02SAndroid Build Coastguard Worker // Optional. The FlatBuffer used for TFLite training. 1268*14675a02SAndroid Build Coastguard Worker // It contains the same model information as the client_graph_bytes, but with 1269*14675a02SAndroid Build Coastguard Worker // a different format. 1270*14675a02SAndroid Build Coastguard Worker bytes client_tflite_graph_bytes = 12; 1271*14675a02SAndroid Build Coastguard Worker 1272*14675a02SAndroid Build Coastguard Worker // A pair of client phase and server phase which are processed in 1273*14675a02SAndroid Build Coastguard Worker // sync. The server execution defines how the results of a client 1274*14675a02SAndroid Build Coastguard Worker // phase are aggregated, and how the checkpoints for clients are 1275*14675a02SAndroid Build Coastguard Worker // generated. 1276*14675a02SAndroid Build Coastguard Worker message Phase { 1277*14675a02SAndroid Build Coastguard Worker // Required. The client phase. 1278*14675a02SAndroid Build Coastguard Worker ClientPhase client_phase = 1; 1279*14675a02SAndroid Build Coastguard Worker 1280*14675a02SAndroid Build Coastguard Worker // Optional. Server phase for TF-based aggregation; not provided for 1281*14675a02SAndroid Build Coastguard Worker // personalization or eligibility tasks. 1282*14675a02SAndroid Build Coastguard Worker ServerPhase server_phase = 2; 1283*14675a02SAndroid Build Coastguard Worker 1284*14675a02SAndroid Build Coastguard Worker // Optional. Server phase for native aggregation; only provided for tasks 1285*14675a02SAndroid Build Coastguard Worker // that have enabled the corresponding flag. 1286*14675a02SAndroid Build Coastguard Worker ServerPhaseV2 server_phase_v2 = 4; 1287*14675a02SAndroid Build Coastguard Worker 1288*14675a02SAndroid Build Coastguard Worker // Optional. Only provided for eligibility tasks. 1289*14675a02SAndroid Build Coastguard Worker ServerEligibilityComputationPhase server_eligibility_phase = 3; 1290*14675a02SAndroid Build Coastguard Worker } 1291*14675a02SAndroid Build Coastguard Worker 1292*14675a02SAndroid Build Coastguard Worker // A pair of client and server computations to run. 1293*14675a02SAndroid Build Coastguard Worker repeated Phase phase = 4; 1294*14675a02SAndroid Build Coastguard Worker 1295*14675a02SAndroid Build Coastguard Worker // Metrics that are persistent across different phases. This 1296*14675a02SAndroid Build Coastguard Worker // includes, for example, counters that track how much work of 1297*14675a02SAndroid Build Coastguard Worker // different kinds has been done. 1298*14675a02SAndroid Build Coastguard Worker repeated Metric metrics = 6; 1299*14675a02SAndroid Build Coastguard Worker 1300*14675a02SAndroid Build Coastguard Worker // Describes how metrics in both the client and server phases should be 1301*14675a02SAndroid Build Coastguard Worker // aggregated. 1302*14675a02SAndroid Build Coastguard Worker repeated OutputMetric output_metrics = 10; 1303*14675a02SAndroid Build Coastguard Worker 1304*14675a02SAndroid Build Coastguard Worker // Version of the plan: 1305*14675a02SAndroid Build Coastguard Worker // version == 0 - Old plan without version field, containing b/65131070 1306*14675a02SAndroid Build Coastguard Worker // version >= 1 - plan supports multi-shard aggregation mode (L1/L2) 1307*14675a02SAndroid Build Coastguard Worker int32 version = 9; 1308*14675a02SAndroid Build Coastguard Worker 1309*14675a02SAndroid Build Coastguard Worker // A TensorFlow ConfigProto packed in an Any. 1310*14675a02SAndroid Build Coastguard Worker // 1311*14675a02SAndroid Build Coastguard Worker // If this field is unset, if the Any proto is set but empty, or if the Any 1312*14675a02SAndroid Build Coastguard Worker // proto is populated with an empty ConfigProto (i.e. its `type_url` field is 1313*14675a02SAndroid Build Coastguard Worker // set, but the `value` field is empty) then the client implementation may 1314*14675a02SAndroid Build Coastguard Worker // choose a set of configuration parameters to provide to TensorFlow by 1315*14675a02SAndroid Build Coastguard Worker // default. 1316*14675a02SAndroid Build Coastguard Worker // 1317*14675a02SAndroid Build Coastguard Worker // In all other cases this field must contain a valid packed ConfigProto 1318*14675a02SAndroid Build Coastguard Worker // (invalid values will result in an error at execution time), and in this 1319*14675a02SAndroid Build Coastguard Worker // case the client will not provide any other configuration parameters by 1320*14675a02SAndroid Build Coastguard Worker // default. 1321*14675a02SAndroid Build Coastguard Worker google.protobuf.Any tensorflow_config_proto = 11; 1322*14675a02SAndroid Build Coastguard Worker} 1323*14675a02SAndroid Build Coastguard Worker 1324*14675a02SAndroid Build Coastguard Worker// Represents a client part of the plan of federated optimization. 1325*14675a02SAndroid Build Coastguard Worker// This also used to describe a client-only plan for standalone on-device 1326*14675a02SAndroid Build Coastguard Worker// training, known as personalization. 1327*14675a02SAndroid Build Coastguard Worker// NEXT_TAG: 6 1328*14675a02SAndroid Build Coastguard Workermessage ClientOnlyPlan { 1329*14675a02SAndroid Build Coastguard Worker reserved 3; 1330*14675a02SAndroid Build Coastguard Worker 1331*14675a02SAndroid Build Coastguard Worker // The graph to use for training, in binary form. 1332*14675a02SAndroid Build Coastguard Worker bytes graph = 1; 1333*14675a02SAndroid Build Coastguard Worker 1334*14675a02SAndroid Build Coastguard Worker // Optional. The flatbuffer used for TFLite training. 1335*14675a02SAndroid Build Coastguard Worker // Whether "graph" or "tflite_graph" is used for training is up to the client 1336*14675a02SAndroid Build Coastguard Worker // code to allow for a flag-controlled a/b rollout. 1337*14675a02SAndroid Build Coastguard Worker bytes tflite_graph = 5; 1338*14675a02SAndroid Build Coastguard Worker 1339*14675a02SAndroid Build Coastguard Worker // The client phase to execute. 1340*14675a02SAndroid Build Coastguard Worker ClientPhase phase = 2; 1341*14675a02SAndroid Build Coastguard Worker 1342*14675a02SAndroid Build Coastguard Worker // A TensorFlow ConfigProto. 1343*14675a02SAndroid Build Coastguard Worker google.protobuf.Any tensorflow_config_proto = 4; 1344*14675a02SAndroid Build Coastguard Worker} 1345*14675a02SAndroid Build Coastguard Worker 1346*14675a02SAndroid Build Coastguard Worker// Represents the cross round aggregation portion for user defined measurements. 1347*14675a02SAndroid Build Coastguard Worker// This is used by tools that process / analyze accumulator checkpoints 1348*14675a02SAndroid Build Coastguard Worker// after a round of computation, to achieve aggregation beyond a round. 1349*14675a02SAndroid Build Coastguard Workermessage CrossRoundAggregationExecution { 1350*14675a02SAndroid Build Coastguard Worker // Operation to run before reading accumulator checkpoint. 1351*14675a02SAndroid Build Coastguard Worker string init_op = 1; 1352*14675a02SAndroid Build Coastguard Worker 1353*14675a02SAndroid Build Coastguard Worker // Reads accumulator checkpoint. 1354*14675a02SAndroid Build Coastguard Worker CheckpointOp read_aggregated_update = 2; 1355*14675a02SAndroid Build Coastguard Worker 1356*14675a02SAndroid Build Coastguard Worker // Operation to merge loaded checkpoint into accumulator. 1357*14675a02SAndroid Build Coastguard Worker string merge_op = 3; 1358*14675a02SAndroid Build Coastguard Worker 1359*14675a02SAndroid Build Coastguard Worker // Reads and writes the final aggregated accumulator vars. 1360*14675a02SAndroid Build Coastguard Worker CheckpointOp read_write_final_accumulators = 6; 1361*14675a02SAndroid Build Coastguard Worker 1362*14675a02SAndroid Build Coastguard Worker // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable` 1363*14675a02SAndroid Build Coastguard Worker // to the user defined name of the signal. 1364*14675a02SAndroid Build Coastguard Worker repeated Measurement measurements = 4; 1365*14675a02SAndroid Build Coastguard Worker 1366*14675a02SAndroid Build Coastguard Worker // The `tf.Graph` used for aggregating accumulator checkpoints when 1367*14675a02SAndroid Build Coastguard Worker // loading metrics. 1368*14675a02SAndroid Build Coastguard Worker google.protobuf.Any cross_round_aggregation_graph_bytes = 5; 1369*14675a02SAndroid Build Coastguard Worker} 1370*14675a02SAndroid Build Coastguard Worker 1371*14675a02SAndroid Build Coastguard Workermessage Measurement { 1372*14675a02SAndroid Build Coastguard Worker // Name of a TensorFlow op to run to read/fetch the value of this measurement. 1373*14675a02SAndroid Build Coastguard Worker string read_op_name = 1; 1374*14675a02SAndroid Build Coastguard Worker 1375*14675a02SAndroid Build Coastguard Worker // A human-readable name for the measurement. Names are usually 1376*14675a02SAndroid Build Coastguard Worker // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 1377*14675a02SAndroid Build Coastguard Worker string name = 2; 1378*14675a02SAndroid Build Coastguard Worker 1379*14675a02SAndroid Build Coastguard Worker reserved 3; 1380*14675a02SAndroid Build Coastguard Worker 1381*14675a02SAndroid Build Coastguard Worker // A serialized `tff.Type` for the measurement. 1382*14675a02SAndroid Build Coastguard Worker bytes tff_type = 4; 1383*14675a02SAndroid Build Coastguard Worker} 1384