1// Copyright 2021 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15syntax = "proto3"; 16 17package google.internal.federated.plan; 18 19import "google/protobuf/any.proto"; 20import "tensorflow/core/framework/tensor.proto"; 21import "tensorflow/core/framework/types.proto"; 22import "tensorflow/core/protobuf/saver.proto"; 23import "tensorflow/core/protobuf/struct.proto"; 24 25option java_package = "com.google.internal.federated.plan"; 26option java_multiple_files = true; 27option java_outer_classname = "PlanProto"; 28 29// Primitives 30// =========== 31 32// Represents an operation to save or restore from a checkpoint. Some 33// instances of this message may only be used either for restore or for 34// save, others for both directions. This is documented together with 35// their usage. 36// 37// This op has four essential uses: 38// 1. read and apply a checkpoint. 39// 2. write a checkpoint. 40// 3. read and apply from an aggregated side channel. 41// 4. write to a side channel (grouped with write a checkpoint). 42// We should consider splitting this into four separate messages. 43message CheckpointOp { 44 // An optional standard saver def. If not provided, only the 45 // op(s) below will be executed. This must be a version 1 SaverDef. 46 tensorflow.SaverDef saver_def = 1; 47 48 // An optional operation to run before the saver_def is executed for 49 // restore. 50 string before_restore_op = 2; 51 52 // An optional operation to run after the saver_def has been 53 // executed for restore. If side_channel_tensors are provided, then 54 // they should be provided in a feed_dict to this op. 55 string after_restore_op = 3; 56 57 // An optional operation to run before the saver_def will be 58 // executed for save. 59 string before_save_op = 4; 60 61 // An optional operation to run after the saver_def has been 62 // executed for save. If there are side_channel_tensors, this op 63 // should be run after the side_channel_tensors have been fetched. 64 string after_save_op = 5; 65 66 // In addition to being saved and restored from a checkpoint, one can 67 // also save and restore via a side channel. The keys in this map are 68 // the names of the tensors transmitted by the side channel. These (key) 69 // tensors should be read off just before saving a SaveDef and used 70 // by the code that handles the side channel. Any variables provided this 71 // way should NOT be saved in the SaveDef. 72 // 73 // For restoring, the variables that are provided by the side channel 74 // are restored differently than those for a checkpoint. For those from 75 // the side channel, these should be restored by calling the before_restore_op 76 // with a feed dict whose keys are the restore_names in the SideChannel and 77 // whose values are the values to be restored. 78 map<string, SideChannel> side_channel_tensors = 6; 79 80 // An optional name of a tensor in to which a unique token for the current 81 // session should be written. 82 // 83 // This session identifier allows TensorFlow ops such as `ServeSlices` or 84 // `ExternalDataset` to refer to callbacks and other session-global objects 85 // registered before running the session. 86 string session_token_tensor_name = 7; 87} 88 89message SideChannel { 90 // A side channel whose variables are processed via SecureAggregation. 91 // This side channel implements aggregation via sum over a set of 92 // clients, so the restored tensor will be a sum of multiple clients 93 // inputs into the side channel. Hence this will restore during the 94 // read_aggregate_update restore, not the per-client read_update restore. 95 message SecureAggregand { 96 message Dimension { 97 int64 size = 1; 98 } 99 100 // Dimensions of the aggregand. This is used by the secure aggregation 101 // protocol in its early rounds, not as redundant info which could be 102 // obtained by reading the dimensions of the tensor itself. 103 repeated Dimension dimension = 3; 104 105 // The data type anticipated by the server-side graph. 106 tensorflow.DataType dtype = 4; 107 108 // SecureAggregation will compute sum modulo this modulus. 109 message FixedModulus { 110 uint64 modulus = 1; 111 } 112 113 // SecureAggregation will for each shard compute sum modulo m with m at 114 // least (1 + shard_size * (base_modulus - 1)), then aggregate 115 // shard results with non-modular addition. Here, shard_size is the number 116 // of clients in the shard. 117 // 118 // Note that the modulus for each shard will be greater than the largest 119 // possible (non-modular) sum of the inputs to that shard. That is, 120 // assuming each client has input on range [0, base_modulus), the result 121 // will be identical to non-modular addition (i.e. federated_sum). 122 // 123 // While any m >= (1 + shard_size * (base_modulus - 1)), the current 124 // implementation takes 125 // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the 126 // smallest possible value of m that is also a power of 2. This choice is 127 // made because (a) it uses the same number of bits per vector entry as 128 // valid smaller m, using the current on-the-wire encoding scheme, and (b) 129 // it enables the underlying mask-generation PRNG to run in its most 130 // computationally efficient mode, which can be up to 2x faster. 131 message ModulusTimesShardSize { 132 uint64 base_modulus = 1; 133 } 134 135 oneof modulus_scheme { 136 // Bitwidth of the aggregand. 137 // 138 // This is the bitwidth of an input value (i.e. the bitwidth that 139 // quantization should target). The Secure Aggregation bitwidth (i.e., 140 // the bitwidth of the *sum* of the input values) will be a function of 141 // this bitwidth and the number of participating clients, as negotiated 142 // with the server when the protocol is initiated. 143 // 144 // Deprecated; prefer fixed_modulus instead. 145 int32 quantized_input_bitwidth = 2 [deprecated = true]; 146 147 FixedModulus fixed_modulus = 5; 148 ModulusTimesShardSize modulus_times_shard_size = 6; 149 } 150 151 reserved 1; 152 } 153 154 // What type of side channel is used. 155 oneof type { 156 SecureAggregand secure_aggregand = 1; 157 } 158 159 // When restoring the name of the tensor to restore to. This is the name 160 // (key) supplied in the feed_dict in the before_restore_op in order to 161 // restore the tensor provided by the side channel (which will be the 162 // value in the feed_dict). 163 string restore_name = 2; 164} 165 166// Container for a metric used by the internal toolkit. 167message Metric { 168 // Name of an Op to run to read the value. 169 string variable_name = 1; 170 171 // A human-readable name for the statistic. Metric names are usually 172 // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 173 // Must be 7-bit ASCII and under 122 characters. 174 string stat_name = 2; 175 176 // The human-readable name of another metric by which this metric should be 177 // normalized, if any. If empty, this Metric should be aggregated with simple 178 // summation. If not empty, the Metric is aggregated according to 179 // weighted_metric_sum = sum_i (metric_i * weight_i) 180 // weight_sum = sum_i weight_i 181 // average_metric_value = weighted_metric_sum / weight_sum 182 string weight_name = 3; 183} 184 185// Controls the format of output metrics users receive. Represents instructions 186// for how metrics are to be output to users, controlling the end format of 187// the metric users receive. 188message OutputMetric { 189 // Metric name. 190 string name = 1; 191 192 oneof value_source { 193 // A metric representing one stat with aggregation type sum. 194 SumOptions sum = 2; 195 196 // A metric representing a ratio between metrics with aggregation 197 // type average. 198 AverageOptions average = 3; 199 200 // A metric that is not aggregated by the MetricReportAggregator or 201 // metrics_loader. This includes metrics like 'num_server_updates' that are 202 // aggregated in TensorFlow. 203 NoneOptions none = 4; 204 205 // A metric representing one stat with aggregation type only sample. 206 // Samples at most 101 clients' values. 207 OnlySampleOptions only_sample = 5; 208 } 209 // Iff True, the metric will be plotted in the default view of the 210 // task level Colab automatically. 211 oneof visualization_info { 212 bool auto_plot = 6 [deprecated = true]; 213 VisualizationSpec plot_spec = 7; 214 } 215} 216 217message VisualizationSpec { 218 // Different allowable plot types. 219 enum VisualizationType { 220 NONE = 0; 221 DEFAULT_PLOT_FOR_TASK_TYPE = 1; 222 LINE_PLOT = 2; 223 LINE_PLOT_WITH_PERCENTILES = 3; 224 HISTOGRAM = 4; 225 } 226 227 // Defines the plot type to provide downstream. 228 VisualizationType plot_type = 1; 229 230 // The x-axis which to provide for the given metric. Must be the name of a 231 // metric or counter. Recommended x_axis options are source_round, round, 232 // or time. 233 string x_axis = 2; 234 235 // Iff True, metric will be displayed on a population level dashboard. 236 bool plot_on_population_dashboard = 3; 237} 238 239// A metric representing one stat with aggregation type sum. 240message SumOptions { 241 // Name for corresponding Metric stat_name field. 242 string stat_name = 1; 243 244 // Iff True, a cumulative sum over rounds will be provided in addition to a 245 // sum per round for the value metric. 246 bool include_cumulative_sum = 2; 247 248 // Iff True, sample of at most 101 clients' values. 249 // Used to calculate quantiles in downstream visualization pipeline. 250 bool include_client_samples = 3; 251} 252 253// A metric representing a ratio between metrics with aggregation type average. 254// Represents: numerator stat / denominator stat. 255message AverageOptions { 256 // Numerator stat name pointing to corresponding Metric stat_name. 257 string numerator_stat_name = 1; 258 259 // Denominator stat name pointing to corresponding Metric stat_name. 260 string denominator_stat_name = 2; 261 262 // Name for corresponding Metric stat_name that is the ratio of the 263 // numerator stat / denominator stat. 264 string average_stat_name = 3; 265 266 // Iff True, sample of at most 101 client's values. 267 // Used to calculate quantiles in downstream visualization pipeline. 268 bool include_client_samples = 4; 269} 270 271// A metric representing one stat with aggregation type none. 272message NoneOptions { 273 // Name for corresponding Metric stat_name field. 274 string stat_name = 1; 275} 276 277// A metric representing one stat with aggregation type only sample. 278message OnlySampleOptions { 279 // Name for corresponding Metric stat_name field. 280 string stat_name = 1; 281} 282 283// Represents a data set. This is used for testing. 284message Dataset { 285 // Represents the data set for one client. 286 message ClientDataset { 287 // A string identifying the client. 288 string client_id = 1; 289 290 // A list of serialized tf.Example protos. 291 repeated bytes example = 2; 292 293 // Represents a dataset whose examples are selected by an ExampleSelector. 294 message SelectedExample { 295 ExampleSelector selector = 1; 296 repeated bytes example = 2; 297 } 298 299 // A list of (selector, dataset) pairs. Used in testing some *TFF-based 300 // tasks* that require multiple datasets as client input, e.g., a TFF-based 301 // personalization eval task requires each client to provide at least two 302 // datasets: one for train, and the other for test. 303 repeated SelectedExample selected_example = 3; 304 } 305 306 // A list of client data. 307 repeated ClientDataset client_data = 1; 308} 309 310// Represents predicates over metrics - i.e., expectations. This is used in 311// training/eval tests to encode metric names and values expected to be reported 312// by a client execution. 313message MetricTestPredicates { 314 // The value must lie in [lower_bound; upper_bound]. Can also be used for 315 // approximate matching (lower == value - epsilon; upper = value + epsilon). 316 message Interval { 317 double lower_bound = 1; 318 double upper_bound = 2; 319 } 320 321 // The value must be a real value as long as the value of the weight_name 322 // metric is non-zero. If the weight metric is zero, then it is acceptable for 323 // the value to be non-real. 324 message RealIfNonzeroWeight { 325 string weight_name = 1; 326 } 327 328 message MetricCriterion { 329 // Name of the metric. 330 string name = 1; 331 332 // FL training round this metric is expected to appear in. 333 int32 training_round_index = 2; 334 335 // If none of the following is set, no matching is performed; but the 336 // metric is still expected to be present (with whatever value). 337 oneof Criterion { 338 // The reported metric must be < lt. 339 float lt = 3; 340 // The reported metric must be > gt. 341 float gt = 4; 342 // The reported metric must be <= le. 343 float le = 5; 344 // The reported metric must be >= ge. 345 float ge = 6; 346 // The reported metric must be == eq. 347 float eq = 7; 348 // The reported metric must lie in the interval. 349 Interval interval = 8; 350 // The reported metric is not NaN or +/- infinity. 351 bool real = 9; 352 // The reported metric is real (i.e., not NaN or +/- infinity) if the 353 // value of an associated weight is not 0. 354 RealIfNonzeroWeight real_if_nonzero_weight = 10; 355 } 356 } 357 358 repeated MetricCriterion metric_criterion = 1; 359 360 reserved 2; 361} 362 363// Client Phase 364// ============ 365 366// A `TensorflowSpec` that is executed on the client in a single `tf.Session`. 367// In federated optimization, this will correspond to one `ServerPhase`. 368message ClientPhase { 369 // A short CamelCase name for the ClientPhase. 370 string name = 2; 371 372 // Minimum number of clients in aggregation. 373 // In secure aggregation mode this is used to configure the protocol instance 374 // in a way that server can't learn aggregated values with number of 375 // participants lower than this number. 376 // Without secure aggregation server still respects this parameter, 377 // ensuring that aggregated values never leave server RAM unless they include 378 // data from (at least) specified number of participants. 379 int32 minimum_number_of_participants = 3; 380 381 // If populated, `io_router` must be specified. 382 oneof spec { 383 // A functional interface for the TensorFlow logic the client should 384 // perform. 385 TensorflowSpec tensorflow_spec = 4 [lazy = true]; 386 // Spec for client plans that issue example queries and send the query 387 // results directly to an aggregator with no or little additional 388 // processing. 389 ExampleQuerySpec example_query_spec = 9 [lazy = true]; 390 } 391 392 // The specification of the inputs coming either from customer apps 393 // (Local Compute) or the federated protocol (Federated Compute). 394 oneof io_router { 395 FederatedComputeIORouter federated_compute = 5 [lazy = true]; 396 LocalComputeIORouter local_compute = 6 [lazy = true]; 397 FederatedComputeEligibilityIORouter federated_compute_eligibility = 7 398 [lazy = true]; 399 FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true]; 400 } 401 402 reserved 1; 403} 404 405// TensorflowSpec message describes a single call into TensorFlow, including the 406// expected input tensors that must be fed when making that call, which 407// output tensors to be fetched, and any operations that have no output but must 408// be run. The TensorFlow session will then use the input tensors to do some 409// computation, generally reading from one or more datasets, and provide some 410// outputs. 411// 412// Conceptually, client or server code uses this proto along with an IORouter 413// to build maps of names to input tensors, vectors of output tensor names, 414// and vectors of target nodes: 415// 416// CreateTensorflowArguments( 417// TensorflowSpec& spec, 418// IORouter& io_router, 419// const vector<pair<string, Tensor>>* input_tensors, 420// const vector<string>* output_tensor_names, 421// const vector<string>* target_node_names); 422// 423// Where `input_tensor`, `output_tensor_names` and `target_node_names` 424// correspond to the arguments of TensorFlow C++ API for 425// `tensorflow::Session:Run()`, and the client executes only a single 426// invocation. 427// 428// Note: the execution engine never sees any concepts related to the federated 429// protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors 430// in, tensors out" interface. New aggregation methods can be added without 431// having to modify the execution engine / TensorflowSpec message, instead they 432// should modify the IORouter messages. 433// 434// Note: both `input_tensor_specs` and `output_tensor_specs` are full 435// `tensorflow.TensorSpecProto` messages, though TensorFlow technically 436// only requires the names to feed the values into the session. The additional 437// dtypes/shape information must always be included in case the runtime 438// executing this TensorflowSpec wants to perform additional, optional static 439// assertions. The runtimes however are free to ignore the dtype/shapes and only 440// rely on the names if so desired. 441// 442// Assertions: 443// - all names in `input_tensor_specs`, `output_tensor_specs`, and 444// `target_node_names` must appear in the serialized GraphDef where 445// the TF execution will be invoked. 446// - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise 447// there is nothing to execute in the graph. 448message TensorflowSpec { 449 // The name of a tensor into which a unique token for the current session 450 // should be written. The corresponding tensor is a scalar string tensor and 451 // is separate from `input_tensors` as there is only one. 452 // 453 // A session token allows TensorFlow ops such as `ServeSlices` or 454 // `ExternalDataset` to refer to callbacks and other session-global objects 455 // registered before running the session. In the `ExternalDataset` case, a 456 // single dataset_token is valid for multiple `tf.data.Dataset` objects as 457 // the token can be thought of as a handle to a dataset factory. 458 string dataset_token_tensor_name = 1; 459 460 // TensorSpecs of inputs which will be passed to TF. 461 // 462 // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in 463 // TensorFlow's Python API, excluding the dataset_token listed above. 464 // 465 // Assertions: 466 // - All the tensor names designated as inputs in the corresponding IORouter 467 // must be listed (otherwise the IORouter input work is unused). 468 // - All placeholders in the TF graph must be listed here, with the 469 // exception of the dataset_token which is explicitly set above (otherwise 470 // TensorFlow will fail to execute). 471 repeated tensorflow.TensorSpecProto input_tensor_specs = 2; 472 473 // TensorSpecs that should be fetched from TF after execution. 474 // 475 // Corresponds to the `fetches` parameter of `tf.Session.run()` in 476 // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++ 477 // API. 478 // 479 // Assertions: 480 // - The set of tensor names here must strictly match the tensor names 481 // designated as outputs in the corresponding IORouter (if any exist). 482 repeated tensorflow.TensorSpecProto output_tensor_specs = 3; 483 484 // Node names in the graph that should be executed, but the output not 485 // returned. 486 // 487 // Corresponds to the `fetches` parameter of `tf.Session.run()` in 488 // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++ 489 // API. 490 // 491 // This is intended for use with operations that do not produce tensors, but 492 // nonetheless are required to run (e.g. serializing checkpoints). 493 repeated string target_node_names = 4; 494 495 // Map of Tensor names to constant inputs. 496 // Note: tensors specified via this message should not be included in 497 // input_tensor_specs. 498 map<string, tensorflow.TensorProto> constant_inputs = 5; 499 500 // The fields below are added by OnDevicePersonalization module. 501 // Specifies an example selection procedure. 502 ExampleSelector example_selector = 999; 503} 504 505// ExampleQuerySpec message describes client execution that issues example 506// queries and sends the query results directly to an aggregator with no or 507// little additional processing. 508// This message describes one or more example store queries that perform the 509// client side analytics computation in C++. The corresponding output vectors 510// will be converted into the expected federated protocol output format. 511// This must be used in conjunction with the `FederatedExampleQueryIORouter`. 512message ExampleQuerySpec { 513 message OutputVectorSpec { 514 // The output vector name. 515 string vector_name = 1; 516 517 // Supported data types for the vector of information. 518 enum DataType { 519 UNSPECIFIED = 0; 520 INT32 = 1; 521 INT64 = 2; 522 BOOL = 3; 523 FLOAT = 4; 524 DOUBLE = 5; 525 BYTES = 6; 526 STRING = 7; 527 } 528 529 // The data type for each entry in the vector. 530 DataType data_type = 2; 531 } 532 533 message ExampleQuery { 534 // The `ExampleSelector` to issue the query with. 535 ExampleSelector example_selector = 1; 536 537 // Indicates that the query returns vector data and must return a single 538 // ExampleQueryResult result containing a VectorData entry matching each 539 // OutputVectorSpec.vector_name. 540 // 541 // If the query instead returns no result, then it will be treated as is if 542 // an error was returned. In that case, or if the query explicitly returns 543 // an error, then the client will abort its session. 544 // 545 // The keys in the map are the names the vectors should be aggregated under, 546 // and must match the keys in FederatedExampleQueryIORouter.aggregations. 547 map<string, OutputVectorSpec> output_vector_specs = 2; 548 } 549 550 // The queries to run. 551 repeated ExampleQuery example_queries = 1; 552} 553 554// The input and output router for Federated Compute plans. 555// 556// This proto is the glue between the federated protocol and the TensorFlow 557// execution engine. This message describes how to prepare data coming from the 558// incoming `CheckinResponse` (defined in 559// fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what 560// to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on 561// the server). 562// 563// TODO(team) we could replace `input_checkpoint_file_tensor_name` with 564// an `input_tensors` field, which would then be a tensor that contains the 565// input TensorProtos directly and skipping disk I/O, rather than referring to a 566// checkpoint file path. 567message FederatedComputeIORouter { 568 // =========================================================================== 569 // Inputs 570 // =========================================================================== 571 // The name of the scalar string tensor that is fed the file path to the 572 // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint). 573 // 574 // The federated protocol code would copy the `CheckinResponse`'s initial 575 // checkpoint to a temporary file and then pass that file path through this 576 // tensor. 577 // 578 // Ops may be added to the client graph that take this tensor as input and 579 // reads the path. 580 // 581 // This field is optional. It may be omitted if the client graph does not use 582 // an initial checkpoint. 583 string input_filepath_tensor_name = 1; 584 585 // The name of the scalar string tensor that is fed the file path to which 586 // client work should serialize the bytes to send back to the server. 587 // 588 // The federated protocol code generates a temporary file and passes the file 589 // path through this tensor. 590 // 591 // Ops may be be added to the client graph that use this tensor as an argument 592 // to write files (e.g. writing checkpoints to disk). 593 // 594 // This field is optional. It must be omitted if the client graph does not 595 // generate any output files (e.g. when all output tensors of `TensorflowSpec` 596 // use Secure Aggregation). If this field is not set, then the `ReportRequest` 597 // message in the federated protocol will not have the 598 // `Report.update_checkpoint` field set. This absence of a value here can be 599 // used to validate that the plan only uses Secure Aggregation. 600 // 601 // Conversely, if this field is set and executing the associated 602 // TensorflowSpec does not write to the path is indication of an internal 603 // framework error. The runtime should notify the caller that the computation 604 // was setup incorrectly. 605 string output_filepath_tensor_name = 2; 606 607 // =========================================================================== 608 // Outputs 609 // =========================================================================== 610 // Describes which output tensors should be aggregated using an aggregation 611 // protocol, and the configuration for those protocols. 612 // 613 // Assertions: 614 // - All keys must exist in the associated `TensorflowSpec` as 615 // `output_tensor_specs.name` values. 616 map<string, AggregationConfig> aggregations = 3; 617} 618 619// The input and output router for client plans that do not use TensorFlow. 620// 621// This proto is the glue between the federated protocol and the example query 622// execution engine, describing how the query results should ultimately be 623// aggregated. 624message FederatedExampleQueryIORouter { 625 // Describes how each output vector should be aggregated using an aggregation 626 // protocol, and the configuration for those protocols. 627 // Keys must match the keys in ExampleQuerySpec.output_vector_specs. 628 // Note that currently only the TFV1CheckpointAggregation config is supported. 629 map<string, AggregationConfig> aggregations = 1; 630} 631 632// The specification for how to aggregate the associated tensor across clients 633// on the server. 634message AggregationConfig { 635 oneof protocol_config { 636 // Indicates that the given output tensor should be processed using Secure 637 // Aggregation, using the specified config options. 638 SecureAggregationConfig secure_aggregation = 2; 639 640 // Note: in the future we could add a `SimpleAggregationConfig` to add 641 // support for simple aggregation without writing to an intermediate 642 // checkpoint file first. 643 644 // Indicates that the given output tensor or vector (e.g. as produced by an 645 // ExampleQuerySpec) should be placed in an output TF v1 checkpoint. 646 // 647 // Currently only ExampleQuerySpec output vectors are supported by this 648 // aggregation type (i.e. it cannot be used with TensorflowSpec output 649 // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 650 // its corresponding data type. 651 TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3; 652 } 653} 654 655// Parameters for the SecAgg protocol (go/secagg). 656// 657// Currently only the server uses the SecAgg parameters, so we only use this 658// message to signify usage of SecAgg. 659message SecureAggregationConfig {} 660 661// Parameters for the TFV1 Checkpoint Aggregation protocol. 662// 663// Currently only ExampleQuerySpec output vectors are supported by this 664// aggregation type (i.e. it cannot be used with TensorflowSpec output 665// tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of 666// its corresponding data type. 667message TFV1CheckpointAggregation {} 668 669// The input and output router for eligibility-computing plans. These plans 670// compute which other plans a client is eligible to run, and are returned by 671// clients via a `EligibilityEvalCheckinResponse` (defined in 672// fcp/protos/federated_api.proto). 673message FederatedComputeEligibilityIORouter { 674 // The name of the scalar string tensor that is fed the file path to the 675 // initial checkpoint (e.g. as provided via 676 // `EligibilityEvalPayload.init_checkpoint`). 677 // 678 // For more detail see the 679 // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same 680 // semantics. 681 // 682 // This field is optional. It may be omitted if the client graph does not use 683 // an initial checkpoint. 684 // 685 // This tensor name must exist in the associated 686 // `TensorflowSpec.input_tensor_specs` list. 687 string input_filepath_tensor_name = 1; 688 689 // Name of the output tensor (a string scalar) containing the serialized 690 // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The 691 // client code will parse this proto and place it in the 692 // `task_eligibility_info` field of the subsequent `CheckinRequest`. 693 // 694 // This tensor name must exist in the associated 695 // `TensorflowSpec.output_tensor_specs` list. 696 string task_eligibility_info_tensor_name = 2; 697} 698 699// The input and output router for Local Compute plans. 700// 701// This proto is the glue between the customers app and the TensorFlow 702// execution engine. This message describes how to prepare data coming from the 703// customer app (e.g. the input directory the app setup), and the temporary, 704// scratch output directory that will be notified to the customer app upon 705// completion of `TensorflowSpec`. 706message LocalComputeIORouter { 707 // =========================================================================== 708 // Inputs 709 // =========================================================================== 710 // The name of the placeholder tensor representing the input resource path(s). 711 // It can be a single input directory or file path (in this case the 712 // `input_dir_tensor_name` is populated) or multiple input resources 713 // represented as a map from names to input directories or file paths (in this 714 // case the `multiple_input_resources` is populated). 715 // 716 // In the multiple input resources case, the placeholder tensors are 717 // represented as a map: the keys are the input resource names defined by the 718 // users when constructing the `LocalComputation` Python object, and the 719 // values are the corresponding placeholder tensor names created by the local 720 // computation plan builder. 721 // 722 // Apps will have the ability to create contracts between their Android code 723 // and `LocalComputation` toolkit code to place files inside the input 724 // resource paths with known names (Android code) and create graphs with ops 725 // to read from these paths (file names can be specified in toolkit code). 726 oneof input_resource { 727 string input_dir_tensor_name = 1; 728 // Directly using the `map` field is not allowed in `oneof`, so we have to 729 // wrap it in a new message. 730 MultipleInputResources multiple_input_resources = 3; 731 } 732 733 // Scalar string tensor name that will contain the output directory path. 734 // 735 // The provided directory should be considered temporary scratch that will be 736 // deleted, not persisted. It is the responsibility of the calling app to 737 // move the desired files to a permanent location once the client returns this 738 // directory back to the calling app. 739 string output_dir_tensor_name = 2; 740 741 // =========================================================================== 742 // Outputs 743 // =========================================================================== 744 // NOTE: LocalCompute has no outputs other than what the client graph writes 745 // to `output_dir` specified above. 746} 747 748// Describes the multiple input resources in `LocalComputeIORouter`. 749message MultipleInputResources { 750 // The keys are the input resource names (defined by the users when 751 // constructing the `LocalComputation` Python object), and the values are the 752 // corresponding placeholder tensor names created by the local computation 753 // plan builder. 754 map<string, string> input_resource_tensor_name_map = 1; 755} 756 757// Describes a queue to which input is fed. 758message AsyncInputFeed { 759 // The op for enqueuing an example input. 760 string enqueue_op = 1; 761 762 // The input placeholders for the enqueue op. 763 repeated string enqueue_params = 2; 764 765 // The op for closing the input queue. 766 string close_op = 3; 767 768 // Whether the work that should be fed asynchronously is the data itself 769 // or a description of where that data lives. 770 bool feed_values_are_data = 4; 771} 772 773message DatasetInput { 774 // Initializer of iterator corresponding to tf.data.Dataset object which 775 // handles the input data. Stores name of an op in the graph. 776 string initializer = 1; 777 778 // Placeholders necessary to initialize the dataset. 779 DatasetInputPlaceholders placeholders = 2; 780 781 // Batch size to be used in tf.data.Dataset. 782 int32 batch_size = 3; 783} 784 785message DatasetInputPlaceholders { 786 // Name of placeholder corresponding to filename(s) of SSTable(s) to read data 787 // from. 788 string filename = 1; 789 790 // Name of placeholder corresponding to key_prefix initializing the 791 // SSTableDataset. Note the value fed should be unique user id, not a prefix. 792 string key_prefix = 2; 793 794 // Name of placeholder corresponding to number of rounds the local training 795 // should be run for. 796 string num_epochs = 3; 797 798 // Name of placeholder corresponding to batch size. 799 string batch_size = 4; 800} 801 802// Specifies an example selection procedure. 803message ExampleSelector { 804 // Selection criteria following a contract agreed upon between client and 805 // model designers. 806 google.protobuf.Any criteria = 1; 807 808 // A URI identifying the example collection to read from. Format should adhere 809 // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments 810 // should adhere to the following rules: 811 // - The scheme ${COLLECTION} should be one of: 812 // - "app" for app-hosted example 813 // - "simulation" for collections not connected to an app (e.g., if used 814 // purely for simulation) 815 // - The authority ${APP_NAME} identifies the owner of the example 816 // collection and should be either the app's package name, or be left empty 817 // (which means "the current app package name"). 818 // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with 819 // a forward slash ("/"). 820 // - The query and fragment are currently not used, but they may become used 821 // for something in the future. To keep open that possibility they must 822 // currently be left empty. 823 // 824 // Example: "app://com.google.some.app/someCollection/name" 825 // identifies the collection "/someCollection/name" owned and hosted by the 826 // app with package name "com.google.some.app". 827 // 828 // Example: "app:/someCollection/name" or "app:///someCollection/name" 829 // both identify the collection "/someCollection/name" owned and hosted by the 830 // app associated with the training job in which this URI appears. 831 // 832 // The path will not be interpreted by the runtime, and will be passed to the 833 // example collection implementation for interpretation. Thus, in the case of 834 // app-hosted example stores, the path segment's interpretation is a contract 835 // between the app's example store developers, and the app's model designers. 836 // 837 // If an `app://` URI is set, then the `TrainerOptions` collection name must 838 // not be set. 839 string collection_uri = 2; 840 841 // Resumption token following a contract agreed upon between client and 842 // model designers. 843 google.protobuf.Any resumption_token = 3; 844} 845 846// Selector for slices to fetch as part of a `federated_select` operation. 847message SlicesSelector { 848 // The string ID under which the slices are served. 849 // 850 // This value must have been returned by a previous call to the `serve_slices` 851 // op run during the `write_client_init` operation. 852 string served_at_id = 1; 853 854 // The indices of slices to fetch. 855 repeated int32 keys = 2; 856} 857 858// Represents slice data to be served as part of a `federated_select` operation. 859// This is used for testing. 860message SlicesTestDataset { 861 // The test data to use. The keys map to the `SlicesSelector.served_at_id` 862 // field. E.g. test slice data for a slice with `served_at_id`="foo" and 863 // `keys`=2 would be store in `dataset["foo"].slice_data[2]`. 864 map<string, SlicesTestData> dataset = 1; 865} 866message SlicesTestData { 867 // The test slice data to serve. Each entry's index corresponds to the slice 868 // key it is the test data for. 869 repeated bytes slice_data = 2; 870} 871 872// Server Phase V2 873// =============== 874 875// Represents a server phase with three distinct components: pre-broadcast, 876// aggregation, and post-aggregation. 877// 878// The pre-broadcast and post-aggregation components are described with 879// the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec 880// messages, respectively. These messages in combination with the server 881// IORouter messages specify how to set up a single TF sess.run call for each 882// component. 883// 884// The pre-broadcast logic is obtained by transforming the server_prepare TFF 885// computation in the DistributeAggregateForm. It takes the server state as 886// input, and it generates the checkpoint to broadcast to the clients and 887// potentially an intermediate server state. The intermediate server state may 888// be used by the aggregation and post-aggregation logic. 889// 890// The aggregation logic represents the aggregation of client results at the 891// server and is described using a list of ServerAggregationConfig messages. 892// Each ServerAggregationConfig message describes a single aggregation operation 893// on a set of input/output tensors. The input tensors may represent parts of 894// either the client results or the intermediate server state. These messages 895// are obtained by transforming the client_to_server_aggregation TFF computation 896// in the DistributeAggregateForm. 897// 898// The post-aggregation logic is obtained by transforming the server_result TFF 899// computation in the DistributeAggregateForm. It takes the intermediate server 900// state and the aggregated client results as input, and it generates the new 901// server state and potentially other server-side output. 902// 903// Note that while a ServerPhaseV2 message can be generated for all types of 904// intrinsics, it is currently only compatible with the ClientPhase message if 905// the aggregations being used are exclusively federated_sum (not SecAgg). If 906// this compatibility requirement is satisfied, it is also valid to run the 907// aggregation portion of this ServerPhaseV2 message alongside the pre- and 908// post-aggregation logic from the original ServerPhase message. Ultimately, 909// we expect the full ServerPhaseV2 message to be run and the ServerPhase 910// message to be deprecated. 911message ServerPhaseV2 { 912 // A short CamelCase name for the ServerPhaseV2. 913 string name = 1; 914 915 // A functional interface for the TensorFlow logic the server should perform 916 // prior to the server-to-client broadcast. This should be used with the 917 // TensorFlow graph defined in server_graph_prepare_bytes. 918 TensorflowSpec tensorflow_spec_prepare = 3; 919 920 // The specification of inputs needed by the server_prepare TF logic. 921 oneof server_prepare_io_router { 922 ServerPrepareIORouter prepare_router = 4; 923 } 924 925 // A list of client-to-server aggregations to perform. 926 repeated ServerAggregationConfig aggregations = 2; 927 928 // A functional interface for the TensorFlow logic the server should perform 929 // post-aggregation. This should be used with the TensorFlow graph defined 930 // in server_graph_result_bytes. 931 TensorflowSpec tensorflow_spec_result = 5; 932 933 // The specification of inputs and outputs needed by the server_result TF 934 // logic. 935 oneof server_result_io_router { 936 ServerResultIORouter result_router = 6; 937 } 938} 939 940// Routing for server_prepare graph 941message ServerPrepareIORouter { 942 // The name of the scalar string tensor in the server_prepare TF graph that 943 // is fed the filepath to the initial server state checkpoint. The 944 // server_prepare logic reads from this filepath. 945 string prepare_server_state_input_filepath_tensor_name = 1; 946 947 // The name of the scalar string tensor in the server_prepare TF graph that 948 // is fed the filepath where the client checkpoint should be stored. The 949 // server_prepare logic writes to this filepath. 950 string prepare_output_filepath_tensor_name = 2; 951 952 // The name of the scalar string tensor in the server_prepare TF graph that 953 // is fed the filepath where the intermediate state checkpoint should be 954 // stored. The server_prepare logic writes to this filepath. The intermediate 955 // state checkpoint will be consumed by both the logic used to set parameters 956 // for aggregation and the post-aggregation logic. 957 string prepare_intermediate_state_output_filepath_tensor_name = 3; 958} 959 960// Routing for server_result graph 961message ServerResultIORouter { 962 // The name of the scalar string tensor in the server_result TF graph that is 963 // fed the filepath to the intermediate state checkpoint. The server_result 964 // logic reads from this filepath. 965 string result_intermediate_state_input_filepath_tensor_name = 1; 966 967 // The name of the scalar string tensor in the server_result TF graph that is 968 // fed the filepath to the aggregated client result checkpoint. The 969 // server_result logic reads from this filepath. 970 string result_aggregate_result_input_filepath_tensor_name = 2; 971 972 // The name of the scalar string tensor in the server_result TF graph that is 973 // fed the filepath where the updated server state should be stored. The 974 // server_result logic writes to this filepath. 975 string result_server_state_output_filepath_tensor_name = 3; 976} 977 978// Represents a single aggregation operation, combining one or more input 979// tensors from a collection of clients into one or more output tensors on the 980// server. 981message ServerAggregationConfig { 982 // The uri of the aggregation intrinsic (e.g. 'federated_sum'). 983 string intrinsic_uri = 1; 984 985 // Describes an argument to the aggregation operation. 986 message IntrinsicArg { 987 oneof arg { 988 // Refers to a tensor within the checkpoint provided by each client. 989 tensorflow.TensorSpecProto input_tensor = 2; 990 991 // Refers to a tensor within the intermediate server state checkpoint. 992 tensorflow.TensorSpecProto state_tensor = 3; 993 } 994 } 995 996 // List of arguments for the aggregation operation. The arguments can be 997 // dependent on client data (in which case they must be retrieved from 998 // clients) or they can be independent of client data (in which case they 999 // can be configured server-side). For now we assume all client-independent 1000 // arguments are constants. The arguments must be in the order expected by 1001 // the server. 1002 repeated IntrinsicArg intrinsic_args = 4; 1003 1004 // List of server-side outputs produced by the aggregation operation. 1005 repeated tensorflow.TensorSpecProto output_tensors = 5; 1006 1007 // List of inner aggregation intrinsics. This can be used to delegate parts 1008 // of the aggregation logic (e.g. a groupby intrinsic may want to delegate 1009 // a sum operation to a sum intrinsic). 1010 repeated ServerAggregationConfig inner_aggregations = 6; 1011} 1012 1013// Server Phase 1014// ============ 1015 1016// Represents a server phase which implements TF-based aggregation of multiple 1017// client updates. 1018// 1019// There are two different modes of aggregation that are described 1020// by the values in this message. The first is aggregation that is 1021// coming from coordinated sets of clients. This includes aggregation 1022// done via checkpoints from clients or aggregation done over a set 1023// of clients by a process like secure aggregation. The results of 1024// this first aggregation are saved to intermediate aggregation 1025// checkpoints. The second aggregation then comes from taking 1026// these intermediate checkpoints and aggregating over them. 1027// 1028// These two different modes of aggregation are done on different 1029// servers, the first in the 'L1' servers and the second in the 1030// 'L2' servers, so we use this nomenclature to describe these 1031// phases below. 1032// 1033// The ServerPhase message is currently in the process of being replaced by the 1034// ServerPhaseV2 message as we switch the plan building pipeline to use 1035// DistributeAggregateForm instead of MapReduceForm. During the migration 1036// process, we may generate both messages and use components from either 1037// message during execution. 1038// 1039message ServerPhase { 1040 // A short CamelCase name for the ServerPhase. 1041 string name = 8; 1042 1043 // =========================================================================== 1044 // L1 "Intermediate" Aggregation. 1045 // 1046 // This is the initial aggregation that creates partial aggregates from client 1047 // results. L1 Aggregation may be run on many different instances. 1048 // 1049 // Pre-condition: 1050 // The execution environment has loaded the graph from `server_graph_bytes`. 1051 1052 // 1. Initialize the phase. 1053 // 1054 // Operation to run before the first aggregation happens. 1055 // For instance, clears the accumulators so that a new aggregation can begin. 1056 string phase_init_op = 1; 1057 1058 // 2. For each client in set of clients: 1059 // a. Restore variables from the client checkpoint. 1060 // 1061 // Loads a checkpoint from a single client written via 1062 // `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once 1063 // for every client checkpoint in a round. 1064 CheckpointOp read_update = 3; 1065 // b. Aggregate the data coming from the client checkpoint. 1066 // 1067 // An operation that aggregates the data from read_update. 1068 // Generally this will add to accumulators and it may leverage internal data 1069 // inside the graph to adjust the weights of the Tensors. 1070 // 1071 // Executed once for each `read_update`, to (for example) update accumulator 1072 // variables using the values loaded during `read_update`. 1073 string aggregate_into_accumulators_op = 4; 1074 1075 // 3. After all clients have been aggregated, possibly restore 1076 // variables that have been aggregated via a separate process. 1077 // 1078 // Optionally restores variables where aggregation is done across 1079 // an entire round of client data updates. In contrast to `read_update`, 1080 // which restores once per client, this occurs after all clients 1081 // in a round have been processed. This allows, for example, side 1082 // channels where aggregation is done by a separate process (such 1083 // as in secure aggregation), in which the side channel aggregated 1084 // tensor is passed to the `before_restore_op` which ensure the 1085 // variables are restored properly. The `after_restore_op` will then 1086 // be responsible for performing the accumulation. 1087 // 1088 // Note that in current use this should not have a SaverDef, but 1089 // should only be used for side channels. 1090 CheckpointOp read_aggregated_update = 10; 1091 1092 // 4. Write the aggregated variables to an intermediate checkpoint. 1093 // 1094 // We require that `aggregate_into_accumulators_op` is associative and 1095 // commutative, so that the aggregates can be computed across 1096 // multiple TensorFlow sessions. 1097 // As an example, say we are computing the sum of 5 client updates: 1098 // A = X1 + X2 + X3 + X4 + X5 1099 // We can always do this in one session by calling `read_update`j and 1100 // `aggregate_into_accumulators_op` once for each client checkpoint. 1101 // 1102 // Alternatively, we could compute: 1103 // A1 = X1 + X2 in one TensorFlow session, and 1104 // A2 = X3 + X4 + X5 in a different session. 1105 // Each of these sessions can then write their accumulator state 1106 // with the `write_intermediate_update` CheckpointOp, and a yet another third 1107 // session can then call `read_intermediate_update` and 1108 // `aggregate_into_accumulators_op` on each of these checkpoints to compute: 1109 // A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5). 1110 CheckpointOp write_intermediate_update = 7; 1111 // End L1 "Intermediate" Aggregation. 1112 // =========================================================================== 1113 1114 // =========================================================================== 1115 // L2 Aggregation and Coordinator. 1116 // 1117 // This aggregates intermediate checkpoints from L1 Aggregation and performs 1118 // the finalizing of the update. Unlike L1 there will only be one instance 1119 // that does this aggregation. 1120 1121 // Pre-condition: 1122 // The execution environment has loaded the graph from `server_graph_bytes` 1123 // and restored the global model using `server_savepoint` from the parent 1124 // `Plan` message. 1125 1126 // 1. Initialize the phase. 1127 // 1128 // This currently re-uses the `phase_init_op` from L1 aggregation above. 1129 1130 // 2. Write a checkpoint that can be sent to the client. 1131 // 1132 // Generates a checkpoint to be sent to the client, to be read by 1133 // `FederatedComputeIORouter.input_filepath_tensor_name`. 1134 1135 CheckpointOp write_client_init = 2; 1136 1137 // 3. For each intermediate checkpoint: 1138 // a. Restore variables from the intermediate checkpoint. 1139 // 1140 // The corresponding read checkpoint op to the write_intermediate_update. 1141 // This is used instead of read_update for intermediate checkpoints because 1142 // the format of these updates may be different than those used in updates 1143 // from clients (which may, for example, be compressed). 1144 CheckpointOp read_intermediate_update = 9; 1145 // b. Aggregate the data coming from the intermediate checkpoint. 1146 // 1147 // An operation that aggregates the data from `read_intermediate_update`. 1148 // Generally this will add to accumulators and it may leverage internal data 1149 // inside the graph to adjust the weights of the Tensors. 1150 string intermediate_aggregate_into_accumulators_op = 11; 1151 1152 // 4. Write the aggregated intermediate variables to a checkpoint. 1153 // 1154 // This is used for downstream, cross-round aggregation of metrics. 1155 // These variables will be read back into a session with 1156 // read_intermediate_update. 1157 // 1158 // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def 1159 // to disable writing accumulator checkpoints. 1160 CheckpointOp write_accumulators = 12; 1161 1162 // 5. Finalize the round. 1163 // 1164 // This can include: 1165 // - Applying the update aggregated from the intermediate checkpoints to the 1166 // global model and other updates to cross-round state variables. 1167 // - Computing final round metric values (e.g. the `report` of a 1168 // `tff.federated_aggregate`). 1169 string apply_aggregrated_updates_op = 5; 1170 1171 // 5. Fetch the server aggregated metrics. 1172 // 1173 // A list of names of metric variables to fetch from the TensorFlow session. 1174 repeated Metric metrics = 6; 1175 1176 // 6. Serialize the updated server state (e.g. the coefficients of the global 1177 // model in FL) using `server_savepoint` in the parent `Plan` message. 1178 1179 // End L2 Aggregation. 1180 // =========================================================================== 1181} 1182 1183// Represents the server phase in an eligibility computation. 1184// 1185// This phase produces a checkpoint to be sent to clients. This checkpoint is 1186// then used as an input to the clients' task eligibility computations. 1187// This phase *does not include any aggregation.* 1188message ServerEligibilityComputationPhase { 1189 // A short CamelCase name for the ServerEligibilityComputationPhase. 1190 string name = 1; 1191 1192 // The names of the TensorFlow nodes to run in order to produce output. 1193 repeated string target_node_names = 2; 1194 1195 // The specification of inputs and outputs to the TensorFlow graph. 1196 oneof server_eligibility_io_router { 1197 TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true]; 1198 } 1199} 1200 1201// Represents the inputs and outputs of a `ServerEligibilityComputationPhase` 1202// which takes a single `TaskEligibilityContext` as input. 1203message TEContextServerEligibilityIORouter { 1204 // The name of the scalar string tensor that must be fed a serialized 1205 // `TaskEligibilityContext`. 1206 string context_proto_input_tensor_name = 1; 1207 1208 // The name of the scalar string tensor that must be fed the path to which 1209 // the server graph should write the checkpoint file to be sent to the client. 1210 string output_filepath_tensor_name = 2; 1211} 1212 1213// Plan 1214// ===== 1215 1216// Represents the overall plan for performing federated optimization or 1217// personalization, as handed over to the production system. This will 1218// typically be split down into individual pieces for different production 1219// parts, e.g. server and client side. 1220// NEXT_TAG: 15 1221message Plan { 1222 reserved 1, 3, 5; 1223 1224 // The actual type of the server_*_graph_bytes fields below is expected to be 1225 // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form 1226 // for two reasons. 1227 // 1) We may use execution engines other than TensorFlow. 1228 // 2) We wish to avoid the cost of deserialized and re-serializing large 1229 // graphs, in the Federated Learning service. 1230 1231 // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes, 1232 // server_graph_prepare_bytes, and server_graph_result_bytes may all be set. 1233 // If we're using a MapReduceForm-based server implementation, only 1234 // server_graph_bytes will be used. If we're using a DistributeAggregateForm- 1235 // based server implementation, only server_graph_prepare_bytes and 1236 // server_graph_result_bytes will be used. 1237 1238 // Optional. The TensorFlow graph used for all server processing described by 1239 // ServerPhase. For personalization, this will not be set. 1240 google.protobuf.Any server_graph_bytes = 7; 1241 1242 // Optional. The TensorFlow graph used for all server processing described by 1243 // ServerPhaseV2.tensorflow_spec_prepare. 1244 google.protobuf.Any server_graph_prepare_bytes = 13; 1245 1246 // Optional. The TensorFlow graph used for all server processing described by 1247 // ServerPhaseV2.tensorflow_spec_result. 1248 google.protobuf.Any server_graph_result_bytes = 14; 1249 1250 // A savepoint to sync the server checkpoint with a persistent 1251 // storage system. The storage initially holds a seeded checkpoint 1252 // which can subsequently read and updated by this savepoint. 1253 // Optional-- not present in eligibility computation plans (those with a 1254 // ServerEligibilityComputationPhase). This is used in conjunction with 1255 // ServerPhase only. 1256 CheckpointOp server_savepoint = 2; 1257 1258 // Required. The TensorFlow graph that describes the TensorFlow logic a client 1259 // should perform. It should be consistent with the `TensorflowSpec` field in 1260 // the `client_phase`. The actual type is expected to be tensorflow.GraphDef. 1261 // The TensorFlow graph is stored in serialized form for two reasons. 1262 // 1) We may use execution engines other than TensorFlow. 1263 // 2) We wish to avoid the cost of deserialized and re-serializing large 1264 // graphs, in the Federated Learning service. 1265 google.protobuf.Any client_graph_bytes = 8; 1266 1267 // Optional. The FlatBuffer used for TFLite training. 1268 // It contains the same model information as the client_graph_bytes, but with 1269 // a different format. 1270 bytes client_tflite_graph_bytes = 12; 1271 1272 // A pair of client phase and server phase which are processed in 1273 // sync. The server execution defines how the results of a client 1274 // phase are aggregated, and how the checkpoints for clients are 1275 // generated. 1276 message Phase { 1277 // Required. The client phase. 1278 ClientPhase client_phase = 1; 1279 1280 // Optional. Server phase for TF-based aggregation; not provided for 1281 // personalization or eligibility tasks. 1282 ServerPhase server_phase = 2; 1283 1284 // Optional. Server phase for native aggregation; only provided for tasks 1285 // that have enabled the corresponding flag. 1286 ServerPhaseV2 server_phase_v2 = 4; 1287 1288 // Optional. Only provided for eligibility tasks. 1289 ServerEligibilityComputationPhase server_eligibility_phase = 3; 1290 } 1291 1292 // A pair of client and server computations to run. 1293 repeated Phase phase = 4; 1294 1295 // Metrics that are persistent across different phases. This 1296 // includes, for example, counters that track how much work of 1297 // different kinds has been done. 1298 repeated Metric metrics = 6; 1299 1300 // Describes how metrics in both the client and server phases should be 1301 // aggregated. 1302 repeated OutputMetric output_metrics = 10; 1303 1304 // Version of the plan: 1305 // version == 0 - Old plan without version field, containing b/65131070 1306 // version >= 1 - plan supports multi-shard aggregation mode (L1/L2) 1307 int32 version = 9; 1308 1309 // A TensorFlow ConfigProto packed in an Any. 1310 // 1311 // If this field is unset, if the Any proto is set but empty, or if the Any 1312 // proto is populated with an empty ConfigProto (i.e. its `type_url` field is 1313 // set, but the `value` field is empty) then the client implementation may 1314 // choose a set of configuration parameters to provide to TensorFlow by 1315 // default. 1316 // 1317 // In all other cases this field must contain a valid packed ConfigProto 1318 // (invalid values will result in an error at execution time), and in this 1319 // case the client will not provide any other configuration parameters by 1320 // default. 1321 google.protobuf.Any tensorflow_config_proto = 11; 1322} 1323 1324// Represents a client part of the plan of federated optimization. 1325// This also used to describe a client-only plan for standalone on-device 1326// training, known as personalization. 1327// NEXT_TAG: 6 1328message ClientOnlyPlan { 1329 reserved 3; 1330 1331 // The graph to use for training, in binary form. 1332 bytes graph = 1; 1333 1334 // Optional. The flatbuffer used for TFLite training. 1335 // Whether "graph" or "tflite_graph" is used for training is up to the client 1336 // code to allow for a flag-controlled a/b rollout. 1337 bytes tflite_graph = 5; 1338 1339 // The client phase to execute. 1340 ClientPhase phase = 2; 1341 1342 // A TensorFlow ConfigProto. 1343 google.protobuf.Any tensorflow_config_proto = 4; 1344} 1345 1346// Represents the cross round aggregation portion for user defined measurements. 1347// This is used by tools that process / analyze accumulator checkpoints 1348// after a round of computation, to achieve aggregation beyond a round. 1349message CrossRoundAggregationExecution { 1350 // Operation to run before reading accumulator checkpoint. 1351 string init_op = 1; 1352 1353 // Reads accumulator checkpoint. 1354 CheckpointOp read_aggregated_update = 2; 1355 1356 // Operation to merge loaded checkpoint into accumulator. 1357 string merge_op = 3; 1358 1359 // Reads and writes the final aggregated accumulator vars. 1360 CheckpointOp read_write_final_accumulators = 6; 1361 1362 // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable` 1363 // to the user defined name of the signal. 1364 repeated Measurement measurements = 4; 1365 1366 // The `tf.Graph` used for aggregating accumulator checkpoints when 1367 // loading metrics. 1368 google.protobuf.Any cross_round_aggregation_graph_bytes = 5; 1369} 1370 1371message Measurement { 1372 // Name of a TensorFlow op to run to read/fetch the value of this measurement. 1373 string read_op_name = 1; 1374 1375 // A human-readable name for the measurement. Names are usually 1376 // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. 1377 string name = 2; 1378 1379 reserved 3; 1380 1381 // A serialized `tff.Type` for the measurement. 1382 bytes tff_type = 4; 1383} 1384