1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
22 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
23 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h"
26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
27
28 namespace mlir {
29 namespace kernel_gen {
30 namespace transforms {
31 namespace {
32
33 constexpr StringRef kPrintStringFuncName = "printCString";
34
35 #define GEN_PASS_CLASSES
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
37
EmitMemRefPrint(Location loc,Type element_type,Value arg,OpBuilder * b)38 Operation* EmitMemRefPrint(Location loc, Type element_type, Value arg,
39 OpBuilder* b) {
40 StringRef func_name;
41 if (element_type.isF32()) {
42 func_name = "printMemrefF32";
43 }
44 if (element_type.isF64()) {
45 func_name = "printMemrefF64";
46 }
47 if (element_type.isInteger(32)) {
48 func_name = "printMemrefI32";
49 }
50 if (element_type.isInteger(64) || element_type.isIndex()) {
51 func_name = "printMemrefI64";
52 }
53 assert(!func_name.empty() &&
54 "Did not find a print function for the element type");
55
56 auto caller_func =
57 b->getInsertionBlock()->getParent()->getParentOfType<func::FuncOp>();
58 auto func_name_attr = b->getStringAttr(func_name);
59
60 auto callee_func = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
61 caller_func, func_name_attr);
62 if (!callee_func) {
63 OpBuilder::InsertionGuard insertGuard(*b);
64
65 auto module = caller_func->getParentOfType<ModuleOp>();
66 b->setInsertionPointToStart(module.getBody());
67 auto func_type = FunctionType::get(b->getContext(), arg.getType(),
68 /*results=*/llvm::None);
69 callee_func =
70 b->create<func::FuncOp>(module.getLoc(), func_name, func_type);
71 callee_func.setPrivate();
72 }
73 return b->create<func::CallOp>(loc, callee_func, arg);
74 }
75
IsElementTypePrintalble(Type element_type)76 bool IsElementTypePrintalble(Type element_type) {
77 return element_type.isF32() || element_type.isF64() ||
78 element_type.isInteger(32) || element_type.isInteger(64) ||
79 element_type.isIndex();
80 }
81
EmitMemRefPrint(Location loc,Value memref,OpBuilder * b)82 void EmitMemRefPrint(Location loc, Value memref, OpBuilder* b) {
83 auto memref_type = memref.getType();
84 if (auto unranked_type = memref_type.dyn_cast<UnrankedMemRefType>()) {
85 Type element_type = unranked_type.getElementType();
86 if (!IsElementTypePrintalble(element_type)) return;
87
88 EmitMemRefPrint(loc, element_type, memref, b);
89 }
90 if (auto ranked_type = memref_type.dyn_cast<MemRefType>()) {
91 Type element_type = ranked_type.getElementType();
92 if (!IsElementTypePrintalble(element_type)) return;
93
94 if (element_type.isIndex()) {
95 element_type = b->getI64Type();
96 ranked_type = MemRefType::get(ranked_type.getShape(), element_type,
97 ranked_type.getLayout(),
98 ranked_type.getMemorySpace());
99 memref = b->create<arith::IndexCastOp>(loc, ranked_type, memref);
100 }
101
102 auto unranked_type = UnrankedMemRefType::get(
103 element_type, ranked_type.getMemorySpaceAsInt());
104 Value unranked_memref =
105 b->create<memref::CastOp>(loc, unranked_type, memref);
106 EmitMemRefPrint(loc, element_type, unranked_memref, b);
107 }
108 }
109
ExtractValuesToPrint(Operation * op)110 SmallVector<Value> ExtractValuesToPrint(Operation* op) {
111 if (isa<memref::ReinterpretCastOp>(op) || isa<memref::ReshapeOp>(op) ||
112 isa<memref::ExpandShapeOp>(op) || isa<memref::CollapseShapeOp>(op)) {
113 return {op->getResult(0)};
114 }
115 if (auto linalg = dyn_cast<linalg::LinalgOp>(op)) {
116 return linalg.getOutputBufferOperands();
117 }
118 if (auto loop = dyn_cast<gml_st::LoopOp>(op)) {
119 return loop.outputs();
120 }
121 if (auto loop = dyn_cast<scf::ForOp>(op)) {
122 return loop.getIterOperands();
123 }
124 if (auto copy = dyn_cast<memref::CopyOp>(op)) {
125 return {copy.getTarget()};
126 }
127 return {};
128 }
129
EmitOperationPrint(Operation * op,OpBuilder * b)130 void EmitOperationPrint(Operation* op, OpBuilder* b) {
131 std::string debug_str = "\n\nPrint memref content after the following op\n";
132 llvm::raw_string_ostream output_stream(debug_str);
133
134 mlir::OpPrintingFlags flags;
135 op->print(output_stream, flags);
136 output_stream << "\n\n";
137
138 Location loc = op->getLoc();
139 Value message_constant = CreateOrFindGlobalStringConstant(
140 loc, GetGlobalName("debug_op", debug_str), debug_str, b);
141
142 // Insert function call.
143 MLIRContext* ctx = op->getContext();
144 auto func_type = LLVM::LLVMFunctionType::get(
145 LLVM::LLVMVoidType::get(op->getContext()),
146 {LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8))});
147 FlatSymbolRefAttr tf_func_ref =
148 GetOrInsertLLVMFunction(kPrintStringFuncName, func_type, op, b);
149 b->create<LLVM::CallOp>(loc, llvm::None, tf_func_ref,
150 llvm::makeArrayRef({message_constant}));
151 }
152
153 // The pass inserts printing on every mutation of memrefs.
154 struct EmbedMemRefPrintsPass
155 : public EmbedMemRefPrintsPassBase<EmbedMemRefPrintsPass> {
runOnOperationmlir::kernel_gen::transforms::__anon28f2d82a0111::EmbedMemRefPrintsPass156 void runOnOperation() override {
157 ModuleOp module = getOperation();
158 module.walk([&](func::FuncOp func) {
159 if (func.isDeclaration()) return;
160 Block* body = &func.getBody().front();
161
162 // Print arguments.
163 OpBuilder b(&getContext());
164 b.setInsertionPointToStart(body);
165 Location loc = func.getLoc();
166 auto args = func.getArguments();
167 if (!args.empty()) {
168 EmitOperationPrint(func, &b);
169 }
170 for (auto arg : args) {
171 EmitMemRefPrint(loc, arg, &b);
172 }
173 // Print buffers after every change.
174 for (auto& op : func.getBody().front().getOperations()) {
175 b.setInsertionPointAfter(&op);
176 auto memrefs = ExtractValuesToPrint(&op);
177 if (!memrefs.empty()) {
178 EmitOperationPrint(&op, &b);
179 }
180 for (auto memref : memrefs) {
181 EmitMemRefPrint(op.getLoc(), memref, &b);
182 }
183 }
184 });
185 }
186 };
187
188 } // namespace
189
CreateEmbedMemRefPrintsPass()190 std::unique_ptr<OperationPass<ModuleOp>> CreateEmbedMemRefPrintsPass() {
191 return std::make_unique<EmbedMemRefPrintsPass>();
192 }
193
194 } // namespace transforms
195 } // namespace kernel_gen
196 } // namespace mlir
197