1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_ 17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_ 18 19 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 20 #include "mlir/Support/LogicalResult.h" // from @llvm-project 21 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" 22 23 namespace tensorflow { 24 25 // This is a temporary control flag to gradually enable compilation for 26 // operations based on the correctness and performance confidence. For example 27 // Tier 1 operations are simple enough and well tested, so they can be safely 28 // enabled for all models. We'll be introducing new tiers based on the 29 // completeness of lowering and testing, and eventually will remove this flag. 30 enum class JitRtClusteringTier : uint8_t { 31 kCwise = 0x1, 32 kTranspose = 0x2, 33 kMetadata = 0x4, // shape, reshape, ... 34 kReductions = 0x8, // all, any, min, max, mean, prod, sum 35 kGatherScatter = 0x10, // gather, scatter, gather_v2,... 36 37 // Only cwise operations (unary, binary, ternary). 38 kTier0 = kCwise, 39 40 // All cwise operations (unary, binary, ternary) plus a tf.Transpose. 41 kTier1 = kCwise | kTranspose, 42 43 // All tier 1 operations plus metadata operations (shape, reshape). 44 kTier1Metadata = kTier1 | kMetadata, 45 46 // All tier 1 operations plus reductions. 47 kTier1Reductions = kTier1 | kReductions, 48 49 // TODO(ezhulenev): Include metadata (shape, reshape) and slicing into tier 2? 50 // TODO(ezhulenev): Include reductions into tier 3? 51 52 // All operations that do have clustering policy. 53 kAll = 0xff 54 }; 55 56 // Adds policies for clustering operations for TF->JitRt JIT compilation. 57 void populateTfJitRtClusteringPolicies( 58 mlir::TFDevice::ClusteringPolicySet& policies, 59 JitRtClusteringTier tier = JitRtClusteringTier::kAll); 60 61 // Adds policies for propagating constraints through Tensorflow operations. We 62 // do not add `tf.Const` operations to the clusters, however before compilation 63 // we sink some of them into the cluster body, and to properly verify compiled 64 // function body and infer operands constraints we need a policy for constants. 65 void populateTfJitRtConstraintsPolicies( 66 mlir::TFDevice::ClusteringPolicySet& policies, 67 JitRtClusteringTier tier = JitRtClusteringTier::kAll); 68 69 // Returns success if constant value can be sunk into the compiled function. We 70 // currently only support small integer constants that typically correspond to 71 // the reduction dimension, transpose permutation and other similar values that 72 // are required for successful compilation. 73 // 74 // We prefer to keep large constants as `tf.Const` operations outside of the 75 // compiled regions, and rely on the runtime to instantiate them as tensors. 76 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value); 77 78 // Verifies that discovered operations cluster satisfies TF->JitRt JIT 79 // compilation constraints. 80 mlir::LogicalResult VerifyCluster(const mlir::TFDevice::Cluster& cluster); 81 82 } // namespace tensorflow 83 84 #endif // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_ 85