1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/IR/SymbolTable.h" // from @llvm-project
24 #include "mlir/IR/Visitors.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Pass/PassManager.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
32 #include "tensorflow/dtensor/mlir/device_utils.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
35 #include "tensorflow/dtensor/mlir/op_utils.h"
36 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
37
38 namespace tensorflow {
39 namespace dtensor {
40 namespace {
41
42 // Holds information on functions to rewrite. `function` is the function
43 // definition or function that needs to be updated and `callsite_ops` holds a
44 // list of ops that calls the `function`.
45 struct FunctionToChangeInfo {
46 mlir::func::FuncOp function;
47 llvm::SmallVector<mlir::Operation*, 4> callsite_ops;
48 };
49
50 // Finds all functions in graph that is not a public functions and retrieves
51 // their callsite operations.
FindFunctionsToRewrite(mlir::ModuleOp module)52 llvm::SmallVector<FunctionToChangeInfo, 4> FindFunctionsToRewrite(
53 mlir::ModuleOp module) {
54 llvm::SmallVector<FunctionToChangeInfo, 4> functions_to_change;
55 module.walk([&](mlir::Operation* op) {
56 if (!llvm::isa<mlir::TF::StatefulPartitionedCallOp,
57 mlir::TF::PartitionedCallOp>(op))
58 return;
59
60 // Extract function symbol from PartitionedCall or StatefulPartitionedCall
61 // op.
62 llvm::StringRef symbol;
63 if (auto call_op =
64 llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
65 symbol = call_op.f();
66 } else {
67 auto symbol_ref = llvm::dyn_cast<mlir::TF::PartitionedCallOp>(op).f();
68 if (!symbol_ref.isa<mlir::FlatSymbolRefAttr>()) return;
69 symbol = symbol_ref.getRootReference().getValue();
70 }
71
72 // If function definition could be found, then extract all function usages.
73 auto function = MaybeFindFunction(op);
74 if (!function || function->isPublic()) return;
75
76 auto function_uses = mlir::SymbolTable::getSymbolUses(
77 mlir::StringAttr::get(module.getContext(), symbol),
78 &module.getBodyRegion());
79 if (!function_uses) return;
80
81 llvm::SmallVector<mlir::Operation*, 4> function_use_ops;
82 for (auto function_use : *function_uses)
83 function_use_ops.emplace_back(function_use.getUser());
84
85 functions_to_change.emplace_back(
86 FunctionToChangeInfo{function.value(), function_use_ops});
87 });
88
89 return functions_to_change;
90 }
91
92 // Rewrites function such that 0th argument of type `type` is added to
93 // `function`.
PrependArgumentToFunction(mlir::func::FuncOp function,mlir::Type type,mlir::OpBuilder * builder)94 void PrependArgumentToFunction(mlir::func::FuncOp function, mlir::Type type,
95 mlir::OpBuilder* builder) {
96 auto& function_body = function.front();
97 function_body.insertArgument(static_cast<unsigned>(0), type,
98 function.getLoc());
99 auto new_argument_types =
100 llvm::to_vector<4>(function_body.getArgumentTypes());
101 function.setType(
102 mlir::FunctionType::get(builder->getContext(), new_argument_types,
103 function.getFunctionType().getResults()));
104 }
105
106 // Rewrites function callsites ops. As function signatures are already updated,
107 // simply add 0th argument of the parent function to 0th operand of the callsite
108 // operation.
PrependDeviceIdToCallsites(mlir::OpBuilder * builder,mlir::Operation * op)109 mlir::LogicalResult PrependDeviceIdToCallsites(mlir::OpBuilder* builder,
110 mlir::Operation* op) {
111 auto device_id_or_status = DeviceId(op);
112 if (!device_id_or_status.ok())
113 return op->emitOpError(
114 "Failed during PropagateDeviceIdToFunctionArgs pass. All functions "
115 "must have device id as 0th argument.");
116
117 auto new_operands = llvm::to_vector<4>(op->getOperands());
118 new_operands.insert(new_operands.begin(), device_id_or_status.ValueOrDie());
119
120 builder->setInsertionPoint(op);
121 mlir::Operation* new_call = nullptr;
122 if (auto stateful_partitioned_call =
123 llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
124 new_call = builder->create<mlir::TF::StatefulPartitionedCallOp>(
125 op->getLoc(), op->getResultTypes(), new_operands,
126 stateful_partitioned_call.f(), stateful_partitioned_call.config(),
127 stateful_partitioned_call.config_proto(),
128 stateful_partitioned_call.executor_type());
129 } else {
130 auto partitioned_call = llvm::cast<mlir::TF::PartitionedCallOp>(op);
131 new_call = builder->create<mlir::TF::PartitionedCallOp>(
132 op->getLoc(), op->getResultTypes(), new_operands, partitioned_call.f(),
133 partitioned_call.config(), partitioned_call.config_proto(),
134 partitioned_call.executor_type());
135 }
136
137 for (auto results : llvm::zip(op->getResults(), new_call->getResults()))
138 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
139
140 op->erase();
141
142 return mlir::success();
143 }
144
145 // Pass that rewrites the functions in graph so that 0th argument of the main
146 // function (i.e. device_id) is present on all functions in the graph.
147 struct DTensorPropagateDeviceIdToFunctionArgs
148 : public DTensorPropagateDeviceIdToFunctionArgsBase<
149 DTensorPropagateDeviceIdToFunctionArgs> {
runOnOperationtensorflow::dtensor::__anonfa7b3f010111::DTensorPropagateDeviceIdToFunctionArgs150 void runOnOperation() override {
151 mlir::MLIRContext& context = getContext();
152 auto module = getOperation();
153 mlir::OpBuilder builder(&context);
154
155 // Extracts device id argument from main function.
156 mlir::func::FuncOp main_func =
157 module.lookupSymbol<mlir::func::FuncOp>("main");
158 auto device_id_or_status = DeviceId(&main_func.getBody().front().front());
159 if (!device_id_or_status.ok()) {
160 main_func.emitOpError(
161 "Error in PropagateDeviceIdToFunctionArgs pass. Main function must "
162 "have device id as 0th function argument.");
163 return signalPassFailure();
164 }
165 auto device_id_from_main_function = device_id_or_status.ValueOrDie();
166 // First iterate through all functions to rewrite and update the signatures
167 // first.
168 const auto functions_to_update = FindFunctionsToRewrite(module);
169 for (const auto& function_to_update : functions_to_update)
170 PrependArgumentToFunction(function_to_update.function,
171 device_id_from_main_function.getType(),
172 &builder);
173
174 // Once all function signatures are updated, rewrite the callsite ops.
175 for (const auto& function_to_update : functions_to_update) {
176 for (auto call_site_op : function_to_update.callsite_ops) {
177 if (mlir::failed(PrependDeviceIdToCallsites(&builder, call_site_op)))
178 return signalPassFailure();
179 }
180 }
181 };
182 };
183
184 } // namespace
185
186 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorPropagateDeviceIdToFunctionArgs()187 CreateDTensorPropagateDeviceIdToFunctionArgs() {
188 return std::make_unique<DTensorPropagateDeviceIdToFunctionArgs>();
189 }
190
191 } // namespace dtensor
192 } // namespace tensorflow
193