1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/lib/gtl/flatset.h" 22 #include "tensorflow/core/lib/strings/str_util.h" 23 #include "tensorflow/core/util/env_var.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 28 // Represents the four lists of ops: the allow list, infer list, deny list, and 29 // clear list. These lists determine which ops are converted to fp16/bf16 30 // (referred to as 'f16' for short) and which ops stay as fp32. 31 class AutoMixedPrecisionLists { 32 public: ~AutoMixedPrecisionLists()33 virtual ~AutoMixedPrecisionLists() {} 34 35 // Returns the set of ops that are considered numerically-safe (for execution 36 // in f16), performance-critical, and can run in f16. These ops are always 37 // converted to f16. 38 virtual gtl::FlatSet<string> AllowList() = 0; 39 // Returns the set of ops that can run in f16 and are considered numerically- 40 // safe (for execution in f16), but which may be made unsafe by an upstream 41 // denylist op. 42 virtual gtl::FlatSet<string> InferList() = 0; 43 // Returns the set of ops that are considered numerically-dangerous (i.e., 44 // unsafe for execution in f16) and whose effects may also be observed in 45 // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to 46 // the Exp). 47 virtual gtl::FlatSet<string> DenyList() = 0; 48 // Returns the set of ops that do not have numerically-significant effects 49 // (i.e., they are always considered safe for execution in f16 precision), and 50 // can run in f16. 51 virtual gtl::FlatSet<string> ClearList() = 0; 52 53 protected: 54 // Adds or removes ops from list if certain environmental variables are set. UpdateList(const string & list_name,gtl::FlatSet<string> * list)55 static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) { 56 CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK. 57 list_name == "DENYLIST" || list_name == "CLEARLIST" || 58 // TODO(reedwm): for bkwds compat; remove when no longer necessary: 59 list_name == "WHITELIST" || list_name == "GRAYLIST" || 60 list_name == "BLACKLIST"); 61 string add_env_var = 62 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD"; 63 string remove_env_var = 64 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE"; 65 string to_add, to_remove; 66 TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add)); 67 TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove)); 68 for (const auto& x : str_util::Split(to_add, ",")) { 69 list->insert(x); 70 } 71 for (const auto& x : str_util::Split(to_remove, ",")) { 72 list->erase(x); 73 } 74 } 75 76 // Subclasses should include these on the ClearList. AddTensorListOps(gtl::FlatSet<string> * list)77 static void AddTensorListOps(gtl::FlatSet<string>* list) { 78 // Note: if a data structure op (such as TensorListPopBack) is added here, 79 // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified 80 // LINT.IfChange 81 constexpr const char* tensor_list_ops[] = { 82 "TensorListConcat", "TensorListConcatLists", 83 "TensorListConcatV2", "TensorListGather", 84 "TensorListGetItem", "TensorListPopBack", 85 "TensorListPushBack", "TensorListPushBackBatch", 86 "TensorListFromTensor", "TensorListScatter", 87 "TensorListScatterV2", "TensorListScatterIntoExistingList", 88 "TensorListSetItem", "TensorListSplit", 89 "TensorListStack"}; 90 // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc) 91 for (auto op : tensor_list_ops) { 92 list->insert(op); 93 } 94 } 95 }; 96 97 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { 98 private: IsPseudoFastMath()99 static bool IsPseudoFastMath() { 100 string optimization_level; 101 TF_CHECK_OK( 102 ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", 103 &optimization_level)); 104 optimization_level = str_util::Uppercase(optimization_level); 105 return optimization_level == "TENSOR_CORES_ONLY"; 106 } 107 108 public: AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)109 AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) 110 : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} 111 AllowList()112 gtl::FlatSet<string> AllowList() override { 113 auto list = gtl::FlatSet<string>{ 114 "BlockLSTM", 115 "BlockLSTMV2", 116 "BlockLSTMGrad", 117 "BlockLSTMGradV2", 118 "Conv2D", 119 "Conv2DBackpropFilter", 120 "Conv2DBackpropInput", 121 "CudnnRNN", 122 "CudnnRNNBackprop", 123 "CudnnRNNBackpropV2", 124 "CudnnRNNBackpropV3", 125 "CudnnRNNV2", 126 "CudnnRNNV3", 127 "Einsum", 128 "FusedConv2DBiasActivation", 129 "FusedSparseConvGpuV2", 130 "GRUBlockCell", 131 "GRUBlockCellGrad", 132 "LSTMBlockCell", 133 "LSTMBlockCellGrad", 134 "MatMul", 135 "Mha", 136 "Tmlp", 137 }; 138 #if TENSORFLOW_USE_ROCM 139 if (true) { 140 #else 141 if (cuda_version_ >= 9010) { 142 // Fp16 BatchMatMul is slow before CUDA 9.1. 143 #endif 144 list.insert("BatchMatMul"); 145 list.insert("BatchMatMulV2"); 146 } 147 if (cudnn_version_ >= 7602) { 148 // Fp16 3D conv is slow before CUDNN 7.6.2. 149 list.insert("Conv3D"); 150 list.insert("Conv3DBackpropFilter"); 151 list.insert("Conv3DBackpropFilterV2"); 152 list.insert("Conv3DBackpropInput"); 153 list.insert("Conv3DBackpropInputV2"); 154 } 155 if (cudnn_version_ >= 8000) { 156 list.insert("DepthwiseConv2dNative"); 157 list.insert("DepthwiseConv2dNativeBackpropFilter"); 158 list.insert("DepthwiseConv2dNativeBackpropInput"); 159 } 160 UpdateList("ALLOWLIST", &list); 161 // For backwards compatibility, keeping the original env variable here. 162 // TODO(reedwm): This should be removed if we don't have active users. 163 UpdateList("WHITELIST", &list); 164 165 return list; 166 } 167 168 gtl::FlatSet<string> InferList() override { 169 if (IsPseudoFastMath()) { 170 return gtl::FlatSet<string>{}; 171 } 172 173 auto list = gtl::FlatSet<string>{ 174 "Add", 175 "AddN", 176 "AddV2", 177 "AvgPool", 178 "AvgPool3D", 179 "AvgPool3DGrad", 180 "AvgPoolGrad", 181 "BiasAdd", 182 "BiasAddGrad", 183 "BiasAddV1", 184 "Elu", 185 "EluGrad", 186 "Erf", 187 "Erfc", 188 "FloorDiv", 189 "FusedBatchNormV2", 190 "FusedBatchNormGradV2", 191 "FusedBatchNormV3", 192 "FusedBatchNormGradV3", 193 "_FusedBatchNormEx", 194 "Inv", 195 "LeakyRelu", 196 "LeakyReluGrad", 197 "Log", 198 "Log1p", 199 "LogSoftmax", 200 "Mul", 201 "Prod", 202 "RealDiv", 203 "Reciprocal", 204 "Selu", 205 "SeluGrad", 206 "Sigmoid", 207 "SigmoidGrad", 208 "Softmax", 209 "Softplus", 210 "SoftplusGrad", 211 "Softsign", 212 "SoftsignGrad", 213 "Sqrt", 214 "Sub", 215 "Tanh", 216 "TanhGrad", 217 }; 218 UpdateList("INFERLIST", &list); 219 // For backwards compatibility, keeping the original env variable here. 220 // TODO(reedwm): This should be removed if we don't have active users. 221 UpdateList("GRAYLIST", &list); 222 return list; 223 } 224 225 gtl::FlatSet<string> DenyList() override { 226 if (IsPseudoFastMath()) { 227 return gtl::FlatSet<string>{}; 228 } 229 230 auto list = gtl::FlatSet<string>{ 231 "Exp", 232 "Expm1", 233 "L2Loss", 234 "Mean", 235 "Pow", 236 "SaveV2", 237 "SoftmaxCrossEntropyWithLogits", 238 "SparseSoftmaxCrossEntropyWithLogits", 239 "Sum", 240 }; 241 UpdateList("DENYLIST", &list); 242 // For backwards compatibility, keeping the original env variable here. 243 // TODO(reedwm): This should be removed if we don't have active users. 244 UpdateList("BLACKLIST", &list); 245 return list; 246 } 247 248 gtl::FlatSet<string> ClearList() override { 249 if (IsPseudoFastMath()) { 250 return gtl::FlatSet<string>{}; 251 } 252 253 auto list = gtl::FlatSet<string>{ 254 "Abs", 255 "ArgMax", 256 "ArgMin", 257 "BatchToSpace", 258 "BatchToSpaceND", 259 "BroadcastTo", 260 "Ceil", 261 "CheckNumerics", 262 "ClipByValue", 263 "Concat", 264 "ConcatV2", 265 "DepthToSpace", 266 "DynamicPartition", 267 "DynamicStitch", 268 "Enter", 269 "EnsureShape", 270 "Equal", 271 "Exit", 272 "ExpandDims", 273 "Fill", 274 "Floor", 275 "Gather", 276 "GatherNd", 277 "GatherV2", 278 "Greater", 279 "GreaterEqual", 280 "Identity", 281 "IdentityN", 282 "IsFinite", 283 "IsInf", 284 "IsNan", 285 "Less", 286 "LessEqual", 287 "Max", 288 "MaxPool", 289 "MaxPool3D", 290 "MaxPool3DGrad", 291 "MaxPool3DGradGrad", 292 "MaxPoolGrad", 293 "MaxPoolGradGrad", 294 "MaxPoolGradGradV2", 295 "MaxPoolGradV2", 296 "MaxPoolV2", 297 "Maximum", 298 "Merge", 299 "Min", 300 "Minimum", 301 "MirrorPad", 302 "MirrorPadGrad", 303 "Neg", 304 "NextIteration", 305 "NotEqual", 306 "OneHot", 307 "OnesLike", 308 "Pack", 309 "Pad", 310 "PadV2", 311 "PreventGradient", 312 "Rank", 313 "Relu", 314 "Relu6", 315 "Relu6Grad", 316 "ReluGrad", 317 "Reshape", 318 "ResizeNearestNeighbor", 319 "ResizeNearestNeighborGrad", 320 "Reverse", 321 "ReverseSequence", 322 "ReverseV2", 323 "Round", 324 "Select", 325 "SelectV2", 326 "Shape", 327 "ShapeN", 328 "Sign", 329 "Size", 330 "Slice", 331 "Snapshot", 332 "SpaceToBatch", 333 "SpaceToBatchND", 334 "SpaceToDepth", 335 "Split", 336 "SplitV", 337 "Squeeze", 338 "StopGradient", 339 "StridedSlice", 340 "StridedSliceGrad", 341 "Switch", 342 "Tile", 343 "TopK", 344 "TopKV2", 345 "Transpose", 346 "Unpack", 347 "Where", 348 "ZerosLike", 349 }; 350 AddTensorListOps(&list); 351 UpdateList("CLEARLIST", &list); 352 return list; 353 } 354 355 private: 356 int cuda_version_; 357 int cudnn_version_; 358 }; 359 360 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { 361 public: AutoMixedPrecisionListsMkl()362 AutoMixedPrecisionListsMkl() {} 363 364 // Only ops which are supported by MKL in bfloat16 should be added to the 365 // allow list, infer list, or clear list. AllowList()366 gtl::FlatSet<string> AllowList() override { 367 auto list = gtl::FlatSet<string>{"Conv2D", 368 "Conv2DBackpropFilter", 369 "Conv2DBackpropInput", 370 "Conv3D", 371 "Conv3DBackpropFilterV2", 372 "Conv3DBackpropInputV2", 373 "DepthwiseConv2dNative", 374 "DepthwiseConv2dNativeBackpropFilter", 375 "DepthwiseConv2dNativeBackpropInput", 376 "MatMul", 377 "BatchMatMul", 378 "BatchMatMulV2"}; 379 380 UpdateList("ALLOWLIST", &list); 381 // For backwards compatibility, keeping the original env variable here. 382 // TODO(reedwm): This should be removed if we don't have active users. 383 UpdateList("WHITELIST", &list); 384 return list; 385 } 386 InferList()387 gtl::FlatSet<string> InferList() override { 388 auto list = gtl::FlatSet<string>{"Add", 389 "AddN", 390 "AddV2", 391 "AvgPool", 392 "AvgPool3D", 393 "AvgPool3DGrad", 394 "AvgPoolGrad", 395 "BiasAdd", 396 "BiasAddGrad", 397 "BiasAddV1", 398 "FusedBatchNormV2", 399 "FusedBatchNormGradV2", 400 "FusedBatchNormV3", 401 "FusedBatchNormGradV3", 402 "LeakyRelu", 403 "LeakyReluGrad", 404 "Mul", 405 "Sub", 406 "Elu", 407 "EluGrad", 408 "FloorDiv", 409 "_FusedBatchNormEx", 410 "Log", 411 "Log1p", 412 "LogSoftmax", 413 "Prod", 414 "RealDiv", 415 "Reciprocal", 416 "Selu", 417 "SeluGrad", 418 "Sigmoid", 419 "SigmoidGrad", 420 "Softmax", 421 "Softplus", 422 "SoftplusGrad", 423 "Softsign", 424 "SoftsignGrad", 425 "Sqrt", 426 "Tanh", 427 "TanhGrad"}; 428 UpdateList("INFERLIST", &list); 429 // For backwards compatibility, keeping the original env variable here. 430 // TODO(reedwm): This should be removed if we don't have active users. 431 UpdateList("GRAYLIST", &list); 432 return list; 433 } 434 DenyList()435 gtl::FlatSet<string> DenyList() override { 436 auto list = gtl::FlatSet<string>{ 437 "Exp", 438 "Expm1", 439 "L2Loss", 440 "Mean", 441 "Pow", 442 "SaveV2", 443 "SoftmaxCrossEntropyWithLogits", 444 "SparseSoftmaxCrossEntropyWithLogits", 445 "Sum", 446 }; 447 UpdateList("DENYLIST", &list); 448 // For backwards compatibility, keeping the original env variable here. 449 // TODO(reedwm): This should be removed if we don't have active users. 450 UpdateList("BLACKLIST", &list); 451 return list; 452 } 453 ClearList()454 gtl::FlatSet<string> ClearList() override { 455 auto list = gtl::FlatSet<string>{ 456 "Abs", 457 "ArgMax", 458 "ArgMin", 459 "BatchToSpace", 460 "BatchToSpaceND", 461 "BroadcastTo", 462 "Ceil", 463 "CheckNumerics", 464 "ClipByValue", 465 "Concat", 466 "ConcatV2", 467 "DepthToSpace", 468 "DynamicPartition", 469 "DynamicStitch", 470 "EnsureShape", 471 "Enter", 472 "Equal", 473 "Exit", 474 "ExpandDims", 475 "Fill", 476 "Floor", 477 "Gather", 478 "GatherNd", 479 "GatherV2", 480 "Greater", 481 "GreaterEqual", 482 "Identity", 483 "IsFinite", 484 "IsInf", 485 "IsNan", 486 "Less", 487 "LessEqual", 488 "Max", 489 "Maximum", 490 "MaxPool", 491 "MaxPool3D", 492 "MaxPool3DGrad", 493 "MaxPoolGrad", 494 "MaxPoolGradGrad", 495 "MaxPoolGradGradV2", 496 "MaxPoolGradV2", 497 "MaxPoolV2", 498 "Merge", 499 "Min", 500 "Minimum", 501 "MirrorPad", 502 "MirrorPadGrad", 503 "Neg", 504 "NextIteration", 505 "NotEqual", 506 "OnesLike", 507 "Pack", 508 "Pad", 509 "PadV2", 510 "PreventGradient", 511 "Rank", 512 "Relu", 513 "Relu6", 514 "Relu6Grad", 515 "ReluGrad", 516 "Reshape", 517 "ResizeNearestNeighbor", 518 "ResizeNearestNeighborGrad", 519 "Reverse", 520 "ReverseSequence", 521 "ReverseV2", 522 "Round", 523 "Select", 524 "SelectV2", 525 "Shape", 526 "ShapeN", 527 "Sign", 528 "Slice", 529 "Snapshot", 530 "SpaceToBatch", 531 "SpaceToBatchND", 532 "SpaceToDepth", 533 "Split", 534 "SplitV", 535 "Squeeze", 536 "StopGradient", 537 "StridedSlice", 538 "StridedSliceGrad", 539 "Switch", 540 "Tile", 541 "TopK", 542 "TopKV2", 543 "Transpose", 544 "Where", 545 "Unpack", 546 "ZerosLike", 547 }; 548 AddTensorListOps(&list); 549 UpdateList("CLEARLIST", &list); 550 return list; 551 } 552 }; 553 554 } // end namespace grappler 555 } // end namespace tensorflow 556 557 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 558