1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "absl/types/optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/MLIRContext.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/Visitors.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/dtensor/cc/constants.h"
30 #include "tensorflow/dtensor/cc/tensor_layout.h"
31 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
32 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
35 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36 #include "tensorflow/dtensor/mlir/layout_parsing.h"
37 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38 #include "tensorflow/dtensor/mlir/value_utils.h"
39
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43
44 // Creates tf.DTensorLayout op that forwards `input` value.
CreateDTensorLayoutOp(const Layout & layout,mlir::Value input,mlir::TensorType & type,mlir::Location loc,mlir::OpBuilder * builder,mlir::MLIRContext * context)45 void CreateDTensorLayoutOp(const Layout& layout, mlir::Value input,
46 mlir::TensorType& type, mlir::Location loc,
47 mlir::OpBuilder* builder,
48 mlir::MLIRContext* context) {
49 if (layout.IsEmpty()) return;
50
51 auto layout_op = builder->create<mlir::TF::DTensorLayout>(
52 loc, input, mlir::dtensor::LayoutAttr::get(context, layout),
53 mlir::TF::ShapeAttr::get(context, type));
54 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
55 input.replaceAllUsesExcept(layout_op.output(), exception);
56 }
57
58 // Adds DTensorLayout op following each Relayout operation to ensure that
59 // tensor from `relayout` has fixed layout.
PropagateDTensorLayoutForRelayout(mlir::MLIRContext & c,mlir::TF::RelayoutOp relayout)60 mlir::LogicalResult PropagateDTensorLayoutForRelayout(
61 mlir::MLIRContext& c, mlir::TF::RelayoutOp relayout) {
62 const std::string layout_str = relayout.layout().str();
63 auto layout_or_status = Layout::FromString(layout_str);
64 if (!layout_or_status.ok()) {
65 return relayout.emitOpError(
66 llvm::formatv("found Relayout op with incorrect/unparsable layout. "
67 "Found layout: {0} ",
68 layout_str));
69 }
70 const Layout& layout = layout_or_status.ValueOrDie();
71
72 // Skip adding a DTensorLayout if Relayout is 'dynamic'. Any dimension with
73 // MATCH for the layout will have its layout preserved in layout propagation.
74 for (const std::string& sharding_spec : layout.sharding_spec_strs())
75 if (sharding_spec == Layout::kMatch) return mlir::success();
76
77 mlir::OpBuilder builder(relayout->getBlock(),
78 ++mlir::Block::iterator(relayout));
79 mlir::TensorType type = relayout.getType().dyn_cast<mlir::TensorType>();
80 if (!type) return relayout.emitOpError("type required for Relayout op");
81
82 CreateDTensorLayoutOp(layout, relayout.output(), type, relayout.getLoc(),
83 &builder, &c);
84 return mlir::success();
85 }
86
87 // Creates tf.DTensorLayout that is connected to each function argument if
88 // function arg contains layout attribute.
PropagateFunctionArgAttrToLayoutOp(mlir::MLIRContext & c,mlir::func::FuncOp function)89 mlir::LogicalResult PropagateFunctionArgAttrToLayoutOp(
90 mlir::MLIRContext& c, mlir::func::FuncOp function) {
91 for (int arg_index = 0; arg_index < function.getNumArguments(); ++arg_index) {
92 auto layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
93 arg_index, kCustomDeviceAttr);
94 if (!layout_attr) continue;
95 const auto layout_str = layout_attr.getValue().str();
96 auto layout_or_status = Layout::FromString(layout_str);
97 if (!layout_or_status.ok())
98 return function.emitOpError(llvm::formatv(
99 "function includes attribute {0} for {1}-th arg that cannot be "
100 "serialized to correct layout format. Found attribute {3}",
101 kCustomDeviceAttr, arg_index, layout_str));
102
103 mlir::OpBuilder builder(function.getBody());
104 auto arg = function.getArgument(arg_index);
105 mlir::Type tensor_type = GetSubtypeOrSelf(arg);
106 if (auto type = tensor_type.dyn_cast<mlir::TensorType>()) {
107 CreateDTensorLayoutOp(layout_or_status.ValueOrDie(), arg, type,
108 function.getLoc(), &builder, &c);
109 } else {
110 return function.emitOpError()
111 << "is missing tensor type for argument " << arg_index;
112 }
113 }
114
115 return mlir::success();
116 }
117
118 // Creates tf.DTensorLayout that is connected to terminator op of function if
119 // function contains default layout attribute that represents layout of function
120 // outputs.
PropagateFunctionDefaultLayoutAttrToLayoutOp(mlir::MLIRContext & c,mlir::func::FuncOp function)121 mlir::LogicalResult PropagateFunctionDefaultLayoutAttrToLayoutOp(
122 mlir::MLIRContext& c, mlir::func::FuncOp function) {
123 for (int ret_index = 0; ret_index < function.getNumResults(); ++ret_index) {
124 auto layout_attr_from_func_result =
125 function.getResultAttrOfType<mlir::StringAttr>(
126 ret_index, kCustomDefaultLayoutAttr);
127 if (!layout_attr_from_func_result) continue;
128
129 const std::string layout_string =
130 layout_attr_from_func_result.getValue().str();
131 auto result_layout_or_status = Layout::FromString(layout_string);
132 if (!result_layout_or_status.ok())
133 return function.emitOpError(
134 llvm::formatv("function includes default layout attribute {0} for "
135 "{1}-th output that cannot be serialized to correct "
136 "layout format. Found attribute {3}",
137 kCustomDefaultLayoutAttr, ret_index, layout_string));
138
139 auto function_terminator = function.getBody().front().getTerminator();
140 mlir::OpBuilder builder(function_terminator);
141 auto return_value = function_terminator->getOperand(ret_index);
142
143 if (auto type = return_value.getType().dyn_cast<mlir::TensorType>())
144 CreateDTensorLayoutOp(result_layout_or_status.ValueOrDie(), return_value,
145 type, function.getLoc(), &builder, &c);
146 else
147 return function.emitOpError()
148 << "is missing tensor type for result " << ret_index;
149 }
150
151 return mlir::success();
152 }
153
154 // MLIR pass that removes trivially unused operations in graph.
155 struct DTensorPropagateDefaultLayout
156 : public DTensorPropagateDefaultLayoutBase<DTensorPropagateDefaultLayout> {
getDependentDialectstensorflow::dtensor::__anon80ddfdb40111::DTensorPropagateDefaultLayout157 void getDependentDialects(mlir::DialectRegistry& registry) const override {
158 registry.insert<mlir::dtensor::DTensorDialect>();
159 }
160
runOnOperationtensorflow::dtensor::__anon80ddfdb40111::DTensorPropagateDefaultLayout161 void runOnOperation() override {
162 mlir::MLIRContext& context = getContext();
163 mlir::OpBuilder builder(&context);
164
165 auto function = getOperation();
166
167 auto walk_result =
168 getOperation().walk([&](mlir::Operation* op) -> mlir::WalkResult {
169 if (auto relayout = llvm::dyn_cast<mlir::TF::RelayoutOp>(op)) {
170 (void)PropagateDTensorLayoutForRelayout(context, relayout);
171 return mlir::WalkResult::advance();
172 }
173
174 // Set user annotated layout on operations.
175 auto layout_or_status = ExtractLayoutFromOp(op);
176 if (!layout_or_status.ok()) {
177 op->emitOpError(llvm::formatv(
178 "op has layout attribute {0} that cannot be deserizlied.",
179 kLayoutAttr));
180 return mlir::WalkResult::interrupt();
181 }
182
183 mlir::OpBuilder builder(&context);
184 builder.setInsertionPointAfter(op);
185 const auto layouts = layout_or_status.ValueOrDie();
186 for (const auto& layout_and_index : llvm::enumerate(layouts)) {
187 const int index = layout_and_index.index();
188 const auto& layout = layout_and_index.value();
189 if (!layout || layout->IsEmpty()) continue;
190
191 auto op_output = op->getResult(index);
192 if (auto type = op_output.getType().dyn_cast<mlir::TensorType>()) {
193 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
194 function.getLoc(), op_output,
195 mlir::dtensor::LayoutAttr::get(&context, *layout),
196 mlir::TF::ShapeAttr::get(&context, type));
197 llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
198 op_output.replaceAllUsesExcept(layout_op.output(), exception);
199 } else {
200 return op->emitOpError()
201 << "type for output " << index << " is not a TensorType";
202 }
203 }
204
205 return mlir::WalkResult::advance();
206 });
207
208 if (walk_result.wasInterrupted()) return signalPassFailure();
209
210 // Set user annotated layout on function arguments.
211 if (mlir::failed(PropagateFunctionArgAttrToLayoutOp(context, function)))
212 return signalPassFailure();
213
214 // Set user annotated layout on function outputs.
215 if (mlir::failed(
216 PropagateFunctionDefaultLayoutAttrToLayoutOp(context, function)))
217 return signalPassFailure();
218 }
219 };
220
221 } // namespace
222
223 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorPropagateDefaultLayout()224 CreateDTensorPropagateDefaultLayout() {
225 return std::make_unique<DTensorPropagateDefaultLayout>();
226 }
227
228 } // namespace dtensor
229 } // namespace tensorflow
230