xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/spmd_expansion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/types/optional.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Dialect.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
32 #include "mlir/IR/UseDefLists.h"  // from @llvm-project
33 #include "mlir/IR/Value.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Pass/PassManager.h"  // from @llvm-project
36 #include "mlir/Support/LLVM.h"  // from @llvm-project
37 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
38 #include "mlir/Transforms/Passes.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43 #include "tensorflow/dtensor/cc/constants.h"
44 #include "tensorflow/dtensor/cc/tensor_layout.h"
45 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
48 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
49 #include "tensorflow/dtensor/mlir/layout_parsing.h"
50 #include "tensorflow/dtensor/mlir/op_utils.h"
51 #include "tensorflow/dtensor/mlir/spmd_expander.h"
52 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
53 
54 namespace tensorflow {
55 namespace dtensor {
56 namespace {
57 
58 constexpr char kMainFunctionName[] = "main";
59 
60 // Updates `function` input signature operand at `argument_index` with
61 // `new_shape`.
UpdateFunctionInputShape(const int argument_index,mlir::RankedTensorType new_arg_type,mlir::func::FuncOp function)62 void UpdateFunctionInputShape(const int argument_index,
63                               mlir::RankedTensorType new_arg_type,
64                               mlir::func::FuncOp function) {
65   auto func_type = function.getFunctionType();
66   auto input_types = llvm::to_vector<8>(func_type.getInputs());
67   input_types[argument_index] = new_arg_type;
68   auto new_func_type = mlir::FunctionType::get(
69       function.getContext(), input_types, func_type.getResults());
70   function.setType(new_func_type);
71   function.getBody()
72       .getArgument(argument_index)
73       .setType(function.getFunctionType().getInput(argument_index));
74 }
75 
76 // If `op` is a TF operation, return itself. If it is an DTensorLayout op,
77 // return it's consumer TF operation.
NextTFOp(mlir::Operation * op)78 mlir::Operation* NextTFOp(mlir::Operation* op) {
79   while (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
80     if (op->getUsers().empty()) return nullptr;
81     op = *(op->getUsers().begin());
82   }
83   return op;
84 }
85 
86 // Updates the shape of resource argument if argument has `tf._layout`
87 // attribute.
88 // For example:
89 // main(%arg0: tensor<!tf_type.resource<tensor<4x4xf32>>
90 //                  {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
91 //
92 // will be converted to:
93 //
94 // main(%arg0: tensor<!tf_type.resource<tensor<2x4xf32>>
95 //                   {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
96 //
97 // Note that resource argument type is still a resource type. But it's subtype
98 // has been changed to reflect local shape.
99 // If resource argument does not have subtype or subtype does not have static
100 // shapes or if resource argument does not have corresponding layout attribute,
101 // this function is an no-op.
UpdateResourceArgumentType(const int arg_index,mlir::func::FuncOp function,absl::optional<mlir::RankedTensorType> new_subtype=absl::nullopt)102 mlir::LogicalResult UpdateResourceArgumentType(
103     const int arg_index, mlir::func::FuncOp function,
104     absl::optional<mlir::RankedTensorType> new_subtype = absl::nullopt) {
105   auto resource_arg = function.getArgument(arg_index);
106   if (new_subtype) {
107     auto new_var_type = mlir::RankedTensorType::get(
108         {}, mlir::TF::ResourceType::get(
109                 mlir::ArrayRef<mlir::TensorType>{*new_subtype},
110                 function.getContext()));
111     UpdateFunctionInputShape(arg_index, new_var_type, function);
112     function.setArgAttr(arg_index, kAssignedResourceLocalShape,
113                         ConvertTypeToTensorShapeAttr(*new_subtype));
114     return mlir::success();
115   }
116 
117   auto resource_type = resource_arg.getType()
118                            .cast<mlir::TensorType>()
119                            .getElementType()
120                            .dyn_cast<mlir::TF::ResourceType>();
121   if (!resource_type) return mlir::success();
122 
123   auto sub_types = resource_type.getSubtypes();
124   if (sub_types.size() != 1) return mlir::success();
125 
126   auto resource_arg_sub_type = sub_types.front();
127   if (!resource_arg_sub_type.hasStaticShape()) return mlir::success();
128 
129   // The local shape that is to be assigned to this resource argument type. We
130   // will either pull it from the assigned local shape attribute or compute it
131   // based on the layout.
132   // TODO(srujun): use the attribute value only to check the computed shape.
133   // This is currently blocked by an "empty_layout" set on the resource
134   // arguments, meaning it is not possible to compute local layout.
135   llvm::SmallVector<int64_t, 4> local_arg_shape;
136   auto assigned_resource_local_shape_attr =
137       function.getArgAttrOfType<mlir::TF::ShapeAttr>(
138           arg_index, kAssignedResourceLocalShape);
139   if (assigned_resource_local_shape_attr) {
140     local_arg_shape.append(
141         assigned_resource_local_shape_attr.getShape().begin(),
142         assigned_resource_local_shape_attr.getShape().end());
143   } else {
144     auto layout_or_status = ExtractLayoutFromOperand(resource_arg);
145     if (!layout_or_status.ok())
146       return function.emitOpError(layout_or_status.status().error_message());
147 
148     const auto& layout = layout_or_status.ValueOrDie();
149     if (!layout) return mlir::success();
150 
151     std::vector<int64_t> local_arg_shape_vec =
152         layout->LocalShapeFromGlobalShape(resource_arg_sub_type.getShape());
153     local_arg_shape.append(local_arg_shape_vec.begin(),
154                            local_arg_shape_vec.end());
155   }
156 
157   auto local_variable_subtype = mlir::RankedTensorType::get(
158       local_arg_shape, resource_arg_sub_type.getElementType());
159   auto new_var_type = mlir::RankedTensorType::get(
160       {}, mlir::TF::ResourceType::get(
161               mlir::ArrayRef<mlir::TensorType>{local_variable_subtype},
162               function.getContext()));
163 
164   UpdateFunctionInputShape(arg_index, new_var_type, function);
165   function.setArgAttr(
166       arg_index, kAssignedResourceLocalShape,
167       mlir::TF::ShapeAttr::get(local_variable_subtype.getContext(),
168                                mlir::ArrayRef<int64_t>(local_arg_shape)));
169 
170   return mlir::success();
171 }
172 
173 // Returns whether `value` is used by AssignVariable op, skipping DTensorLayout
174 // op.
IsValueUsedByAssignVariableOp(mlir::Value value,int * resource_argument_index_for_assign_variable)175 bool IsValueUsedByAssignVariableOp(
176     mlir::Value value, int* resource_argument_index_for_assign_variable) {
177   for (auto user : value.getUsers()) {
178     if (auto assign_variable_op =
179             llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>(
180                 NextTFOp(user))) {
181       *resource_argument_index_for_assign_variable =
182           GetForwardedDTensorLayoutInput(assign_variable_op.resource())
183               .cast<mlir::BlockArgument>()
184               .getArgNumber();
185       return true;
186     }
187   }
188   return false;
189 }
190 
191 // Updates argument shapes of `function` based on `tf._layout` attribute.
UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function)192 mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) {
193   for (int argument_index = 0; argument_index < function.getNumArguments();
194        ++argument_index) {
195     auto arg_layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
196         argument_index, kCustomDeviceAttr);
197     if (!arg_layout_attr) continue;
198 
199     auto arg_layout = Layout::FromString(arg_layout_attr.getValue().str());
200     if (!arg_layout.ok())
201       return function.emitOpError(llvm::formatv(
202           "Invalid layout attribute found during SPMD expansion: {0}",
203           arg_layout.status().error_message()));
204 
205     mlir::Type arg_type = mlir::getElementTypeOrSelf(
206         function.getFunctionType().getInput(argument_index));
207 
208     // If argument is a resource type update the subtype shape information
209     // to reflect local shape of resources.
210     if (arg_type.isa<mlir::TF::ResourceType>()) {
211       if (mlir::failed(UpdateResourceArgumentType(argument_index, function)))
212         return mlir::failure();
213       continue;
214     }
215 
216     mlir::RankedTensorType ranked_type =
217         function.getFunctionType()
218             .getInput(argument_index)
219             .dyn_cast<mlir::RankedTensorType>();
220     if (!ranked_type) continue;
221 
222     // If input value is non-resource type, then update the value to reflect
223     // local shape.
224     llvm::ArrayRef<int64_t> arg_shape = ranked_type.getShape();
225     const std::vector<int64_t> arg_local_shape =
226         arg_layout->LocalShapeFromGlobalShape(arg_shape);
227     mlir::RankedTensorType new_arg_type = mlir::RankedTensorType::get(
228         arg_local_shape, ranked_type.getElementType());
229     UpdateFunctionInputShape(argument_index, new_arg_type, function);
230 
231     // If non-resource value was used for AssignVariable op, then ensure that
232     // resource shape of updated/assigned resource is consistent with the
233     // local shape of assigned value.
234     int assigned_resource_argument_index = -1;
235     if (IsValueUsedByAssignVariableOp(function.getArgument(argument_index),
236                                       &assigned_resource_argument_index)) {
237       (void)UpdateResourceArgumentType(assigned_resource_argument_index,
238                                        function, new_arg_type);
239     }
240   }
241   return mlir::success();
242 }
243 
244 // Given SPMD expanded `function_operands` to `function`, update the function
245 // signature to reflect the local shape of `function_operands`.
UpdateFunctionWithLocalInputShapes(mlir::MutableArrayRef<mlir::OpOperand> function_operands,mlir::func::FuncOp function)246 mlir::LogicalResult UpdateFunctionWithLocalInputShapes(
247     mlir::MutableArrayRef<mlir::OpOperand> function_operands,
248     mlir::func::FuncOp function) {
249   for (auto& operand : function_operands) {
250     const int index = operand.getOperandNumber();
251     auto arg_type = operand.get().getType().dyn_cast<mlir::RankedTensorType>();
252     if (!arg_type) continue;
253 
254     auto arg_local_shape = arg_type.getShape();
255     auto new_arg_type =
256         mlir::RankedTensorType::get(arg_local_shape, arg_type.getElementType());
257     UpdateFunctionInputShape(index, new_arg_type, function);
258   }
259   return mlir::success();
260 }
261 
262 // Updates output shapes of enclosing op or function containing `terminator_op`
263 // to local shapes.
UpdateReturnValueShapes(mlir::ModuleOp module,mlir::Operation * terminator_op)264 mlir::LogicalResult UpdateReturnValueShapes(mlir::ModuleOp module,
265                                             mlir::Operation* terminator_op) {
266   auto parent_op = terminator_op->getBlock()->getParentOp();
267   if (!parent_op) return mlir::success();
268 
269   auto output_types = llvm::to_vector<8>(terminator_op->getOperandTypes());
270   if (auto function = llvm::dyn_cast<mlir::func::FuncOp>(parent_op)) {
271     // Update function output type to have local shape.
272     auto new_func_type = mlir::FunctionType::get(
273         function.getContext(), function.getFunctionType().getInputs(),
274         output_types);
275     function.setType(new_func_type);
276 
277     // Update function callsite operations to reflect local output shapes.
278     auto function_uses =
279         mlir::SymbolTable::getSymbolUses(function, &module.getBodyRegion());
280     if (!function_uses) return mlir::success();
281 
282     // Update function callsite operations to reflect local output shapes.
283     for (auto function_use : *function_uses) {
284       auto callsite_op = function_use.getUser();
285       if (!callsite_op) continue;
286 
287       for (auto& output_type_and_index : llvm::enumerate(output_types)) {
288         int index = output_type_and_index.index();
289         const auto& type = output_type_and_index.value();
290         callsite_op->getResult(index).setType(type);
291       }
292     }
293   } else {
294     for (auto& output_type_and_index : llvm::enumerate(output_types)) {
295       int index = output_type_and_index.index();
296       const auto& type = output_type_and_index.value();
297       parent_op->getResult(index).setType(type);
298     }
299   }
300 
301   return mlir::success();
302 }
303 
304 // Conducts SPMD expansion for all ops in `module`. If function call operation
305 // exists, walk the function in topological order to update inputs/outputs of
306 // functions before SPMD expansion of callsite operations is done.
307 // Note that the iteration won't work with recursive function calls.
ConductSPMDExpansion(mlir::ModuleOp module)308 mlir::LogicalResult ConductSPMDExpansion(mlir::ModuleOp module) {
309   auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName);
310   if (!main_func)
311     return module.emitOpError(
312         "could not find `main` function in module for SPMD expansion.");
313 
314   if (mlir::failed(UpdateFunctionArgsUsingLayout(main_func)))
315     return mlir::failure();
316 
317   TopologicalIterator iterator(main_func);
318   while (iterator.hasNext()) {
319     mlir::Operation* op = iterator.next();
320     absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
321     if (func.has_value()) {
322       if (mlir::failed(
323               UpdateFunctionWithLocalInputShapes(op->getOpOperands(), *func)))
324         return mlir::failure();
325     }
326 
327     const bool is_terminator_op =
328         llvm::isa<mlir::func::ReturnOp, mlir::tf_device::ReturnOp>(op);
329     if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op))
330       layout_op.output().setType(layout_op.input().getType());
331 
332     mlir::Operation* expanded_op = nullptr;
333     auto status = RunSPMDExpansion(op, &expanded_op);
334     if (!status.ok() || expanded_op == nullptr) {
335       // Sometimes op may been erased and expanded_op set.
336       // In this case we should emit the error on the expanded op.
337       mlir::Operation* emit_op = op;
338       if (expanded_op != nullptr) emit_op = expanded_op;
339       return emit_op->emitError(WithContext(status, __FILE__, __LINE__,
340                                             "While computing SPMD expansion")
341                                     .error_message());
342     }
343 
344     // If expanded op is terminator of tf_device.Cluster or a function, then
345     // make sure to update the function return value as well as the shape of
346     // it's callsite operation.
347     if (is_terminator_op)
348       if (mlir::failed(UpdateReturnValueShapes(module, expanded_op)))
349         return mlir::failure();
350   }
351   return mlir::success();
352 }
353 
354 // DTensorLayout only conveys layout information of tensors which is no
355 // longer needed after SPMD expansion. As so, remove all layouts from
356 // graph.
RemoveDTensorLayoutOps(mlir::ModuleOp module)357 void RemoveDTensorLayoutOps(mlir::ModuleOp module) {
358   llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
359   module.walk(
360       [&](mlir::TF::DTensorLayout layout) { layout_ops.emplace_back(layout); });
361 
362   for (auto layout_op : layout_ops) RemoveDTensorLayoutOp(layout_op);
363 }
364 
365 // Removes temporary attrs created during SPMD expansion.
RemoveTemporarySPMDAttrs(mlir::ModuleOp module)366 void RemoveTemporarySPMDAttrs(mlir::ModuleOp module) {
367   module.walk([&](mlir::Operation* op) {
368     if (op->hasAttr(kDeviceSeedForMeshDims)) {
369       op->removeAttr(kDeviceSeedForMeshDims);
370     }
371   });
372 }
373 
374 // MLIR pass that converts graph in global view into a local view which can be
375 // invoked in parallel on distributed set of devices. This pass removes
376 // all DTensorLayout ops after the expansion is done. Temporary nodes and
377 // attributes are also removed after the pass is done.
378 struct DTensorSPMDExpansion
379     : public DTensorSPMDExpansionBase<DTensorSPMDExpansion> {
getDependentDialectstensorflow::dtensor::__anon860a1f380111::DTensorSPMDExpansion380   void getDependentDialects(mlir::DialectRegistry& registry) const override {
381     registry.insert<mlir::dtensor::DTensorDialect>();
382   }
383 
runOnOperationtensorflow::dtensor::__anon860a1f380111::DTensorSPMDExpansion384   void runOnOperation() override {
385     auto module = getOperation();
386     if (failed(ConductSPMDExpansion(module))) return signalPassFailure();
387 
388     RemoveDTensorLayoutOps(module);
389 
390     RemoveTemporarySPMDAttrs(module);
391   };
392 };
393 
394 }  // namespace
395 
396 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorSPMDExpansion()397 CreateDTensorSPMDExpansion() {
398   return std::make_unique<DTensorSPMDExpansion>();
399 }
400 
401 }  // namespace dtensor
402 }  // namespace tensorflow
403