xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/propagate_device_id_to_function_args.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
24 #include "mlir/IR/Visitors.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Pass/PassManager.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
32 #include "tensorflow/dtensor/mlir/device_utils.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
34 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
35 #include "tensorflow/dtensor/mlir/op_utils.h"
36 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
37 
38 namespace tensorflow {
39 namespace dtensor {
40 namespace {
41 
42 // Holds information on functions to rewrite. `function` is the function
43 // definition or function that needs to be updated and `callsite_ops` holds a
44 // list of ops that calls the `function`.
45 struct FunctionToChangeInfo {
46   mlir::func::FuncOp function;
47   llvm::SmallVector<mlir::Operation*, 4> callsite_ops;
48 };
49 
50 // Finds all functions in graph that is not a public functions and retrieves
51 // their callsite operations.
FindFunctionsToRewrite(mlir::ModuleOp module)52 llvm::SmallVector<FunctionToChangeInfo, 4> FindFunctionsToRewrite(
53     mlir::ModuleOp module) {
54   llvm::SmallVector<FunctionToChangeInfo, 4> functions_to_change;
55   module.walk([&](mlir::Operation* op) {
56     if (!llvm::isa<mlir::TF::StatefulPartitionedCallOp,
57                    mlir::TF::PartitionedCallOp>(op))
58       return;
59 
60     // Extract function symbol from PartitionedCall or StatefulPartitionedCall
61     // op.
62     llvm::StringRef symbol;
63     if (auto call_op =
64             llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
65       symbol = call_op.f();
66     } else {
67       auto symbol_ref = llvm::dyn_cast<mlir::TF::PartitionedCallOp>(op).f();
68       if (!symbol_ref.isa<mlir::FlatSymbolRefAttr>()) return;
69       symbol = symbol_ref.getRootReference().getValue();
70     }
71 
72     // If function definition could be found, then extract all function usages.
73     auto function = MaybeFindFunction(op);
74     if (!function || function->isPublic()) return;
75 
76     auto function_uses = mlir::SymbolTable::getSymbolUses(
77         mlir::StringAttr::get(module.getContext(), symbol),
78         &module.getBodyRegion());
79     if (!function_uses) return;
80 
81     llvm::SmallVector<mlir::Operation*, 4> function_use_ops;
82     for (auto function_use : *function_uses)
83       function_use_ops.emplace_back(function_use.getUser());
84 
85     functions_to_change.emplace_back(
86         FunctionToChangeInfo{function.value(), function_use_ops});
87   });
88 
89   return functions_to_change;
90 }
91 
92 // Rewrites function such that 0th argument of type `type` is added to
93 // `function`.
PrependArgumentToFunction(mlir::func::FuncOp function,mlir::Type type,mlir::OpBuilder * builder)94 void PrependArgumentToFunction(mlir::func::FuncOp function, mlir::Type type,
95                                mlir::OpBuilder* builder) {
96   auto& function_body = function.front();
97   function_body.insertArgument(static_cast<unsigned>(0), type,
98                                function.getLoc());
99   auto new_argument_types =
100       llvm::to_vector<4>(function_body.getArgumentTypes());
101   function.setType(
102       mlir::FunctionType::get(builder->getContext(), new_argument_types,
103                               function.getFunctionType().getResults()));
104 }
105 
106 // Rewrites function callsites ops. As function signatures are already updated,
107 // simply add 0th argument of the parent function to 0th operand of the callsite
108 // operation.
PrependDeviceIdToCallsites(mlir::OpBuilder * builder,mlir::Operation * op)109 mlir::LogicalResult PrependDeviceIdToCallsites(mlir::OpBuilder* builder,
110                                                mlir::Operation* op) {
111   auto device_id_or_status = DeviceId(op);
112   if (!device_id_or_status.ok())
113     return op->emitOpError(
114         "Failed during PropagateDeviceIdToFunctionArgs pass. All functions "
115         "must have device id as 0th argument.");
116 
117   auto new_operands = llvm::to_vector<4>(op->getOperands());
118   new_operands.insert(new_operands.begin(), device_id_or_status.ValueOrDie());
119 
120   builder->setInsertionPoint(op);
121   mlir::Operation* new_call = nullptr;
122   if (auto stateful_partitioned_call =
123           llvm::dyn_cast<mlir::TF::StatefulPartitionedCallOp>(op)) {
124     new_call = builder->create<mlir::TF::StatefulPartitionedCallOp>(
125         op->getLoc(), op->getResultTypes(), new_operands,
126         stateful_partitioned_call.f(), stateful_partitioned_call.config(),
127         stateful_partitioned_call.config_proto(),
128         stateful_partitioned_call.executor_type());
129   } else {
130     auto partitioned_call = llvm::cast<mlir::TF::PartitionedCallOp>(op);
131     new_call = builder->create<mlir::TF::PartitionedCallOp>(
132         op->getLoc(), op->getResultTypes(), new_operands, partitioned_call.f(),
133         partitioned_call.config(), partitioned_call.config_proto(),
134         partitioned_call.executor_type());
135   }
136 
137   for (auto results : llvm::zip(op->getResults(), new_call->getResults()))
138     std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
139 
140   op->erase();
141 
142   return mlir::success();
143 }
144 
145 // Pass that rewrites the functions in graph so that 0th argument of the main
146 // function (i.e. device_id) is present on all functions in the graph.
147 struct DTensorPropagateDeviceIdToFunctionArgs
148     : public DTensorPropagateDeviceIdToFunctionArgsBase<
149           DTensorPropagateDeviceIdToFunctionArgs> {
runOnOperationtensorflow::dtensor::__anonfa7b3f010111::DTensorPropagateDeviceIdToFunctionArgs150   void runOnOperation() override {
151     mlir::MLIRContext& context = getContext();
152     auto module = getOperation();
153     mlir::OpBuilder builder(&context);
154 
155     // Extracts device id argument from main function.
156     mlir::func::FuncOp main_func =
157         module.lookupSymbol<mlir::func::FuncOp>("main");
158     auto device_id_or_status = DeviceId(&main_func.getBody().front().front());
159     if (!device_id_or_status.ok()) {
160       main_func.emitOpError(
161           "Error in PropagateDeviceIdToFunctionArgs pass. Main function must "
162           "have device id as 0th function argument.");
163       return signalPassFailure();
164     }
165     auto device_id_from_main_function = device_id_or_status.ValueOrDie();
166     // First iterate through all functions to rewrite and update the signatures
167     // first.
168     const auto functions_to_update = FindFunctionsToRewrite(module);
169     for (const auto& function_to_update : functions_to_update)
170       PrependArgumentToFunction(function_to_update.function,
171                                 device_id_from_main_function.getType(),
172                                 &builder);
173 
174     // Once all function signatures are updated, rewrite the callsite ops.
175     for (const auto& function_to_update : functions_to_update) {
176       for (auto call_site_op : function_to_update.callsite_ops) {
177         if (mlir::failed(PrependDeviceIdToCallsites(&builder, call_site_op)))
178           return signalPassFailure();
179       }
180     }
181   };
182 };
183 
184 }  // namespace
185 
186 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorPropagateDeviceIdToFunctionArgs()187 CreateDTensorPropagateDeviceIdToFunctionArgs() {
188   return std::make_unique<DTensorPropagateDeviceIdToFunctionArgs>();
189 }
190 
191 }  // namespace dtensor
192 }  // namespace tensorflow
193