xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/cluster_function_conversion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Value.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Pass/PassManager.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "mlir/Transforms/Passes.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
34 #include "tensorflow/dtensor/cc/constants.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
37 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
38 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
39 #include "tensorflow/dtensor/mlir/layout_parsing.h"
40 #include "tensorflow/dtensor/mlir/op_utils.h"
41 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
42 
43 namespace tensorflow {
44 namespace dtensor {
45 namespace {
46 
47 // Attach layouts for all the returned values so that custom device could get
48 // layouts for the handles.
AttachRetvalLayouts(mlir::OpBuilder * builder,mlir::TF::StatefulPartitionedCallOp sp_call_op)49 mlir::LogicalResult AttachRetvalLayouts(
50     mlir::OpBuilder* builder, mlir::TF::StatefulPartitionedCallOp sp_call_op) {
51   // Find the FuncOp that the StatefulPartitionedCallOp is invoking.
52   mlir::SymbolRefAttr sym =
53       sp_call_op.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>();
54   if (!sym)
55     return sp_call_op.emitOpError(
56         "has no symbolRef for given StatefulPartitionedCallOp");
57 
58   auto func = mlir::dyn_cast<mlir::func::FuncOp>(
59       mlir::SymbolTable::lookupNearestSymbolFrom(sp_call_op, sym));
60   if (!func)
61     return sp_call_op.emitOpError() << "found no FuncOp for symbol " << sym;
62 
63   llvm::SmallVector<absl::optional<Layout>, 8> retvals_layouts;
64   retvals_layouts.reserve(func.getNumResults());
65   for (auto operand : func.front().getTerminator()->getOperands()) {
66     auto result_layout_or_status = ExtractLayoutFromOperand(operand);
67     if (!result_layout_or_status.ok()) {
68       return func.emitOpError("error while parsing result layout for function");
69     }
70 
71     auto result_layout = result_layout_or_status.ValueOrDie();
72 
73     // When function returns its arguments directly, layout information for the
74     // return value of `func` may be only obtainable by looking at it's callsite
75     // operations. In that case, query the input layouts for function callsite
76     // operations for layout information.
77     if (!result_layout) {
78       if (auto block_arg = operand.dyn_cast<mlir::BlockArgument>()) {
79         auto layout_or_status = ExtractLayoutFromOperand(
80             sp_call_op.getOperand(block_arg.getArgNumber()));
81         if (!layout_or_status.ok())
82           return func.emitOpError(
83               "error while parsing result layout for function");
84         result_layout = std::move(layout_or_status.ValueOrDie());
85       }
86 
87       if (!result_layout)
88         return func.emitOpError(
89             llvm::formatv("missing result layout attribute for function. All "
90                           "DTensor functions "
91                           "must have layouts for its results."));
92     }
93     retvals_layouts.emplace_back(result_layout.value());
94   }
95 
96   // Note that we set this unconditionally - retvals_layout could be empty, but
97   // that is fine and we will have an empty _layout for the
98   // StatefulPartitionedCallOp. This is fine as for op without return values,
99   // all we need is a placeholder layout so that no special case is needed in
100   // dtensor_device.
101   SetLayoutOnOp(sp_call_op,
102                 absl::Span<const absl::optional<Layout>>(
103                     retvals_layouts.data(), retvals_layouts.size()));
104 
105   return mlir::success();
106 }
107 
108 // Add an anotation to skip xla compilation for VarHandleOp and
109 // DestroyResourceOp.
MaybeSkipXlaCompilation(mlir::OpBuilder * builder,mlir::Operation * call_op)110 void MaybeSkipXlaCompilation(mlir::OpBuilder* builder,
111                              mlir::Operation* call_op) {
112   auto function = MaybeFindFunction(call_op);
113   const auto& body_ops = function->getBody().front().without_terminator();
114   // VarHandleOp and DestroyResourceOp run on op-by-op mode, so there is only
115   // one op in the function body.
116   if (std::distance(std::begin(body_ops), std::end(body_ops)) == 1 &&
117       llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp>(
118           body_ops.begin())) {
119     call_op->setAttr(kSkipXlaCompilation, builder->getBoolAttr(true));
120   }
121 }
122 
ReplaceClusterWithPartitionCallOp(mlir::OpBuilder * builder,mlir::tf_device::ClusterFuncOp cluster_func)123 mlir::LogicalResult ReplaceClusterWithPartitionCallOp(
124     mlir::OpBuilder* builder, mlir::tf_device::ClusterFuncOp cluster_func) {
125   auto mesh_attr = cluster_func->getAttrOfType<mlir::StringAttr>(kMeshAttr);
126   if (!mesh_attr)
127     return cluster_func.emitOpError()
128            << "requires " << llvm::StringRef(kMeshAttr) << " attribute";
129 
130   llvm::SmallVector<mlir::Type, 8> output_types{
131       cluster_func.getResultTypes().begin(),
132       cluster_func.getResultTypes().end()};
133 
134   auto function_name = cluster_func.funcAttr();
135 
136   builder->setInsertionPoint(cluster_func);
137   auto call_op = builder->create<mlir::TF::StatefulPartitionedCallOp>(
138       cluster_func.getLoc(), output_types, cluster_func.getOperands(),
139       function_name, mesh_attr, /*config_proto=*/builder->getStringAttr(""),
140       /*executor_type=*/builder->getStringAttr(""));
141 
142   MaybeSkipXlaCompilation(builder, call_op);
143 
144   if (mlir::failed(ValidateMetadataAttributes(cluster_func)))
145     return mlir::failure();
146 
147   // All attributes beginning with `_` is validate, perform copy.
148   mlir::TF::CopyUnderscoredAttributes(cluster_func, call_op);
149 
150   cluster_func.replaceAllUsesWith(call_op.getResults());
151   cluster_func.erase();
152 
153   return AttachRetvalLayouts(builder, call_op);
154 }
155 
156 // MLIR pass that converts tf_device.cluster_func to TF partitioned call
157 // op with device mesh config added to `config` attribute.
158 struct DTensorClusterFunctionConversion
159     : public DTensorClusterFunctionConversionBase<
160           DTensorClusterFunctionConversion> {
runOnOperationtensorflow::dtensor::__anon6078223c0111::DTensorClusterFunctionConversion161   void runOnOperation() override {
162     mlir::MLIRContext& context = getContext();
163 
164     // Find all tf_device.ClusterFunc ops and visit them in post order. This
165     // order guarantees that ops in function definition is visited before
166     // function call site operations. When python graph includes tf.functions
167     // this leads to nested tf_device.ClusterFunc ops. As we infer the layout
168     // of function call operations with layout attached to return values in the
169     // function definition, ClusterFunc op in nested/inner functions must be
170     // visited before ClusterFunc op in outer functions.
171     llvm::SmallVector<mlir::tf_device::ClusterFuncOp, 8> clusters;
172     getOperation().walk([&](mlir::tf_device::ClusterFuncOp cluster_func) {
173       clusters.emplace_back(cluster_func);
174     });
175 
176     mlir::OpBuilder op_builder(&context);
177     for (auto cluster_func : llvm::reverse(clusters)) {
178       if (mlir::failed(
179               ReplaceClusterWithPartitionCallOp(&op_builder, cluster_func))) {
180         return signalPassFailure();
181       }
182     }
183   };
184 };
185 
186 }  // namespace
187 
188 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorClusterFunctionConversion()189 CreateDTensorClusterFunctionConversion() {
190   return std::make_unique<DTensorClusterFunctionConversion>();
191 }
192 
193 }  // namespace dtensor
194 }  // namespace tensorflow
195