xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
18 
19 #include <string>
20 
21 #include "tensorflow/core/lib/gtl/flatset.h"
22 #include "tensorflow/core/lib/strings/str_util.h"
23 #include "tensorflow/core/util/env_var.h"
24 
25 namespace tensorflow {
26 namespace grappler {
27 
28 // Represents the four lists of ops: the allow list, infer list, deny list, and
29 // clear list. These lists determine which ops are converted to fp16/bf16
30 // (referred to as 'f16' for short) and which ops stay as fp32.
31 class AutoMixedPrecisionLists {
32  public:
~AutoMixedPrecisionLists()33   virtual ~AutoMixedPrecisionLists() {}
34 
35   // Returns the set of ops that are considered numerically-safe (for execution
36   // in f16), performance-critical, and can run in f16. These ops are always
37   // converted to f16.
38   virtual gtl::FlatSet<string> AllowList() = 0;
39   // Returns the set of ops that can run in f16 and are considered numerically-
40   // safe (for execution in f16), but which may be made unsafe by an upstream
41   // denylist op.
42   virtual gtl::FlatSet<string> InferList() = 0;
43   // Returns the set of ops that are considered numerically-dangerous (i.e.,
44   // unsafe for execution in f16) and whose effects may also be observed in
45   // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
46   // the Exp).
47   virtual gtl::FlatSet<string> DenyList() = 0;
48   // Returns the set of ops that do not have numerically-significant effects
49   // (i.e., they are always considered safe for execution in f16 precision), and
50   // can run in f16.
51   virtual gtl::FlatSet<string> ClearList() = 0;
52 
53  protected:
54   // Adds or removes ops from list if certain environmental variables are set.
UpdateList(const string & list_name,gtl::FlatSet<string> * list)55   static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
56     CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" ||  // Crash OK.
57           list_name == "DENYLIST" || list_name == "CLEARLIST" ||
58           // TODO(reedwm): for bkwds compat; remove when no longer necessary:
59           list_name == "WHITELIST" || list_name == "GRAYLIST" ||
60           list_name == "BLACKLIST");
61     string add_env_var =
62         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
63     string remove_env_var =
64         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
65     string to_add, to_remove;
66     TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
67     TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
68     for (const auto& x : str_util::Split(to_add, ",")) {
69       list->insert(x);
70     }
71     for (const auto& x : str_util::Split(to_remove, ",")) {
72       list->erase(x);
73     }
74   }
75 
76   // Subclasses should include these on the ClearList.
AddTensorListOps(gtl::FlatSet<string> * list)77   static void AddTensorListOps(gtl::FlatSet<string>* list) {
78     // Note: if a data structure op (such as TensorListPopBack) is added here,
79     // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
80     // LINT.IfChange
81     constexpr const char* tensor_list_ops[] = {
82         "TensorListConcat",     "TensorListConcatLists",
83         "TensorListConcatV2",   "TensorListGather",
84         "TensorListGetItem",    "TensorListPopBack",
85         "TensorListPushBack",   "TensorListPushBackBatch",
86         "TensorListFromTensor", "TensorListScatter",
87         "TensorListScatterV2",  "TensorListScatterIntoExistingList",
88         "TensorListSetItem",    "TensorListSplit",
89         "TensorListStack"};
90     // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc)
91     for (auto op : tensor_list_ops) {
92       list->insert(op);
93     }
94   }
95 };
96 
97 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
98  private:
IsPseudoFastMath()99   static bool IsPseudoFastMath() {
100     string optimization_level;
101     TF_CHECK_OK(
102         ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "",
103                              &optimization_level));
104     optimization_level = str_util::Uppercase(optimization_level);
105     return optimization_level == "TENSOR_CORES_ONLY";
106   }
107 
108  public:
AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)109   AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
110       : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
111 
AllowList()112   gtl::FlatSet<string> AllowList() override {
113     auto list = gtl::FlatSet<string>{
114         "BlockLSTM",
115         "BlockLSTMV2",
116         "BlockLSTMGrad",
117         "BlockLSTMGradV2",
118         "Conv2D",
119         "Conv2DBackpropFilter",
120         "Conv2DBackpropInput",
121         "CudnnRNN",
122         "CudnnRNNBackprop",
123         "CudnnRNNBackpropV2",
124         "CudnnRNNBackpropV3",
125         "CudnnRNNV2",
126         "CudnnRNNV3",
127         "Einsum",
128         "FusedConv2DBiasActivation",
129         "FusedSparseConvGpuV2",
130         "GRUBlockCell",
131         "GRUBlockCellGrad",
132         "LSTMBlockCell",
133         "LSTMBlockCellGrad",
134         "MatMul",
135         "Mha",
136         "Tmlp",
137     };
138 #if TENSORFLOW_USE_ROCM
139     if (true) {
140 #else
141     if (cuda_version_ >= 9010) {
142       // Fp16 BatchMatMul is slow before CUDA 9.1.
143 #endif
144       list.insert("BatchMatMul");
145       list.insert("BatchMatMulV2");
146     }
147     if (cudnn_version_ >= 7602) {
148       // Fp16 3D conv is slow before CUDNN 7.6.2.
149       list.insert("Conv3D");
150       list.insert("Conv3DBackpropFilter");
151       list.insert("Conv3DBackpropFilterV2");
152       list.insert("Conv3DBackpropInput");
153       list.insert("Conv3DBackpropInputV2");
154     }
155     if (cudnn_version_ >= 8000) {
156       list.insert("DepthwiseConv2dNative");
157       list.insert("DepthwiseConv2dNativeBackpropFilter");
158       list.insert("DepthwiseConv2dNativeBackpropInput");
159     }
160     UpdateList("ALLOWLIST", &list);
161     // For backwards compatibility, keeping the original env variable here.
162     // TODO(reedwm): This should be removed if we don't have active users.
163     UpdateList("WHITELIST", &list);
164 
165     return list;
166   }
167 
168   gtl::FlatSet<string> InferList() override {
169     if (IsPseudoFastMath()) {
170       return gtl::FlatSet<string>{};
171     }
172 
173     auto list = gtl::FlatSet<string>{
174         "Add",
175         "AddN",
176         "AddV2",
177         "AvgPool",
178         "AvgPool3D",
179         "AvgPool3DGrad",
180         "AvgPoolGrad",
181         "BiasAdd",
182         "BiasAddGrad",
183         "BiasAddV1",
184         "Elu",
185         "EluGrad",
186         "Erf",
187         "Erfc",
188         "FloorDiv",
189         "FusedBatchNormV2",
190         "FusedBatchNormGradV2",
191         "FusedBatchNormV3",
192         "FusedBatchNormGradV3",
193         "_FusedBatchNormEx",
194         "Inv",
195         "LeakyRelu",
196         "LeakyReluGrad",
197         "Log",
198         "Log1p",
199         "LogSoftmax",
200         "Mul",
201         "Prod",
202         "RealDiv",
203         "Reciprocal",
204         "Selu",
205         "SeluGrad",
206         "Sigmoid",
207         "SigmoidGrad",
208         "Softmax",
209         "Softplus",
210         "SoftplusGrad",
211         "Softsign",
212         "SoftsignGrad",
213         "Sqrt",
214         "Sub",
215         "Tanh",
216         "TanhGrad",
217     };
218     UpdateList("INFERLIST", &list);
219     // For backwards compatibility, keeping the original env variable here.
220     // TODO(reedwm): This should be removed if we don't have active users.
221     UpdateList("GRAYLIST", &list);
222     return list;
223   }
224 
225   gtl::FlatSet<string> DenyList() override {
226     if (IsPseudoFastMath()) {
227       return gtl::FlatSet<string>{};
228     }
229 
230     auto list = gtl::FlatSet<string>{
231         "Exp",
232         "Expm1",
233         "L2Loss",
234         "Mean",
235         "Pow",
236         "SaveV2",
237         "SoftmaxCrossEntropyWithLogits",
238         "SparseSoftmaxCrossEntropyWithLogits",
239         "Sum",
240     };
241     UpdateList("DENYLIST", &list);
242     // For backwards compatibility, keeping the original env variable here.
243     // TODO(reedwm): This should be removed if we don't have active users.
244     UpdateList("BLACKLIST", &list);
245     return list;
246   }
247 
248   gtl::FlatSet<string> ClearList() override {
249     if (IsPseudoFastMath()) {
250       return gtl::FlatSet<string>{};
251     }
252 
253     auto list = gtl::FlatSet<string>{
254         "Abs",
255         "ArgMax",
256         "ArgMin",
257         "BatchToSpace",
258         "BatchToSpaceND",
259         "BroadcastTo",
260         "Ceil",
261         "CheckNumerics",
262         "ClipByValue",
263         "Concat",
264         "ConcatV2",
265         "DepthToSpace",
266         "DynamicPartition",
267         "DynamicStitch",
268         "Enter",
269         "EnsureShape",
270         "Equal",
271         "Exit",
272         "ExpandDims",
273         "Fill",
274         "Floor",
275         "Gather",
276         "GatherNd",
277         "GatherV2",
278         "Greater",
279         "GreaterEqual",
280         "Identity",
281         "IdentityN",
282         "IsFinite",
283         "IsInf",
284         "IsNan",
285         "Less",
286         "LessEqual",
287         "Max",
288         "MaxPool",
289         "MaxPool3D",
290         "MaxPool3DGrad",
291         "MaxPool3DGradGrad",
292         "MaxPoolGrad",
293         "MaxPoolGradGrad",
294         "MaxPoolGradGradV2",
295         "MaxPoolGradV2",
296         "MaxPoolV2",
297         "Maximum",
298         "Merge",
299         "Min",
300         "Minimum",
301         "MirrorPad",
302         "MirrorPadGrad",
303         "Neg",
304         "NextIteration",
305         "NotEqual",
306         "OneHot",
307         "OnesLike",
308         "Pack",
309         "Pad",
310         "PadV2",
311         "PreventGradient",
312         "Rank",
313         "Relu",
314         "Relu6",
315         "Relu6Grad",
316         "ReluGrad",
317         "Reshape",
318         "ResizeNearestNeighbor",
319         "ResizeNearestNeighborGrad",
320         "Reverse",
321         "ReverseSequence",
322         "ReverseV2",
323         "Round",
324         "Select",
325         "SelectV2",
326         "Shape",
327         "ShapeN",
328         "Sign",
329         "Size",
330         "Slice",
331         "Snapshot",
332         "SpaceToBatch",
333         "SpaceToBatchND",
334         "SpaceToDepth",
335         "Split",
336         "SplitV",
337         "Squeeze",
338         "StopGradient",
339         "StridedSlice",
340         "StridedSliceGrad",
341         "Switch",
342         "Tile",
343         "TopK",
344         "TopKV2",
345         "Transpose",
346         "Unpack",
347         "Where",
348         "ZerosLike",
349     };
350     AddTensorListOps(&list);
351     UpdateList("CLEARLIST", &list);
352     return list;
353   }
354 
355  private:
356   int cuda_version_;
357   int cudnn_version_;
358 };
359 
360 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
361  public:
AutoMixedPrecisionListsMkl()362   AutoMixedPrecisionListsMkl() {}
363 
364   // Only ops which are supported by MKL in bfloat16 should be added to the
365   // allow list, infer list, or clear list.
AllowList()366   gtl::FlatSet<string> AllowList() override {
367     auto list = gtl::FlatSet<string>{"Conv2D",
368                                      "Conv2DBackpropFilter",
369                                      "Conv2DBackpropInput",
370                                      "Conv3D",
371                                      "Conv3DBackpropFilterV2",
372                                      "Conv3DBackpropInputV2",
373                                      "DepthwiseConv2dNative",
374                                      "DepthwiseConv2dNativeBackpropFilter",
375                                      "DepthwiseConv2dNativeBackpropInput",
376                                      "MatMul",
377                                      "BatchMatMul",
378                                      "BatchMatMulV2"};
379 
380     UpdateList("ALLOWLIST", &list);
381     // For backwards compatibility, keeping the original env variable here.
382     // TODO(reedwm): This should be removed if we don't have active users.
383     UpdateList("WHITELIST", &list);
384     return list;
385   }
386 
InferList()387   gtl::FlatSet<string> InferList() override {
388     auto list = gtl::FlatSet<string>{"Add",
389                                      "AddN",
390                                      "AddV2",
391                                      "AvgPool",
392                                      "AvgPool3D",
393                                      "AvgPool3DGrad",
394                                      "AvgPoolGrad",
395                                      "BiasAdd",
396                                      "BiasAddGrad",
397                                      "BiasAddV1",
398                                      "FusedBatchNormV2",
399                                      "FusedBatchNormGradV2",
400                                      "FusedBatchNormV3",
401                                      "FusedBatchNormGradV3",
402                                      "LeakyRelu",
403                                      "LeakyReluGrad",
404                                      "Mul",
405                                      "Sub",
406                                      "Elu",
407                                      "EluGrad",
408                                      "FloorDiv",
409                                      "_FusedBatchNormEx",
410                                      "Log",
411                                      "Log1p",
412                                      "LogSoftmax",
413                                      "Prod",
414                                      "RealDiv",
415                                      "Reciprocal",
416                                      "Selu",
417                                      "SeluGrad",
418                                      "Sigmoid",
419                                      "SigmoidGrad",
420                                      "Softmax",
421                                      "Softplus",
422                                      "SoftplusGrad",
423                                      "Softsign",
424                                      "SoftsignGrad",
425                                      "Sqrt",
426                                      "Tanh",
427                                      "TanhGrad"};
428     UpdateList("INFERLIST", &list);
429     // For backwards compatibility, keeping the original env variable here.
430     // TODO(reedwm): This should be removed if we don't have active users.
431     UpdateList("GRAYLIST", &list);
432     return list;
433   }
434 
DenyList()435   gtl::FlatSet<string> DenyList() override {
436     auto list = gtl::FlatSet<string>{
437         "Exp",
438         "Expm1",
439         "L2Loss",
440         "Mean",
441         "Pow",
442         "SaveV2",
443         "SoftmaxCrossEntropyWithLogits",
444         "SparseSoftmaxCrossEntropyWithLogits",
445         "Sum",
446     };
447     UpdateList("DENYLIST", &list);
448     // For backwards compatibility, keeping the original env variable here.
449     // TODO(reedwm): This should be removed if we don't have active users.
450     UpdateList("BLACKLIST", &list);
451     return list;
452   }
453 
ClearList()454   gtl::FlatSet<string> ClearList() override {
455     auto list = gtl::FlatSet<string>{
456         "Abs",
457         "ArgMax",
458         "ArgMin",
459         "BatchToSpace",
460         "BatchToSpaceND",
461         "BroadcastTo",
462         "Ceil",
463         "CheckNumerics",
464         "ClipByValue",
465         "Concat",
466         "ConcatV2",
467         "DepthToSpace",
468         "DynamicPartition",
469         "DynamicStitch",
470         "EnsureShape",
471         "Enter",
472         "Equal",
473         "Exit",
474         "ExpandDims",
475         "Fill",
476         "Floor",
477         "Gather",
478         "GatherNd",
479         "GatherV2",
480         "Greater",
481         "GreaterEqual",
482         "Identity",
483         "IsFinite",
484         "IsInf",
485         "IsNan",
486         "Less",
487         "LessEqual",
488         "Max",
489         "Maximum",
490         "MaxPool",
491         "MaxPool3D",
492         "MaxPool3DGrad",
493         "MaxPoolGrad",
494         "MaxPoolGradGrad",
495         "MaxPoolGradGradV2",
496         "MaxPoolGradV2",
497         "MaxPoolV2",
498         "Merge",
499         "Min",
500         "Minimum",
501         "MirrorPad",
502         "MirrorPadGrad",
503         "Neg",
504         "NextIteration",
505         "NotEqual",
506         "OnesLike",
507         "Pack",
508         "Pad",
509         "PadV2",
510         "PreventGradient",
511         "Rank",
512         "Relu",
513         "Relu6",
514         "Relu6Grad",
515         "ReluGrad",
516         "Reshape",
517         "ResizeNearestNeighbor",
518         "ResizeNearestNeighborGrad",
519         "Reverse",
520         "ReverseSequence",
521         "ReverseV2",
522         "Round",
523         "Select",
524         "SelectV2",
525         "Shape",
526         "ShapeN",
527         "Sign",
528         "Slice",
529         "Snapshot",
530         "SpaceToBatch",
531         "SpaceToBatchND",
532         "SpaceToDepth",
533         "Split",
534         "SplitV",
535         "Squeeze",
536         "StopGradient",
537         "StridedSlice",
538         "StridedSliceGrad",
539         "Switch",
540         "Tile",
541         "TopK",
542         "TopKV2",
543         "Transpose",
544         "Where",
545         "Unpack",
546         "ZerosLike",
547     };
548     AddTensorListOps(&list);
549     UpdateList("CLEARLIST", &list);
550     return list;
551   }
552 };
553 
554 }  // end namespace grappler
555 }  // end namespace tensorflow
556 
557 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
558