1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/Diagnostics.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Value.h" // from @llvm-project
27 #include "mlir/Pass/Pass.h" // from @llvm-project
28 #include "mlir/Pass/PassManager.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "mlir/Transforms/Passes.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
34 #include "tensorflow/dtensor/cc/constants.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
37 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
38 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
39 #include "tensorflow/dtensor/mlir/layout_parsing.h"
40 #include "tensorflow/dtensor/mlir/op_utils.h"
41 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
42
43 namespace tensorflow {
44 namespace dtensor {
45 namespace {
46
47 // Attach layouts for all the returned values so that custom device could get
48 // layouts for the handles.
AttachRetvalLayouts(mlir::OpBuilder * builder,mlir::TF::StatefulPartitionedCallOp sp_call_op)49 mlir::LogicalResult AttachRetvalLayouts(
50 mlir::OpBuilder* builder, mlir::TF::StatefulPartitionedCallOp sp_call_op) {
51 // Find the FuncOp that the StatefulPartitionedCallOp is invoking.
52 mlir::SymbolRefAttr sym =
53 sp_call_op.getCallableForCallee().dyn_cast<mlir::SymbolRefAttr>();
54 if (!sym)
55 return sp_call_op.emitOpError(
56 "has no symbolRef for given StatefulPartitionedCallOp");
57
58 auto func = mlir::dyn_cast<mlir::func::FuncOp>(
59 mlir::SymbolTable::lookupNearestSymbolFrom(sp_call_op, sym));
60 if (!func)
61 return sp_call_op.emitOpError() << "found no FuncOp for symbol " << sym;
62
63 llvm::SmallVector<absl::optional<Layout>, 8> retvals_layouts;
64 retvals_layouts.reserve(func.getNumResults());
65 for (auto operand : func.front().getTerminator()->getOperands()) {
66 auto result_layout_or_status = ExtractLayoutFromOperand(operand);
67 if (!result_layout_or_status.ok()) {
68 return func.emitOpError("error while parsing result layout for function");
69 }
70
71 auto result_layout = result_layout_or_status.ValueOrDie();
72
73 // When function returns its arguments directly, layout information for the
74 // return value of `func` may be only obtainable by looking at it's callsite
75 // operations. In that case, query the input layouts for function callsite
76 // operations for layout information.
77 if (!result_layout) {
78 if (auto block_arg = operand.dyn_cast<mlir::BlockArgument>()) {
79 auto layout_or_status = ExtractLayoutFromOperand(
80 sp_call_op.getOperand(block_arg.getArgNumber()));
81 if (!layout_or_status.ok())
82 return func.emitOpError(
83 "error while parsing result layout for function");
84 result_layout = std::move(layout_or_status.ValueOrDie());
85 }
86
87 if (!result_layout)
88 return func.emitOpError(
89 llvm::formatv("missing result layout attribute for function. All "
90 "DTensor functions "
91 "must have layouts for its results."));
92 }
93 retvals_layouts.emplace_back(result_layout.value());
94 }
95
96 // Note that we set this unconditionally - retvals_layout could be empty, but
97 // that is fine and we will have an empty _layout for the
98 // StatefulPartitionedCallOp. This is fine as for op without return values,
99 // all we need is a placeholder layout so that no special case is needed in
100 // dtensor_device.
101 SetLayoutOnOp(sp_call_op,
102 absl::Span<const absl::optional<Layout>>(
103 retvals_layouts.data(), retvals_layouts.size()));
104
105 return mlir::success();
106 }
107
108 // Add an anotation to skip xla compilation for VarHandleOp and
109 // DestroyResourceOp.
MaybeSkipXlaCompilation(mlir::OpBuilder * builder,mlir::Operation * call_op)110 void MaybeSkipXlaCompilation(mlir::OpBuilder* builder,
111 mlir::Operation* call_op) {
112 auto function = MaybeFindFunction(call_op);
113 const auto& body_ops = function->getBody().front().without_terminator();
114 // VarHandleOp and DestroyResourceOp run on op-by-op mode, so there is only
115 // one op in the function body.
116 if (std::distance(std::begin(body_ops), std::end(body_ops)) == 1 &&
117 llvm::isa<mlir::TF::VarHandleOp, mlir::TF::DestroyResourceOp>(
118 body_ops.begin())) {
119 call_op->setAttr(kSkipXlaCompilation, builder->getBoolAttr(true));
120 }
121 }
122
ReplaceClusterWithPartitionCallOp(mlir::OpBuilder * builder,mlir::tf_device::ClusterFuncOp cluster_func)123 mlir::LogicalResult ReplaceClusterWithPartitionCallOp(
124 mlir::OpBuilder* builder, mlir::tf_device::ClusterFuncOp cluster_func) {
125 auto mesh_attr = cluster_func->getAttrOfType<mlir::StringAttr>(kMeshAttr);
126 if (!mesh_attr)
127 return cluster_func.emitOpError()
128 << "requires " << llvm::StringRef(kMeshAttr) << " attribute";
129
130 llvm::SmallVector<mlir::Type, 8> output_types{
131 cluster_func.getResultTypes().begin(),
132 cluster_func.getResultTypes().end()};
133
134 auto function_name = cluster_func.funcAttr();
135
136 builder->setInsertionPoint(cluster_func);
137 auto call_op = builder->create<mlir::TF::StatefulPartitionedCallOp>(
138 cluster_func.getLoc(), output_types, cluster_func.getOperands(),
139 function_name, mesh_attr, /*config_proto=*/builder->getStringAttr(""),
140 /*executor_type=*/builder->getStringAttr(""));
141
142 MaybeSkipXlaCompilation(builder, call_op);
143
144 if (mlir::failed(ValidateMetadataAttributes(cluster_func)))
145 return mlir::failure();
146
147 // All attributes beginning with `_` is validate, perform copy.
148 mlir::TF::CopyUnderscoredAttributes(cluster_func, call_op);
149
150 cluster_func.replaceAllUsesWith(call_op.getResults());
151 cluster_func.erase();
152
153 return AttachRetvalLayouts(builder, call_op);
154 }
155
156 // MLIR pass that converts tf_device.cluster_func to TF partitioned call
157 // op with device mesh config added to `config` attribute.
158 struct DTensorClusterFunctionConversion
159 : public DTensorClusterFunctionConversionBase<
160 DTensorClusterFunctionConversion> {
runOnOperationtensorflow::dtensor::__anon6078223c0111::DTensorClusterFunctionConversion161 void runOnOperation() override {
162 mlir::MLIRContext& context = getContext();
163
164 // Find all tf_device.ClusterFunc ops and visit them in post order. This
165 // order guarantees that ops in function definition is visited before
166 // function call site operations. When python graph includes tf.functions
167 // this leads to nested tf_device.ClusterFunc ops. As we infer the layout
168 // of function call operations with layout attached to return values in the
169 // function definition, ClusterFunc op in nested/inner functions must be
170 // visited before ClusterFunc op in outer functions.
171 llvm::SmallVector<mlir::tf_device::ClusterFuncOp, 8> clusters;
172 getOperation().walk([&](mlir::tf_device::ClusterFuncOp cluster_func) {
173 clusters.emplace_back(cluster_func);
174 });
175
176 mlir::OpBuilder op_builder(&context);
177 for (auto cluster_func : llvm::reverse(clusters)) {
178 if (mlir::failed(
179 ReplaceClusterWithPartitionCallOp(&op_builder, cluster_func))) {
180 return signalPassFailure();
181 }
182 }
183 };
184 };
185
186 } // namespace
187
188 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorClusterFunctionConversion()189 CreateDTensorClusterFunctionConversion() {
190 return std::make_unique<DTensorClusterFunctionConversion>();
191 }
192
193 } // namespace dtensor
194 } // namespace tensorflow
195