1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/IR/Attributes.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/Diagnostics.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
27 #include "mlir/IR/Types.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
31 #include "tensorflow/dtensor/cc/constants.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34
35 namespace tensorflow {
36 namespace dtensor {
37 namespace {
38
39 // Sets `_global_shape` attributes to argument/return values of `function`.
AnnotateFunctionArgRetvalGlobalShapes(mlir::func::FuncOp function,mlir::OpBuilder * builder)40 void AnnotateFunctionArgRetvalGlobalShapes(mlir::func::FuncOp function,
41 mlir::OpBuilder* builder) {
42 for (const auto& argument_type_and_index :
43 llvm::enumerate(function.getArgumentTypes())) {
44 const int index = argument_type_and_index.index();
45 const auto& argument_type = argument_type_and_index.value();
46 // Extract TensorType from element of resource type to allow setting proper
47 // global shape of resource types.
48 if (auto resource_type = mlir::getElementTypeOrSelf(argument_type)
49 .dyn_cast<mlir::TF::ResourceType>()) {
50 auto subtype = resource_type.getSubtypes();
51 if (subtype.size() == 1) {
52 // subtype returns a Array of TensorType -- if it contains more than one
53 // Tensor type, we give up extracting the single TensorType inside the
54 // subtype.
55 function.setArgAttr(index, kGlobalShapeDialectAttr,
56 ConvertTypeToTensorShapeAttr(subtype[0]));
57 }
58 } else {
59 function.setArgAttr(index, kGlobalShapeDialectAttr,
60 ConvertTypeToTensorShapeAttr(argument_type));
61 }
62 }
63
64 for (const auto& retval_type_and_index :
65 llvm::enumerate(function.getFunctionType().getResults())) {
66 const int index = retval_type_and_index.index();
67 const auto& retval_type = retval_type_and_index.value();
68 function.setResultAttr(index, kGlobalShapeDialectAttr,
69 ConvertTypeToTensorShapeAttr(retval_type));
70 }
71 }
72
73 // Sets `_global_shape` attribute of an `op` with array of ShapeAttr of
74 // `outputs.
AnnotateOperationGlobalShape(mlir::Operation * op,mlir::OpBuilder * builder)75 void AnnotateOperationGlobalShape(mlir::Operation* op,
76 mlir::OpBuilder* builder) {
77 llvm::SmallVector<mlir::Attribute, 4> op_global_shape;
78 op_global_shape.reserve(op->getNumResults());
79
80 for (const auto& result_type : op->getResultTypes())
81 op_global_shape.emplace_back(ConvertTypeToTensorShapeAttr(result_type));
82
83 op->setAttr(kGlobalShape, builder->getArrayAttr(op_global_shape));
84 }
85
86 // Pass that annotates function argument/return values and all operation with
87 // `_global_shape` attribute. This will be used during SPMD expansion to
88 // preserve original global shape of operations in graph after shape has been
89 // modified to local shape.
90 struct DTensorAnnotateGlobalShape
91 : public DTensorAnnotateGlobalShapeBase<DTensorAnnotateGlobalShape> {
runOnOperationtensorflow::dtensor::__anon0c4951ca0111::DTensorAnnotateGlobalShape92 void runOnOperation() override {
93 mlir::MLIRContext& context = getContext();
94 mlir::OpBuilder builder(&context);
95
96 auto module = getOperation();
97 module.walk([&](mlir::func::FuncOp function) {
98 if (function.empty()) return;
99
100 auto* terminator = function.getBody().front().getTerminator();
101 AnnotateFunctionArgRetvalGlobalShapes(function, &builder);
102 function.getBody().walk([&](mlir::Operation* op) {
103 if (op == terminator) return;
104
105 AnnotateOperationGlobalShape(op, &builder);
106 });
107 });
108 }
109 };
110
111 } // namespace
112
113 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorAnnotateGlobalShape()114 CreateDTensorAnnotateGlobalShape() {
115 return std::make_unique<DTensorAnnotateGlobalShape>();
116 }
117
118 } // namespace dtensor
119 } // namespace tensorflow
120