xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_
17 #define TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_
18 
19 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
20 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
22 
23 namespace tensorflow {
24 
25 // This is a temporary control flag to gradually enable compilation for
26 // operations based on the correctness and performance confidence. For example
27 // Tier 1 operations are simple enough and well tested, so they can be safely
28 // enabled for all models. We'll be introducing new tiers based on the
29 // completeness of lowering and testing, and eventually will remove this flag.
30 enum class JitRtClusteringTier : uint8_t {
31   kCwise = 0x1,
32   kTranspose = 0x2,
33   kMetadata = 0x4,        // shape, reshape, ...
34   kReductions = 0x8,      // all, any, min, max, mean, prod, sum
35   kGatherScatter = 0x10,  // gather, scatter, gather_v2,...
36 
37   // Only cwise operations (unary, binary, ternary).
38   kTier0 = kCwise,
39 
40   // All cwise operations (unary, binary, ternary) plus a tf.Transpose.
41   kTier1 = kCwise | kTranspose,
42 
43   // All tier 1 operations plus metadata operations (shape, reshape).
44   kTier1Metadata = kTier1 | kMetadata,
45 
46   // All tier 1 operations plus reductions.
47   kTier1Reductions = kTier1 | kReductions,
48 
49   // TODO(ezhulenev): Include metadata (shape, reshape) and slicing into tier 2?
50   // TODO(ezhulenev): Include reductions into tier 3?
51 
52   // All operations that do have clustering policy.
53   kAll = 0xff
54 };
55 
56 // Adds policies for clustering operations for TF->JitRt JIT compilation.
57 void populateTfJitRtClusteringPolicies(
58     mlir::TFDevice::ClusteringPolicySet& policies,
59     JitRtClusteringTier tier = JitRtClusteringTier::kAll);
60 
61 // Adds policies for propagating constraints through Tensorflow operations. We
62 // do not add `tf.Const` operations to the clusters, however before compilation
63 // we sink some of them into the cluster body, and to properly verify compiled
64 // function body and infer operands constraints we need a policy for constants.
65 void populateTfJitRtConstraintsPolicies(
66     mlir::TFDevice::ClusteringPolicySet& policies,
67     JitRtClusteringTier tier = JitRtClusteringTier::kAll);
68 
69 // Returns success if constant value can be sunk into the compiled function. We
70 // currently only support small integer constants that typically correspond to
71 // the reduction dimension, transpose permutation and other similar values that
72 // are required for successful compilation.
73 //
74 // We prefer to keep large constants as `tf.Const` operations outside of the
75 // compiled regions, and rely on the runtime to instantiate them as tensors.
76 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value);
77 
78 // Verifies that discovered operations cluster satisfies TF->JitRt JIT
79 // compilation constraints.
80 mlir::LogicalResult VerifyCluster(const mlir::TFDevice::Cluster& cluster);
81 
82 }  // namespace tensorflow
83 
84 #endif  // TENSORFLOW_COMPILER_MLIR_TFRT_JIT_TF_JITRT_CLUSTERING_H_
85