1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Builders.h" // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
25 #include "tensorflow/compiler/xla/mlir/ir/runtime/rt_ops.h"
26 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h"
27
28 namespace xla {
29 namespace runtime {
30
31 using namespace mlir; // NOLINT
32
33 #define GEN_PASS_CLASSES
34 #include "tensorflow/compiler/xla/mlir/transforms/runtime/passes.h.inc"
35
36 class ConvertToEntrypointPass
37 : public ConvertToEntrypointBase<ConvertToEntrypointPass> {
38 void runOnOperation() override;
39 };
40
ConvertCustomCallOperations(func::FuncOp func,Value exec_ctx)41 static void ConvertCustomCallOperations(func::FuncOp func, Value exec_ctx) {
42 MLIRContext* ctx = func->getContext();
43
44 SymbolTable sym_table(func->getParentOfType<ModuleOp>());
45
46 struct CustomCall {
47 func::CallOp call;
48 func::FuncOp callee;
49 llvm::StringRef target;
50 bool direct;
51 };
52
53 // Collect function calls that have to be converted to custom calls.
54 llvm::SmallVector<CustomCall> custom_calls;
55 func.walk([&](func::CallOp op) {
56 auto callee = dyn_cast<func::FuncOp>(sym_table.lookup(op.getCallee()));
57 if (!callee) return;
58
59 // Check if the call is an indirect custom call ...
60 StringAttr target = callee->getAttrOfType<StringAttr>("rt.custom_call");
61 if (target) custom_calls.push_back({op, callee, target.strref(), false});
62
63 // ... or a direct custom call.
64 target = callee->getAttrOfType<StringAttr>("rt.direct_custom_call");
65 if (target) custom_calls.push_back({op, callee, target.strref(), true});
66 });
67
68 // After converting to custom call we need to clean up all declarations.
69 llvm::DenseSet<func::FuncOp> erase_declarations;
70
71 // Rewrite function calls to `rt.custom_call` operations.
72 for (CustomCall custom_call : custom_calls) {
73 ImplicitLocOpBuilder b(custom_call.call.getLoc(), custom_call.call);
74
75 // Custom call intrinsic always returns the status flag.
76 llvm::SmallVector<Type> results = {StatusType::get(ctx)};
77 results.append(custom_call.call->getResultTypes().begin(),
78 custom_call.call->getResultTypes().end());
79
80 // Rewrite function call with a custom call, and check the return status.
81 auto call = b.create<CustomCallOp>(results, exec_ctx, custom_call.target,
82 custom_call.direct,
83 custom_call.call.getOperands());
84
85 // Copy optional attributes from the custom call function declaration.
86 llvm::ArrayRef<llvm::StringRef> callee_attrs =
87 custom_call.callee.getAttributeNames();
88 for (auto& attr : custom_call.callee->getAttrs()) {
89 if (isa_and_nonnull<RuntimeDialect>(attr.getNameDialect())) continue;
90 if (llvm::find(callee_attrs, attr.getName()) == callee_attrs.end())
91 call->setAttr(attr.getName(), attr.getValue());
92 }
93
94 // Copy optional attributes from the call operation to the custom call.
95 llvm::ArrayRef<llvm::StringRef> orig_attrs =
96 custom_call.call.getAttributeNames();
97 for (auto& attr : custom_call.call->getAttrs()) {
98 if (llvm::find(orig_attrs, attr.getName()) == orig_attrs.end())
99 call->setAttr(attr.getName(), attr.getValue());
100 }
101
102 b.create<cf::AssertOp>(
103 b.create<IsOkOp>(TypeRange(b.getI1Type()), call.status()),
104 b.getStringAttr("custom call '" + custom_call.target + "' failed"));
105
106 // Forward users of the original results to custom call results.
107 auto rets = llvm::zip(custom_call.call.getResults(),
108 llvm::drop_begin(call.getResults()));
109 llvm::for_each(rets, [](auto ret) {
110 std::get<0>(ret).replaceAllUsesWith(std::get<1>(ret));
111 });
112
113 // Keep track of custom call declaration to erase.
114 erase_declarations.insert(custom_call.callee);
115
116 // Erase the original function call operation.
117 custom_call.call.erase();
118 }
119
120 // Erase all converted custom calls declarations.
121 for (auto func : erase_declarations) sym_table.erase(func);
122 }
123
ConvertReturnOperations(func::FuncOp func,Value exec_ctx)124 static void ConvertReturnOperations(func::FuncOp func, Value exec_ctx) {
125 // Convert all returns to the Runtime API calls.
126 func.walk([&](func::ReturnOp ret) {
127 ImplicitLocOpBuilder b(ret.getLoc(), ret);
128
129 // Return all outputs via the `rt.set_output` operation.
130 for (auto& pair : llvm::enumerate(ret.getOperands())) {
131 b.create<SetOutputOp>(exec_ctx, pair.index(), pair.value());
132 }
133
134 // Replace original return with an empty one.
135 b.create<func::ReturnOp>();
136 ret.erase();
137 });
138
139 // Update function type to the function with empty results.
140 auto type = FunctionType::get(func.getContext(), func.getArgumentTypes(), {});
141 func.setType(type);
142 }
143
ConvertAssertOperations(func::FuncOp func,Value exec_ctx)144 static void ConvertAssertOperations(func::FuncOp func, Value exec_ctx) {
145 // Collect all assert operations in the function body.
146 llvm::SmallVector<cf::AssertOp> asserts;
147 func.walk([&](cf::AssertOp op) { asserts.push_back(op); });
148
149 // Rewrite all asserts to the Runtime API calls.
150 for (cf::AssertOp assert : asserts) {
151 ImplicitLocOpBuilder b(assert.getLoc(), assert);
152
153 // Split the block at the assert operation.
154 Block* block = assert->getBlock();
155 Block* ok = block->splitBlock(assert);
156
157 // Set up block for returning error.
158 Block* err = func.addBlock();
159 b.setInsertionPointToStart(err);
160 b.create<SetErrorOp>(exec_ctx, assert.getMsg());
161 b.create<func::ReturnOp>();
162
163 // Branch into the error block if assertion failed.
164 b.setInsertionPointToEnd(block);
165 b.create<cf::CondBranchOp>(assert.getArg(), ok, err);
166
167 // Erase the original assert operation.
168 assert.erase();
169 }
170 }
171
PrependExecutionContextArgument(func::FuncOp func)172 static Value PrependExecutionContextArgument(func::FuncOp func) {
173 Type new_type = KernelContextType::get(func.getContext());
174 DictionaryAttr attr = DictionaryAttr::get(func.getContext());
175 func.insertArguments({0}, {new_type}, {attr}, {func.getLoc()});
176 return func.getArgument(0);
177 }
178
ConvertToEntrypoint(func::FuncOp func)179 static void ConvertToEntrypoint(func::FuncOp func) {
180 assert(func->hasAttr(kEntrypointAttrName));
181
182 Value exec_ctx = PrependExecutionContextArgument(func);
183 ConvertCustomCallOperations(func, exec_ctx);
184 ConvertReturnOperations(func, exec_ctx);
185 ConvertAssertOperations(func, exec_ctx);
186
187 // After conversion !rt.execution_context is a marker of an entrypoint.
188 func->removeAttr(kEntrypointAttrName);
189 }
190
runOnOperation()191 void ConvertToEntrypointPass::runOnOperation() {
192 llvm::SmallVector<func::FuncOp> entry_points;
193
194 // Collect entrypoint functions.
195 getOperation().walk([&](func::FuncOp op) {
196 if (op->hasAttr(kEntrypointAttrName)) entry_points.push_back(op);
197 });
198
199 llvm::for_each(entry_points, ConvertToEntrypoint);
200 }
201
CreateConvertToEntrypoint()202 std::unique_ptr<OperationPass<ModuleOp>> CreateConvertToEntrypoint() {
203 return std::make_unique<ConvertToEntrypointPass>();
204 }
205
206 } // namespace runtime
207 } // namespace xla
208