xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/propagate_default_layout.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "absl/types/optional.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Builders.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "mlir/IR/Visitors.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/dtensor/cc/constants.h"
30 #include "tensorflow/dtensor/cc/tensor_layout.h"
31 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
32 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dtensor_attributes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
35 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36 #include "tensorflow/dtensor/mlir/layout_parsing.h"
37 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38 #include "tensorflow/dtensor/mlir/value_utils.h"
39 
40 namespace tensorflow {
41 namespace dtensor {
42 namespace {
43 
44 // Creates tf.DTensorLayout op that forwards `input` value.
CreateDTensorLayoutOp(const Layout & layout,mlir::Value input,mlir::TensorType & type,mlir::Location loc,mlir::OpBuilder * builder,mlir::MLIRContext * context)45 void CreateDTensorLayoutOp(const Layout& layout, mlir::Value input,
46                            mlir::TensorType& type, mlir::Location loc,
47                            mlir::OpBuilder* builder,
48                            mlir::MLIRContext* context) {
49   if (layout.IsEmpty()) return;
50 
51   auto layout_op = builder->create<mlir::TF::DTensorLayout>(
52       loc, input, mlir::dtensor::LayoutAttr::get(context, layout),
53       mlir::TF::ShapeAttr::get(context, type));
54   llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
55   input.replaceAllUsesExcept(layout_op.output(), exception);
56 }
57 
58 // Adds DTensorLayout op following each Relayout operation to ensure that
59 // tensor from `relayout` has fixed layout.
PropagateDTensorLayoutForRelayout(mlir::MLIRContext & c,mlir::TF::RelayoutOp relayout)60 mlir::LogicalResult PropagateDTensorLayoutForRelayout(
61     mlir::MLIRContext& c, mlir::TF::RelayoutOp relayout) {
62   const std::string layout_str = relayout.layout().str();
63   auto layout_or_status = Layout::FromString(layout_str);
64   if (!layout_or_status.ok()) {
65     return relayout.emitOpError(
66         llvm::formatv("found Relayout op with incorrect/unparsable layout. "
67                       "Found layout: {0} ",
68                       layout_str));
69   }
70   const Layout& layout = layout_or_status.ValueOrDie();
71 
72   // Skip adding a DTensorLayout if Relayout is 'dynamic'. Any dimension with
73   // MATCH for the layout will have its layout preserved in layout propagation.
74   for (const std::string& sharding_spec : layout.sharding_spec_strs())
75     if (sharding_spec == Layout::kMatch) return mlir::success();
76 
77   mlir::OpBuilder builder(relayout->getBlock(),
78                           ++mlir::Block::iterator(relayout));
79   mlir::TensorType type = relayout.getType().dyn_cast<mlir::TensorType>();
80   if (!type) return relayout.emitOpError("type required for Relayout op");
81 
82   CreateDTensorLayoutOp(layout, relayout.output(), type, relayout.getLoc(),
83                         &builder, &c);
84   return mlir::success();
85 }
86 
87 // Creates tf.DTensorLayout that is connected to each function argument if
88 // function arg contains layout attribute.
PropagateFunctionArgAttrToLayoutOp(mlir::MLIRContext & c,mlir::func::FuncOp function)89 mlir::LogicalResult PropagateFunctionArgAttrToLayoutOp(
90     mlir::MLIRContext& c, mlir::func::FuncOp function) {
91   for (int arg_index = 0; arg_index < function.getNumArguments(); ++arg_index) {
92     auto layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
93         arg_index, kCustomDeviceAttr);
94     if (!layout_attr) continue;
95     const auto layout_str = layout_attr.getValue().str();
96     auto layout_or_status = Layout::FromString(layout_str);
97     if (!layout_or_status.ok())
98       return function.emitOpError(llvm::formatv(
99           "function includes attribute {0} for {1}-th arg that cannot be "
100           "serialized to correct layout format. Found attribute {3}",
101           kCustomDeviceAttr, arg_index, layout_str));
102 
103     mlir::OpBuilder builder(function.getBody());
104     auto arg = function.getArgument(arg_index);
105     mlir::Type tensor_type = GetSubtypeOrSelf(arg);
106     if (auto type = tensor_type.dyn_cast<mlir::TensorType>()) {
107       CreateDTensorLayoutOp(layout_or_status.ValueOrDie(), arg, type,
108                             function.getLoc(), &builder, &c);
109     } else {
110       return function.emitOpError()
111              << "is missing tensor type for argument " << arg_index;
112     }
113   }
114 
115   return mlir::success();
116 }
117 
118 // Creates tf.DTensorLayout that is connected to terminator op of function if
119 // function contains default layout attribute that represents layout of function
120 // outputs.
PropagateFunctionDefaultLayoutAttrToLayoutOp(mlir::MLIRContext & c,mlir::func::FuncOp function)121 mlir::LogicalResult PropagateFunctionDefaultLayoutAttrToLayoutOp(
122     mlir::MLIRContext& c, mlir::func::FuncOp function) {
123   for (int ret_index = 0; ret_index < function.getNumResults(); ++ret_index) {
124     auto layout_attr_from_func_result =
125         function.getResultAttrOfType<mlir::StringAttr>(
126             ret_index, kCustomDefaultLayoutAttr);
127     if (!layout_attr_from_func_result) continue;
128 
129     const std::string layout_string =
130         layout_attr_from_func_result.getValue().str();
131     auto result_layout_or_status = Layout::FromString(layout_string);
132     if (!result_layout_or_status.ok())
133       return function.emitOpError(
134           llvm::formatv("function includes default layout attribute {0} for "
135                         "{1}-th output that cannot be serialized to correct "
136                         "layout format. Found attribute {3}",
137                         kCustomDefaultLayoutAttr, ret_index, layout_string));
138 
139     auto function_terminator = function.getBody().front().getTerminator();
140     mlir::OpBuilder builder(function_terminator);
141     auto return_value = function_terminator->getOperand(ret_index);
142 
143     if (auto type = return_value.getType().dyn_cast<mlir::TensorType>())
144       CreateDTensorLayoutOp(result_layout_or_status.ValueOrDie(), return_value,
145                             type, function.getLoc(), &builder, &c);
146     else
147       return function.emitOpError()
148              << "is missing tensor type for result " << ret_index;
149   }
150 
151   return mlir::success();
152 }
153 
154 // MLIR pass that removes trivially unused operations in graph.
155 struct DTensorPropagateDefaultLayout
156     : public DTensorPropagateDefaultLayoutBase<DTensorPropagateDefaultLayout> {
getDependentDialectstensorflow::dtensor::__anon80ddfdb40111::DTensorPropagateDefaultLayout157   void getDependentDialects(mlir::DialectRegistry& registry) const override {
158     registry.insert<mlir::dtensor::DTensorDialect>();
159   }
160 
runOnOperationtensorflow::dtensor::__anon80ddfdb40111::DTensorPropagateDefaultLayout161   void runOnOperation() override {
162     mlir::MLIRContext& context = getContext();
163     mlir::OpBuilder builder(&context);
164 
165     auto function = getOperation();
166 
167     auto walk_result =
168         getOperation().walk([&](mlir::Operation* op) -> mlir::WalkResult {
169           if (auto relayout = llvm::dyn_cast<mlir::TF::RelayoutOp>(op)) {
170             (void)PropagateDTensorLayoutForRelayout(context, relayout);
171             return mlir::WalkResult::advance();
172           }
173 
174           // Set user annotated layout on operations.
175           auto layout_or_status = ExtractLayoutFromOp(op);
176           if (!layout_or_status.ok()) {
177             op->emitOpError(llvm::formatv(
178                 "op has layout attribute {0} that cannot be deserizlied.",
179                 kLayoutAttr));
180             return mlir::WalkResult::interrupt();
181           }
182 
183           mlir::OpBuilder builder(&context);
184           builder.setInsertionPointAfter(op);
185           const auto layouts = layout_or_status.ValueOrDie();
186           for (const auto& layout_and_index : llvm::enumerate(layouts)) {
187             const int index = layout_and_index.index();
188             const auto& layout = layout_and_index.value();
189             if (!layout || layout->IsEmpty()) continue;
190 
191             auto op_output = op->getResult(index);
192             if (auto type = op_output.getType().dyn_cast<mlir::TensorType>()) {
193               auto layout_op = builder.create<mlir::TF::DTensorLayout>(
194                   function.getLoc(), op_output,
195                   mlir::dtensor::LayoutAttr::get(&context, *layout),
196                   mlir::TF::ShapeAttr::get(&context, type));
197               llvm::SmallPtrSet<mlir::Operation*, 4> exception{layout_op};
198               op_output.replaceAllUsesExcept(layout_op.output(), exception);
199             } else {
200               return op->emitOpError()
201                      << "type for output " << index << " is not a TensorType";
202             }
203           }
204 
205           return mlir::WalkResult::advance();
206         });
207 
208     if (walk_result.wasInterrupted()) return signalPassFailure();
209 
210     // Set user annotated layout on function arguments.
211     if (mlir::failed(PropagateFunctionArgAttrToLayoutOp(context, function)))
212       return signalPassFailure();
213 
214     // Set user annotated layout on function outputs.
215     if (mlir::failed(
216             PropagateFunctionDefaultLayoutAttrToLayoutOp(context, function)))
217       return signalPassFailure();
218   }
219 };
220 
221 }  // namespace
222 
223 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorPropagateDefaultLayout()224 CreateDTensorPropagateDefaultLayout() {
225   return std::make_unique<DTensorPropagateDefaultLayout>();
226 }
227 
228 }  // namespace dtensor
229 }  // namespace tensorflow
230