1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/Dialect/Linalg/IR/Linalg.h"  // from @llvm-project
22 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
23 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
25 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h"
26 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
27 
28 namespace mlir {
29 namespace kernel_gen {
30 namespace transforms {
31 namespace {
32 
33 constexpr StringRef kPrintStringFuncName = "printCString";
34 
35 #define GEN_PASS_CLASSES
36 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc"
37 
EmitMemRefPrint(Location loc,Type element_type,Value arg,OpBuilder * b)38 Operation* EmitMemRefPrint(Location loc, Type element_type, Value arg,
39                            OpBuilder* b) {
40   StringRef func_name;
41   if (element_type.isF32()) {
42     func_name = "printMemrefF32";
43   }
44   if (element_type.isF64()) {
45     func_name = "printMemrefF64";
46   }
47   if (element_type.isInteger(32)) {
48     func_name = "printMemrefI32";
49   }
50   if (element_type.isInteger(64) || element_type.isIndex()) {
51     func_name = "printMemrefI64";
52   }
53   assert(!func_name.empty() &&
54          "Did not find a print function for the element type");
55 
56   auto caller_func =
57       b->getInsertionBlock()->getParent()->getParentOfType<func::FuncOp>();
58   auto func_name_attr = b->getStringAttr(func_name);
59 
60   auto callee_func = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
61       caller_func, func_name_attr);
62   if (!callee_func) {
63     OpBuilder::InsertionGuard insertGuard(*b);
64 
65     auto module = caller_func->getParentOfType<ModuleOp>();
66     b->setInsertionPointToStart(module.getBody());
67     auto func_type = FunctionType::get(b->getContext(), arg.getType(),
68                                        /*results=*/llvm::None);
69     callee_func =
70         b->create<func::FuncOp>(module.getLoc(), func_name, func_type);
71     callee_func.setPrivate();
72   }
73   return b->create<func::CallOp>(loc, callee_func, arg);
74 }
75 
IsElementTypePrintalble(Type element_type)76 bool IsElementTypePrintalble(Type element_type) {
77   return element_type.isF32() || element_type.isF64() ||
78          element_type.isInteger(32) || element_type.isInteger(64) ||
79          element_type.isIndex();
80 }
81 
EmitMemRefPrint(Location loc,Value memref,OpBuilder * b)82 void EmitMemRefPrint(Location loc, Value memref, OpBuilder* b) {
83   auto memref_type = memref.getType();
84   if (auto unranked_type = memref_type.dyn_cast<UnrankedMemRefType>()) {
85     Type element_type = unranked_type.getElementType();
86     if (!IsElementTypePrintalble(element_type)) return;
87 
88     EmitMemRefPrint(loc, element_type, memref, b);
89   }
90   if (auto ranked_type = memref_type.dyn_cast<MemRefType>()) {
91     Type element_type = ranked_type.getElementType();
92     if (!IsElementTypePrintalble(element_type)) return;
93 
94     if (element_type.isIndex()) {
95       element_type = b->getI64Type();
96       ranked_type = MemRefType::get(ranked_type.getShape(), element_type,
97                                     ranked_type.getLayout(),
98                                     ranked_type.getMemorySpace());
99       memref = b->create<arith::IndexCastOp>(loc, ranked_type, memref);
100     }
101 
102     auto unranked_type = UnrankedMemRefType::get(
103         element_type, ranked_type.getMemorySpaceAsInt());
104     Value unranked_memref =
105         b->create<memref::CastOp>(loc, unranked_type, memref);
106     EmitMemRefPrint(loc, element_type, unranked_memref, b);
107   }
108 }
109 
ExtractValuesToPrint(Operation * op)110 SmallVector<Value> ExtractValuesToPrint(Operation* op) {
111   if (isa<memref::ReinterpretCastOp>(op) || isa<memref::ReshapeOp>(op) ||
112       isa<memref::ExpandShapeOp>(op) || isa<memref::CollapseShapeOp>(op)) {
113     return {op->getResult(0)};
114   }
115   if (auto linalg = dyn_cast<linalg::LinalgOp>(op)) {
116     return linalg.getOutputBufferOperands();
117   }
118   if (auto loop = dyn_cast<gml_st::LoopOp>(op)) {
119     return loop.outputs();
120   }
121   if (auto loop = dyn_cast<scf::ForOp>(op)) {
122     return loop.getIterOperands();
123   }
124   if (auto copy = dyn_cast<memref::CopyOp>(op)) {
125     return {copy.getTarget()};
126   }
127   return {};
128 }
129 
EmitOperationPrint(Operation * op,OpBuilder * b)130 void EmitOperationPrint(Operation* op, OpBuilder* b) {
131   std::string debug_str = "\n\nPrint memref content after the following op\n";
132   llvm::raw_string_ostream output_stream(debug_str);
133 
134   mlir::OpPrintingFlags flags;
135   op->print(output_stream, flags);
136   output_stream << "\n\n";
137 
138   Location loc = op->getLoc();
139   Value message_constant = CreateOrFindGlobalStringConstant(
140       loc, GetGlobalName("debug_op", debug_str), debug_str, b);
141 
142   // Insert function call.
143   MLIRContext* ctx = op->getContext();
144   auto func_type = LLVM::LLVMFunctionType::get(
145       LLVM::LLVMVoidType::get(op->getContext()),
146       {LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8))});
147   FlatSymbolRefAttr tf_func_ref =
148       GetOrInsertLLVMFunction(kPrintStringFuncName, func_type, op, b);
149   b->create<LLVM::CallOp>(loc, llvm::None, tf_func_ref,
150                           llvm::makeArrayRef({message_constant}));
151 }
152 
153 // The pass inserts printing on every mutation of memrefs.
154 struct EmbedMemRefPrintsPass
155     : public EmbedMemRefPrintsPassBase<EmbedMemRefPrintsPass> {
runOnOperationmlir::kernel_gen::transforms::__anon28f2d82a0111::EmbedMemRefPrintsPass156   void runOnOperation() override {
157     ModuleOp module = getOperation();
158     module.walk([&](func::FuncOp func) {
159       if (func.isDeclaration()) return;
160       Block* body = &func.getBody().front();
161 
162       // Print arguments.
163       OpBuilder b(&getContext());
164       b.setInsertionPointToStart(body);
165       Location loc = func.getLoc();
166       auto args = func.getArguments();
167       if (!args.empty()) {
168         EmitOperationPrint(func, &b);
169       }
170       for (auto arg : args) {
171         EmitMemRefPrint(loc, arg, &b);
172       }
173       // Print buffers after every change.
174       for (auto& op : func.getBody().front().getOperations()) {
175         b.setInsertionPointAfter(&op);
176         auto memrefs = ExtractValuesToPrint(&op);
177         if (!memrefs.empty()) {
178           EmitOperationPrint(&op, &b);
179         }
180         for (auto memref : memrefs) {
181           EmitMemRefPrint(op.getLoc(), memref, &b);
182         }
183       }
184     });
185   }
186 };
187 
188 }  // namespace
189 
CreateEmbedMemRefPrintsPass()190 std::unique_ptr<OperationPass<ModuleOp>> CreateEmbedMemRefPrintsPass() {
191   return std::make_unique<EmbedMemRefPrintsPass>();
192 }
193 
194 }  // namespace transforms
195 }  // namespace kernel_gen
196 }  // namespace mlir
197