1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
17 #define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
18
19 #include <string>
20
21 #include "absl/base/casts.h"
22 #include "tensorflow/core/framework/op.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
25
26 namespace tensorflow {
27 namespace tpu {
28
29 using OptimizationAlgorithm = OptimizationParameters::ParametersCase;
30
31 // Returns the name of the optimization algorithm.
32 string GetOptimizationAlgorithmName(OptimizationAlgorithm alg);
33
34 // Returns a user-friendly name for the optimization algorithm.
35 string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg);
36
37 // Returns all supported optimization algorithms.
38 std::vector<OptimizationAlgorithm> GetOptimizationAlgorithms();
39
40 enum class GradientAccumulationSupport {
41 // Accumulation cannot be used with this optimizer.
42 kNotSupported,
43
44 // Accumulation is allowed and changes optimizer behavior.
45 kSupported,
46 };
47
48 // Returns the number of optimization parameter vectors used by the optimization
49 // algorithm, excluding the weights themselves and assuming no gradient
50 // accumulation.
51 Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms,
52 int *count);
53
54 // Returns whether (and how) an optimization algorithm supports gradient
55 // accumulation.
56 Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms,
57 GradientAccumulationSupport *support);
58
59 // Returns whether both the given set of optimization parameters has gradient
60 // accumulation turned on and that the algorithm used supports it or should
61 // ignore that setting. Returns an error if gradient accumulation is enabled and
62 // the algorithm does not support it.
63 Status UseGradientAccumulation(const OptimizationParameters ¶ms,
64 bool *use_gradient_accumulation);
65
66 // Returns the parameter specifications for the optimization algorithm (the main
67 // parameters first, followed by any auxiliary parameters such as Adagrad
68 // accumulators).
69 Status GetOptimizationAlgorithmStateVariables(
70 const OptimizationParameters ¶ms,
71 std::vector<StateVariableSpecification> *state_variables);
72
73 // Maximum value of auxiliary_parametery_count for any optimization algorithm.
74 // This count is used by TPU embedding load/retrieve and needs to be independent
75 // of any particular TPU version and hence, we take the maximum across all TPU
76 // versions.
77 static constexpr int kMaxAuxiliaryParameterCount = 7;
78
79 // Fill value for gradient accumulators. This is a denormal so that it will be
80 // flushed to zero on the current TPU platforms and needs to continue to have
81 // the following properties in the future:
82 //
83 // 1. Does not have the same bit pattern as a zero and can be distinguished from
84 // it using integer operations.
85 // 2. Treated as zero by floating-point arithmetic operations (at least addition
86 // and subtraction).
87 // 3. Cannot be produced by any floating-point arithmetic operation, including
88 // those involving itself.
89 //
90 // It does not need to compare equal or not equal to zero in floating point. We
91 // need to use a non-zero value here because some optimization algorithms are
92 // not no-ops on zero gradients, so we need to distinguish an accumulated
93 // gradient of zero from one that has been cleared after its gradients have
94 // already been applied to the parameters and accumulators.
GradientAccumulatorInitialValue()95 inline float GradientAccumulatorInitialValue() {
96 return absl::bit_cast<float, uint32>(1);
97 }
98
99 // Generic shape function for per-optimization-algorithm load ops.
100 class LoadOpShapeFunction {
101 public:
102 // Computes resulting shape and does parameter checking.
103 Status operator()(shape_inference::InferenceContext *c) const;
104 };
105
106 // Generic shape function for per-optimization-algorithm retrieve ops.
107 class RetrieveOpShapeFunction {
108 public:
109 // Computes resulting shape and does parameter checking.
110 Status operator()(shape_inference::InferenceContext *c) const;
111 };
112
113 } // namespace tpu
114 } // namespace tensorflow
115
116 #endif // TENSORFLOW_CORE_TPU_TPU_EMBEDDING_OPTIMIZATION_PARAMETERS_UTILS_H_
117