1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/hlo.pb.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/shape_inference.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/strings/stringprintf.h"
28 
29 namespace tensorflow {
30 namespace tpu {
31 
GetOptimizationAlgorithmName(OptimizationAlgorithm alg)32 std::string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) {
33   switch (alg) {
34     case OptimizationAlgorithm::kAdagrad:
35       return "Adagrad";
36     case OptimizationAlgorithm::kAdagradMomentum:
37       return "AdagradMomentum";
38     case OptimizationAlgorithm::kBoundedAdagrad:
39       return "BoundedAdagrad";
40     case OptimizationAlgorithm::kStochasticGradientDescent:
41       return "StochasticGradientDescent";
42     case OptimizationAlgorithm::kFtrl:
43       return "FTRL";
44     case OptimizationAlgorithm::kAdam:
45       return "ADAM";
46     case OptimizationAlgorithm::kMomentum:
47       return "Momentum";
48     case OptimizationAlgorithm::kRmsProp:
49       return "RMSProp";
50     case OptimizationAlgorithm::kCenteredRmsProp:
51       return "CenteredRMSProp";
52     case OptimizationAlgorithm::kMdlAdagradLight:
53       return "MDLAdagradLight";
54     case OptimizationAlgorithm::kAdadelta:
55       return "Adadelta";
56     case OptimizationAlgorithm::kProximalAdagrad:
57       return "ProximalAdagrad";
58     case OptimizationAlgorithm::kOnlineYogi:
59       return "OnlineYogi";
60     case OptimizationAlgorithm::kProximalYogi:
61       return "ProximalYogi";
62     case OptimizationAlgorithm::kFrequencyEstimator:
63       return "FrequencyEstimator";
64     case OptimizationAlgorithm::kUserDefinedProgram:
65       return "UserDefinedProgram";
66     case OptimizationAlgorithm::kAssign:
67       return "Assign";
68     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
69       return "*** Not set ***";
70   }
71   return "*** Not set ***";
72 }
73 
GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg)74 std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) {
75   switch (alg) {
76     case OptimizationAlgorithm::kAdagrad:
77       return "Adagrad";
78     case OptimizationAlgorithm::kAdagradMomentum:
79       return "Adagrad with Momentum";
80     case OptimizationAlgorithm::kBoundedAdagrad:
81       return "Bounded Adagrad";
82     case OptimizationAlgorithm::kStochasticGradientDescent:
83       return "stochastic gradient descent";
84     case OptimizationAlgorithm::kFtrl:
85       return "FTRL";
86     case OptimizationAlgorithm::kAdam:
87       return "ADAM";
88     case OptimizationAlgorithm::kMomentum:
89       return "Momentum";
90     case OptimizationAlgorithm::kRmsProp:
91       return "RMSProp";
92     case OptimizationAlgorithm::kCenteredRmsProp:
93       return "centered RMSProp";
94     case OptimizationAlgorithm::kMdlAdagradLight:
95       return "MDL Adagrad Light";
96     case OptimizationAlgorithm::kAdadelta:
97       return "Adadelta";
98     case OptimizationAlgorithm::kProximalAdagrad:
99       return "proximal Adagrad";
100     case OptimizationAlgorithm::kOnlineYogi:
101       return "online Yogi";
102     case OptimizationAlgorithm::kProximalYogi:
103       return "proximal Yogi";
104     case OptimizationAlgorithm::kFrequencyEstimator:
105       return "frequency estimator";
106     case OptimizationAlgorithm::kUserDefinedProgram:
107       return "UserDefinedProgram";
108     case OptimizationAlgorithm::kAssign:
109       return "Assign";
110     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
111       return "unknown (not specified)";
112   }
113   return "unknown (not specified)";
114 }
115 
116 // Returns the number of optimization parameter vectors used by the optimization
117 // algorithm, excluding the weights themselves and assuming no gradient
118 // accumulation.
GetBaseAuxiliaryParameterCount(const OptimizationParameters & params,int * count)119 Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params,
120                                       int* count) {
121   switch (params.parameters_case()) {
122     case OptimizationAlgorithm::kAdagrad:
123       *count = 1;
124       return OkStatus();
125     case OptimizationAlgorithm::kAdagradMomentum:
126       *count = 2;
127       return OkStatus();
128     case OptimizationAlgorithm::kBoundedAdagrad:
129       *count = 1;
130       return OkStatus();
131     case OptimizationAlgorithm::kStochasticGradientDescent:
132       *count = 0;
133       return OkStatus();
134     case OptimizationAlgorithm::kFtrl:
135       *count = 2;
136       return OkStatus();
137     case OptimizationAlgorithm::kAdam:
138       *count = 2;
139       return OkStatus();
140     case OptimizationAlgorithm::kMomentum:
141       *count = 1;
142       return OkStatus();
143     case OptimizationAlgorithm::kRmsProp:
144       *count = 2;
145       return OkStatus();
146     case OptimizationAlgorithm::kCenteredRmsProp:
147       *count = 3;
148       return OkStatus();
149     case OptimizationAlgorithm::kMdlAdagradLight:
150       *count = 3;
151       return OkStatus();
152     case OptimizationAlgorithm::kAdadelta:
153       *count = 2;
154       return OkStatus();
155     case OptimizationAlgorithm::kProximalAdagrad:
156       *count = 1;
157       return OkStatus();
158     case OptimizationAlgorithm::kOnlineYogi:
159       *count = 2;
160       return OkStatus();
161     case OptimizationAlgorithm::kProximalYogi:
162       *count = 2;
163       return OkStatus();
164     case OptimizationAlgorithm::kFrequencyEstimator:
165       *count = 1;
166       return OkStatus();
167     case OptimizationAlgorithm::kUserDefinedProgram: {
168       const xla::ProgramShapeProto& program_shape =
169           params.user_defined_program().program().host_program_shape();
170 
171       const int num_inputs = program_shape.parameters_size();
172       const int num_outputs = program_shape.result().tuple_shapes_size();
173 
174       if ((num_inputs < 2) || ((num_inputs != num_outputs + 1) &&
175                                (num_inputs != num_outputs + 2))) {
176         return errors::InvalidArgument(
177             "User-defined TPU embedding optimizer program must have at least "
178             "two inputs and the number of outputs must be 1 or 2 less than the "
179             "number of inputs. Received ",
180             num_inputs, " input(s) and ", num_outputs, "output(s).");
181       }
182 
183       *count = num_outputs - 1;
184 
185       return OkStatus();
186     }
187     case OptimizationAlgorithm::kAssign:
188       *count = 0;
189       return OkStatus();
190     case OptimizationAlgorithm::PARAMETERS_NOT_SET:
191       return errors::InvalidArgument("No optimization algorithm specified");
192   }
193   return errors::InvalidArgument("No optimization algorithm specified");
194 }
195 
GetGradientAccumulationSupport(const OptimizationParameters & params,GradientAccumulationSupport * support)196 Status GetGradientAccumulationSupport(const OptimizationParameters& params,
197                                       GradientAccumulationSupport* support) {
198   int auxiliary_parameter_count;
199   TF_RETURN_IF_ERROR(
200       GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count));
201   *support = auxiliary_parameter_count + 1 <= kMaxAuxiliaryParameterCount
202                  ? GradientAccumulationSupport::kSupported
203                  : GradientAccumulationSupport::kNotSupported;
204   return OkStatus();
205 }
206 
UseGradientAccumulation(const OptimizationParameters & params,bool * use_gradient_accumulation)207 Status UseGradientAccumulation(const OptimizationParameters& params,
208                                bool* use_gradient_accumulation) {
209   GradientAccumulationSupport support;
210   TF_RETURN_IF_ERROR(GetGradientAccumulationSupport(params, &support));
211   bool raw_gradient_accumulation_status = false;
212   switch (params.gradient_accumulation_status()) {
213     case GradientAccumulationStatus::UNSPECIFIED: {
214       // Default is now to turn gradient accumulation on by default.
215       raw_gradient_accumulation_status = true;
216       break;
217     }
218     case GradientAccumulationStatus::DISABLED: {
219       raw_gradient_accumulation_status = false;
220       break;
221     }
222     case GradientAccumulationStatus::ENABLED: {
223       raw_gradient_accumulation_status = true;
224       break;
225     }
226     default:
227       return errors::Internal(
228           absl::StrCat("Unsupported gradient accumulation status ",
229                        GradientAccumulationStatus_Status_Name(
230                            params.gradient_accumulation_status())));
231   }
232   switch (support) {
233     case GradientAccumulationSupport::kSupported: {
234       *use_gradient_accumulation = raw_gradient_accumulation_status;
235       break;
236     }
237     case GradientAccumulationSupport::kNotSupported: {
238       if (raw_gradient_accumulation_status) {
239         return errors::InvalidArgument(strings::Printf(
240             "Optimization algorithm %s does not support gradient accumulation "
241             "but parameters specify it.",
242             GetOptimizationAlgorithmName(params.parameters_case()).c_str()));
243       }
244       *use_gradient_accumulation = false;
245       break;
246     }
247   }
248   return OkStatus();
249 }
250 
GetOptimizationAlgorithmStateVariables(const OptimizationParameters & params,std::vector<StateVariableSpecification> * state_variables)251 Status GetOptimizationAlgorithmStateVariables(
252     const OptimizationParameters& params,
253     std::vector<StateVariableSpecification>* state_variables) {
254   // The parameter set for the weights themselves is required to be named
255   // "parameters". The rest should stay stable for compatibility. There is an
256   // internal function, GetOptimizationAlgorithmStateVariableInternalIndices,
257   // that needs to be updated along with this one.
258   bool use_gradient_accumulation;
259   TF_RETURN_IF_ERROR(
260       UseGradientAccumulation(params, &use_gradient_accumulation));
261 
262   auto add_state_variable = [&](const std::string& name) {
263     StateVariableSpecification spec;
264     spec.set_name(name);
265     (void)spec.mutable_user_defined();
266     state_variables->push_back(spec);
267   };
268 
269   switch (params.parameters_case()) {
270     case OptimizationAlgorithm::kAdagrad: {
271       add_state_variable("parameters");
272       add_state_variable("accumulators");
273       break;
274     }
275     case OptimizationAlgorithm::kAdagradMomentum: {
276       add_state_variable("parameters");
277       add_state_variable("accumulators");
278       add_state_variable("momenta");
279       break;
280     }
281     case OptimizationAlgorithm::kBoundedAdagrad: {
282       add_state_variable("parameters");
283       add_state_variable("accumulators");
284       break;
285     }
286     case OptimizationAlgorithm::kStochasticGradientDescent: {
287       add_state_variable("parameters");
288       break;
289     }
290     case OptimizationAlgorithm::kFtrl: {
291       add_state_variable("parameters");
292       add_state_variable("accumulators");
293       add_state_variable("linears");
294       break;
295     }
296     case OptimizationAlgorithm::kAdam: {
297       add_state_variable("parameters");
298       add_state_variable("momenta");
299       add_state_variable("velocities");
300       break;
301     }
302     case OptimizationAlgorithm::kMomentum: {
303       add_state_variable("parameters");
304       add_state_variable("momenta");
305       break;
306     }
307     case OptimizationAlgorithm::kRmsProp: {
308       add_state_variable("parameters");
309       add_state_variable("ms");
310       add_state_variable("mom");
311       break;
312     }
313     case OptimizationAlgorithm::kCenteredRmsProp: {
314       add_state_variable("parameters");
315       add_state_variable("ms");
316       add_state_variable("mom");
317       add_state_variable("mg");
318       break;
319     }
320     case OptimizationAlgorithm::kMdlAdagradLight: {
321       add_state_variable("parameters");
322       add_state_variable("accumulators");
323       add_state_variable("weights");
324       add_state_variable("benefits");
325       break;
326     }
327     case OptimizationAlgorithm::kAdadelta: {
328       add_state_variable("parameters");
329       add_state_variable("accumulators");
330       add_state_variable("updates");
331       break;
332     }
333     case OptimizationAlgorithm::kProximalAdagrad: {
334       add_state_variable("parameters");
335       add_state_variable("accumulators");
336       break;
337     }
338     case OptimizationAlgorithm::kOnlineYogi: {
339       add_state_variable("parameters");
340       add_state_variable("vs");
341       add_state_variable("linears");
342       break;
343     }
344     case OptimizationAlgorithm::kProximalYogi: {
345       add_state_variable("parameters");
346       add_state_variable("v");
347       add_state_variable("m");
348       break;
349     }
350     case OptimizationAlgorithm::kFrequencyEstimator: {
351       add_state_variable("parameters");
352       add_state_variable("last_hit_step");
353       break;
354     }
355     case OptimizationAlgorithm::kUserDefinedProgram: {
356       add_state_variable("parameters");
357       int num_slots = -1;
358       TF_RETURN_IF_ERROR(GetBaseAuxiliaryParameterCount(params, &num_slots));
359       for (int i = 0; i < num_slots; ++i) {
360         add_state_variable(absl::StrCat("Slot_", i));
361       }
362       break;
363     }
364     case OptimizationAlgorithm::kAssign: {
365       add_state_variable("parameters");
366       break;
367     }
368     case OptimizationAlgorithm::PARAMETERS_NOT_SET: {
369       return errors::InvalidArgument("No optimization algorithm specified");
370     }
371   }
372 
373   // This needs to be last for compatibility.
374   if (use_gradient_accumulation) {
375     StateVariableSpecification gradient_acc;
376     gradient_acc.set_name("gradient_accumulators");
377     gradient_acc.mutable_fill_with_constant()->set_initial_value(
378         GradientAccumulatorInitialValue());
379     state_variables->push_back(std::move(gradient_acc));
380   }
381 
382   if (state_variables->size() > kMaxAuxiliaryParameterCount + 1) {
383     return errors::InvalidArgument(
384         "Optimization algorithm",
385         GetOptimizationAlgorithmName(params.parameters_case()),
386         "does not support gradient accumulation because it "
387         "already has too many other accumulators");
388   }
389   return OkStatus();
390 }
391 
GetOptimizationAlgorithms()392 std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms() {
393   return {
394       OptimizationAlgorithm::kAdagrad,
395       OptimizationAlgorithm::kAdagradMomentum,
396       OptimizationAlgorithm::kBoundedAdagrad,
397       OptimizationAlgorithm::kStochasticGradientDescent,
398       OptimizationAlgorithm::kFtrl,
399       OptimizationAlgorithm::kAdam,
400       OptimizationAlgorithm::kMomentum,
401       OptimizationAlgorithm::kRmsProp,
402       OptimizationAlgorithm::kCenteredRmsProp,
403       OptimizationAlgorithm::kMdlAdagradLight,
404       OptimizationAlgorithm::kAdadelta,
405       OptimizationAlgorithm::kProximalAdagrad,
406       OptimizationAlgorithm::kOnlineYogi,
407       OptimizationAlgorithm::kProximalYogi,
408       OptimizationAlgorithm::kFrequencyEstimator,
409       OptimizationAlgorithm::kUserDefinedProgram,
410       OptimizationAlgorithm::kAssign,
411   };
412 }
413 
operator ()(shape_inference::InferenceContext * c) const414 Status LoadOpShapeFunction::operator()(
415     shape_inference::InferenceContext* c) const {
416   int table_id;
417   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
418   string table_name;
419   TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
420   // Exactly one must be non-default.
421   if ((table_id >= 0) == (!table_name.empty())) {
422     return errors::InvalidArgument(
423         "exactly one of table_id or table_name must be non-default");
424   }
425   int num_shards;
426   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
427   int shard_id;
428   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
429 
430   // Verify shapes have rank 2 and are compatible when they are
431   // required to be valid.
432   shape_inference::ShapeHandle parameter_shape;
433   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &parameter_shape));
434   for (int j = 1; j < c->num_inputs(); ++j) {
435     shape_inference::ShapeHandle accumulator_j_shape;
436     TF_RETURN_IF_ERROR(c->WithRank(c->input(j), 2, &accumulator_j_shape));
437     shape_inference::ShapeHandle merged;
438     TF_RETURN_IF_ERROR(c->Merge(parameter_shape, accumulator_j_shape, &merged));
439   }
440 
441   return OkStatus();
442 }
443 
operator ()(shape_inference::InferenceContext * c) const444 Status RetrieveOpShapeFunction::operator()(
445     shape_inference::InferenceContext* c) const {
446   int table_id;
447   TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
448   string table_name;
449   TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
450   // Exactly one must be non-default.
451   if ((table_id >= 0) == (!table_name.empty())) {
452     return errors::InvalidArgument(
453         "exactly one of table_id or table_name must be non-default");
454   }
455   int num_shards;
456   TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
457   int shard_id;
458   TF_RETURN_IF_ERROR(c->GetAttr("shard_id", &shard_id));
459   for (int j = 0; j < c->num_outputs(); ++j) {
460     c->set_output(j, c->MakeShape(std::vector<shape_inference::DimensionHandle>(
461                          2, c->UnknownDim())));
462   }
463   return OkStatus();
464 }
465 
466 }  // namespace tpu
467 }  // namespace tensorflow
468