1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/types/optional.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/FormatVariadic.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Dialect.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
32 #include "mlir/IR/UseDefLists.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassManager.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "mlir/Support/LogicalResult.h" // from @llvm-project
38 #include "mlir/Transforms/Passes.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
43 #include "tensorflow/dtensor/cc/constants.h"
44 #include "tensorflow/dtensor/cc/tensor_layout.h"
45 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
47 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
48 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
49 #include "tensorflow/dtensor/mlir/layout_parsing.h"
50 #include "tensorflow/dtensor/mlir/op_utils.h"
51 #include "tensorflow/dtensor/mlir/spmd_expander.h"
52 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
53
54 namespace tensorflow {
55 namespace dtensor {
56 namespace {
57
58 constexpr char kMainFunctionName[] = "main";
59
60 // Updates `function` input signature operand at `argument_index` with
61 // `new_shape`.
UpdateFunctionInputShape(const int argument_index,mlir::RankedTensorType new_arg_type,mlir::func::FuncOp function)62 void UpdateFunctionInputShape(const int argument_index,
63 mlir::RankedTensorType new_arg_type,
64 mlir::func::FuncOp function) {
65 auto func_type = function.getFunctionType();
66 auto input_types = llvm::to_vector<8>(func_type.getInputs());
67 input_types[argument_index] = new_arg_type;
68 auto new_func_type = mlir::FunctionType::get(
69 function.getContext(), input_types, func_type.getResults());
70 function.setType(new_func_type);
71 function.getBody()
72 .getArgument(argument_index)
73 .setType(function.getFunctionType().getInput(argument_index));
74 }
75
76 // If `op` is a TF operation, return itself. If it is an DTensorLayout op,
77 // return it's consumer TF operation.
NextTFOp(mlir::Operation * op)78 mlir::Operation* NextTFOp(mlir::Operation* op) {
79 while (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
80 if (op->getUsers().empty()) return nullptr;
81 op = *(op->getUsers().begin());
82 }
83 return op;
84 }
85
86 // Updates the shape of resource argument if argument has `tf._layout`
87 // attribute.
88 // For example:
89 // main(%arg0: tensor<!tf_type.resource<tensor<4x4xf32>>
90 // {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
91 //
92 // will be converted to:
93 //
94 // main(%arg0: tensor<!tf_type.resource<tensor<2x4xf32>>
95 // {tf._layout = "mesh:TPU,x=2,y=2 layout:x,not_sharded"})
96 //
97 // Note that resource argument type is still a resource type. But it's subtype
98 // has been changed to reflect local shape.
99 // If resource argument does not have subtype or subtype does not have static
100 // shapes or if resource argument does not have corresponding layout attribute,
101 // this function is an no-op.
UpdateResourceArgumentType(const int arg_index,mlir::func::FuncOp function,absl::optional<mlir::RankedTensorType> new_subtype=absl::nullopt)102 mlir::LogicalResult UpdateResourceArgumentType(
103 const int arg_index, mlir::func::FuncOp function,
104 absl::optional<mlir::RankedTensorType> new_subtype = absl::nullopt) {
105 auto resource_arg = function.getArgument(arg_index);
106 if (new_subtype) {
107 auto new_var_type = mlir::RankedTensorType::get(
108 {}, mlir::TF::ResourceType::get(
109 mlir::ArrayRef<mlir::TensorType>{*new_subtype},
110 function.getContext()));
111 UpdateFunctionInputShape(arg_index, new_var_type, function);
112 function.setArgAttr(arg_index, kAssignedResourceLocalShape,
113 ConvertTypeToTensorShapeAttr(*new_subtype));
114 return mlir::success();
115 }
116
117 auto resource_type = resource_arg.getType()
118 .cast<mlir::TensorType>()
119 .getElementType()
120 .dyn_cast<mlir::TF::ResourceType>();
121 if (!resource_type) return mlir::success();
122
123 auto sub_types = resource_type.getSubtypes();
124 if (sub_types.size() != 1) return mlir::success();
125
126 auto resource_arg_sub_type = sub_types.front();
127 if (!resource_arg_sub_type.hasStaticShape()) return mlir::success();
128
129 // The local shape that is to be assigned to this resource argument type. We
130 // will either pull it from the assigned local shape attribute or compute it
131 // based on the layout.
132 // TODO(srujun): use the attribute value only to check the computed shape.
133 // This is currently blocked by an "empty_layout" set on the resource
134 // arguments, meaning it is not possible to compute local layout.
135 llvm::SmallVector<int64_t, 4> local_arg_shape;
136 auto assigned_resource_local_shape_attr =
137 function.getArgAttrOfType<mlir::TF::ShapeAttr>(
138 arg_index, kAssignedResourceLocalShape);
139 if (assigned_resource_local_shape_attr) {
140 local_arg_shape.append(
141 assigned_resource_local_shape_attr.getShape().begin(),
142 assigned_resource_local_shape_attr.getShape().end());
143 } else {
144 auto layout_or_status = ExtractLayoutFromOperand(resource_arg);
145 if (!layout_or_status.ok())
146 return function.emitOpError(layout_or_status.status().error_message());
147
148 const auto& layout = layout_or_status.ValueOrDie();
149 if (!layout) return mlir::success();
150
151 std::vector<int64_t> local_arg_shape_vec =
152 layout->LocalShapeFromGlobalShape(resource_arg_sub_type.getShape());
153 local_arg_shape.append(local_arg_shape_vec.begin(),
154 local_arg_shape_vec.end());
155 }
156
157 auto local_variable_subtype = mlir::RankedTensorType::get(
158 local_arg_shape, resource_arg_sub_type.getElementType());
159 auto new_var_type = mlir::RankedTensorType::get(
160 {}, mlir::TF::ResourceType::get(
161 mlir::ArrayRef<mlir::TensorType>{local_variable_subtype},
162 function.getContext()));
163
164 UpdateFunctionInputShape(arg_index, new_var_type, function);
165 function.setArgAttr(
166 arg_index, kAssignedResourceLocalShape,
167 mlir::TF::ShapeAttr::get(local_variable_subtype.getContext(),
168 mlir::ArrayRef<int64_t>(local_arg_shape)));
169
170 return mlir::success();
171 }
172
173 // Returns whether `value` is used by AssignVariable op, skipping DTensorLayout
174 // op.
IsValueUsedByAssignVariableOp(mlir::Value value,int * resource_argument_index_for_assign_variable)175 bool IsValueUsedByAssignVariableOp(
176 mlir::Value value, int* resource_argument_index_for_assign_variable) {
177 for (auto user : value.getUsers()) {
178 if (auto assign_variable_op =
179 llvm::dyn_cast_or_null<mlir::TF::AssignVariableOp>(
180 NextTFOp(user))) {
181 *resource_argument_index_for_assign_variable =
182 GetForwardedDTensorLayoutInput(assign_variable_op.resource())
183 .cast<mlir::BlockArgument>()
184 .getArgNumber();
185 return true;
186 }
187 }
188 return false;
189 }
190
191 // Updates argument shapes of `function` based on `tf._layout` attribute.
UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function)192 mlir::LogicalResult UpdateFunctionArgsUsingLayout(mlir::func::FuncOp function) {
193 for (int argument_index = 0; argument_index < function.getNumArguments();
194 ++argument_index) {
195 auto arg_layout_attr = function.getArgAttrOfType<mlir::StringAttr>(
196 argument_index, kCustomDeviceAttr);
197 if (!arg_layout_attr) continue;
198
199 auto arg_layout = Layout::FromString(arg_layout_attr.getValue().str());
200 if (!arg_layout.ok())
201 return function.emitOpError(llvm::formatv(
202 "Invalid layout attribute found during SPMD expansion: {0}",
203 arg_layout.status().error_message()));
204
205 mlir::Type arg_type = mlir::getElementTypeOrSelf(
206 function.getFunctionType().getInput(argument_index));
207
208 // If argument is a resource type update the subtype shape information
209 // to reflect local shape of resources.
210 if (arg_type.isa<mlir::TF::ResourceType>()) {
211 if (mlir::failed(UpdateResourceArgumentType(argument_index, function)))
212 return mlir::failure();
213 continue;
214 }
215
216 mlir::RankedTensorType ranked_type =
217 function.getFunctionType()
218 .getInput(argument_index)
219 .dyn_cast<mlir::RankedTensorType>();
220 if (!ranked_type) continue;
221
222 // If input value is non-resource type, then update the value to reflect
223 // local shape.
224 llvm::ArrayRef<int64_t> arg_shape = ranked_type.getShape();
225 const std::vector<int64_t> arg_local_shape =
226 arg_layout->LocalShapeFromGlobalShape(arg_shape);
227 mlir::RankedTensorType new_arg_type = mlir::RankedTensorType::get(
228 arg_local_shape, ranked_type.getElementType());
229 UpdateFunctionInputShape(argument_index, new_arg_type, function);
230
231 // If non-resource value was used for AssignVariable op, then ensure that
232 // resource shape of updated/assigned resource is consistent with the
233 // local shape of assigned value.
234 int assigned_resource_argument_index = -1;
235 if (IsValueUsedByAssignVariableOp(function.getArgument(argument_index),
236 &assigned_resource_argument_index)) {
237 (void)UpdateResourceArgumentType(assigned_resource_argument_index,
238 function, new_arg_type);
239 }
240 }
241 return mlir::success();
242 }
243
244 // Given SPMD expanded `function_operands` to `function`, update the function
245 // signature to reflect the local shape of `function_operands`.
UpdateFunctionWithLocalInputShapes(mlir::MutableArrayRef<mlir::OpOperand> function_operands,mlir::func::FuncOp function)246 mlir::LogicalResult UpdateFunctionWithLocalInputShapes(
247 mlir::MutableArrayRef<mlir::OpOperand> function_operands,
248 mlir::func::FuncOp function) {
249 for (auto& operand : function_operands) {
250 const int index = operand.getOperandNumber();
251 auto arg_type = operand.get().getType().dyn_cast<mlir::RankedTensorType>();
252 if (!arg_type) continue;
253
254 auto arg_local_shape = arg_type.getShape();
255 auto new_arg_type =
256 mlir::RankedTensorType::get(arg_local_shape, arg_type.getElementType());
257 UpdateFunctionInputShape(index, new_arg_type, function);
258 }
259 return mlir::success();
260 }
261
262 // Updates output shapes of enclosing op or function containing `terminator_op`
263 // to local shapes.
UpdateReturnValueShapes(mlir::ModuleOp module,mlir::Operation * terminator_op)264 mlir::LogicalResult UpdateReturnValueShapes(mlir::ModuleOp module,
265 mlir::Operation* terminator_op) {
266 auto parent_op = terminator_op->getBlock()->getParentOp();
267 if (!parent_op) return mlir::success();
268
269 auto output_types = llvm::to_vector<8>(terminator_op->getOperandTypes());
270 if (auto function = llvm::dyn_cast<mlir::func::FuncOp>(parent_op)) {
271 // Update function output type to have local shape.
272 auto new_func_type = mlir::FunctionType::get(
273 function.getContext(), function.getFunctionType().getInputs(),
274 output_types);
275 function.setType(new_func_type);
276
277 // Update function callsite operations to reflect local output shapes.
278 auto function_uses =
279 mlir::SymbolTable::getSymbolUses(function, &module.getBodyRegion());
280 if (!function_uses) return mlir::success();
281
282 // Update function callsite operations to reflect local output shapes.
283 for (auto function_use : *function_uses) {
284 auto callsite_op = function_use.getUser();
285 if (!callsite_op) continue;
286
287 for (auto& output_type_and_index : llvm::enumerate(output_types)) {
288 int index = output_type_and_index.index();
289 const auto& type = output_type_and_index.value();
290 callsite_op->getResult(index).setType(type);
291 }
292 }
293 } else {
294 for (auto& output_type_and_index : llvm::enumerate(output_types)) {
295 int index = output_type_and_index.index();
296 const auto& type = output_type_and_index.value();
297 parent_op->getResult(index).setType(type);
298 }
299 }
300
301 return mlir::success();
302 }
303
304 // Conducts SPMD expansion for all ops in `module`. If function call operation
305 // exists, walk the function in topological order to update inputs/outputs of
306 // functions before SPMD expansion of callsite operations is done.
307 // Note that the iteration won't work with recursive function calls.
ConductSPMDExpansion(mlir::ModuleOp module)308 mlir::LogicalResult ConductSPMDExpansion(mlir::ModuleOp module) {
309 auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName);
310 if (!main_func)
311 return module.emitOpError(
312 "could not find `main` function in module for SPMD expansion.");
313
314 if (mlir::failed(UpdateFunctionArgsUsingLayout(main_func)))
315 return mlir::failure();
316
317 TopologicalIterator iterator(main_func);
318 while (iterator.hasNext()) {
319 mlir::Operation* op = iterator.next();
320 absl::optional<mlir::func::FuncOp> func = MaybeFindFunction(op);
321 if (func.has_value()) {
322 if (mlir::failed(
323 UpdateFunctionWithLocalInputShapes(op->getOpOperands(), *func)))
324 return mlir::failure();
325 }
326
327 const bool is_terminator_op =
328 llvm::isa<mlir::func::ReturnOp, mlir::tf_device::ReturnOp>(op);
329 if (auto layout_op = llvm::dyn_cast<mlir::TF::DTensorLayout>(op))
330 layout_op.output().setType(layout_op.input().getType());
331
332 mlir::Operation* expanded_op = nullptr;
333 auto status = RunSPMDExpansion(op, &expanded_op);
334 if (!status.ok() || expanded_op == nullptr) {
335 // Sometimes op may been erased and expanded_op set.
336 // In this case we should emit the error on the expanded op.
337 mlir::Operation* emit_op = op;
338 if (expanded_op != nullptr) emit_op = expanded_op;
339 return emit_op->emitError(WithContext(status, __FILE__, __LINE__,
340 "While computing SPMD expansion")
341 .error_message());
342 }
343
344 // If expanded op is terminator of tf_device.Cluster or a function, then
345 // make sure to update the function return value as well as the shape of
346 // it's callsite operation.
347 if (is_terminator_op)
348 if (mlir::failed(UpdateReturnValueShapes(module, expanded_op)))
349 return mlir::failure();
350 }
351 return mlir::success();
352 }
353
354 // DTensorLayout only conveys layout information of tensors which is no
355 // longer needed after SPMD expansion. As so, remove all layouts from
356 // graph.
RemoveDTensorLayoutOps(mlir::ModuleOp module)357 void RemoveDTensorLayoutOps(mlir::ModuleOp module) {
358 llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops;
359 module.walk(
360 [&](mlir::TF::DTensorLayout layout) { layout_ops.emplace_back(layout); });
361
362 for (auto layout_op : layout_ops) RemoveDTensorLayoutOp(layout_op);
363 }
364
365 // Removes temporary attrs created during SPMD expansion.
RemoveTemporarySPMDAttrs(mlir::ModuleOp module)366 void RemoveTemporarySPMDAttrs(mlir::ModuleOp module) {
367 module.walk([&](mlir::Operation* op) {
368 if (op->hasAttr(kDeviceSeedForMeshDims)) {
369 op->removeAttr(kDeviceSeedForMeshDims);
370 }
371 });
372 }
373
374 // MLIR pass that converts graph in global view into a local view which can be
375 // invoked in parallel on distributed set of devices. This pass removes
376 // all DTensorLayout ops after the expansion is done. Temporary nodes and
377 // attributes are also removed after the pass is done.
378 struct DTensorSPMDExpansion
379 : public DTensorSPMDExpansionBase<DTensorSPMDExpansion> {
getDependentDialectstensorflow::dtensor::__anon860a1f380111::DTensorSPMDExpansion380 void getDependentDialects(mlir::DialectRegistry& registry) const override {
381 registry.insert<mlir::dtensor::DTensorDialect>();
382 }
383
runOnOperationtensorflow::dtensor::__anon860a1f380111::DTensorSPMDExpansion384 void runOnOperation() override {
385 auto module = getOperation();
386 if (failed(ConductSPMDExpansion(module))) return signalPassFailure();
387
388 RemoveDTensorLayoutOps(module);
389
390 RemoveTemporarySPMDAttrs(module);
391 };
392 };
393
394 } // namespace
395
396 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorSPMDExpansion()397 CreateDTensorSPMDExpansion() {
398 return std::make_unique<DTensorSPMDExpansion>();
399 }
400
401 } // namespace dtensor
402 } // namespace tensorflow
403