1 #pragma once 2 3 #include <torch/csrc/jit/tensorexpr/ir.h> 4 #include <torch/csrc/jit/tensorexpr/ir_visitor.h> 5 #include <torch/csrc/jit/tensorexpr/stmt.h> 6 #include <torch/csrc/jit/tensorexpr/tensor.h> 7 8 #include <utility> 9 10 namespace torch::jit::tensorexpr { 11 class HasRand : public IRVisitor { 12 public: HasRand(StmtPtr stmt)13 HasRand(StmtPtr stmt) : stmt_(std::move(stmt)) { 14 stmt_->accept(this); 15 } 16 has_rand()17 bool has_rand() const { 18 return has_rand_; 19 } 20 21 private: visit(const IntrinsicsPtr & v)22 void visit(const IntrinsicsPtr& v) override { 23 if (v->op_type() == IntrinsicsOp::kRand) { 24 has_rand_ = true; 25 } else { 26 IRVisitor::visit(v); 27 } 28 } 29 StmtPtr stmt_; 30 bool has_rand_ = false; 31 }; 32 33 template <typename Op> 34 class NodeFinder : public IRVisitor { 35 public: visit(const NodePtr<Op> & v)36 void visit(const NodePtr<Op>& v) override { 37 nodes.push_back((NodePtr<Op>)v); 38 IRVisitor::visit(v); 39 } 40 find(const StmtPtr & s)41 static std::vector<NodePtr<Op>> find(const StmtPtr& s) { 42 NodeFinder<Op> nf; 43 s->accept(&nf); 44 return nf.nodes; 45 } 46 find(const ExprPtr & e)47 static std::vector<NodePtr<Op>> find(const ExprPtr& e) { 48 NodeFinder<Op> nf; 49 e->accept(&nf); 50 return nf.nodes; 51 } 52 53 std::vector<NodePtr<Op>> nodes; 54 }; 55 56 class VarFinder : public IRVisitor { 57 public: visit(const VarPtr & v)58 void visit(const VarPtr& v) override { 59 vars_.insert(v); 60 IRVisitor::visit(v); 61 } 62 find(const StmtPtr & s)63 static std::unordered_set<VarPtr> find(const StmtPtr& s) { 64 VarFinder nf; 65 s->accept(&nf); 66 return nf.vars(); 67 } 68 find(const ExprPtr & e)69 static std::unordered_set<VarPtr> find(const ExprPtr& e) { 70 VarFinder nf; 71 e->accept(&nf); 72 return nf.vars(); 73 } 74 vars()75 const std::unordered_set<VarPtr>& vars() { 76 return vars_; 77 } 78 79 private: 80 std::unordered_set<VarPtr> vars_; 81 }; 82 83 class BufFinder : public IRVisitor { 84 public: visit(const BufPtr & v)85 void visit(const BufPtr& v) override { 86 bufs_.insert(v); 87 IRVisitor::visit(v); 88 } 89 find(const StmtPtr & s)90 static std::unordered_set<BufPtr> find(const StmtPtr& s) { 91 BufFinder nf; 92 s->accept(&nf); 93 return nf.bufs(); 94 } 95 find(const ExprPtr & e)96 static std::unordered_set<BufPtr> find(const ExprPtr& e) { 97 BufFinder nf; 98 e->accept(&nf); 99 return nf.bufs(); 100 } 101 bufs()102 const std::unordered_set<BufPtr>& bufs() { 103 return bufs_; 104 } 105 106 private: 107 std::unordered_set<BufPtr> bufs_; 108 }; 109 110 // Finds all kinds of write operations to the provided Buf. 111 class WritesToBuf : public IRVisitor { 112 public: WritesToBuf(BufPtr target)113 WritesToBuf(BufPtr target) : target_(std::move(target)) {} 114 writes()115 std::vector<StmtPtr> writes() { 116 return writes_; 117 } 118 find(const StmtPtr & s,BufPtr b)119 static std::vector<StmtPtr> find(const StmtPtr& s, BufPtr b) { 120 WritesToBuf finder(std::move(b)); 121 s->accept(&finder); 122 return finder.writes(); 123 } 124 125 private: visit(const StorePtr & v)126 void visit(const StorePtr& v) override { 127 if (v->buf() == target_) { 128 writes_.push_back(v); 129 } 130 } 131 visit(const AtomicAddPtr & v)132 void visit(const AtomicAddPtr& v) override { 133 if (v->buf() == target_) { 134 writes_.push_back(v); 135 } 136 } 137 138 BufPtr target_; 139 std::vector<StmtPtr> writes_; 140 }; 141 142 class StmtsReadingBuf : public IRVisitor { 143 public: StmtsReadingBuf(BufPtr target)144 StmtsReadingBuf(BufPtr target) : target_(std::move(target)) {} 145 reads()146 std::vector<StmtPtr> reads() { 147 return reads_; 148 } 149 find(const StmtPtr & s,BufPtr b)150 static std::vector<StmtPtr> find(const StmtPtr& s, BufPtr b) { 151 StmtsReadingBuf finder(std::move(b)); 152 s->accept(&finder); 153 return finder.reads(); 154 } 155 156 private: readsBuffer(const StmtPtr & s)157 bool readsBuffer(const StmtPtr& s) { 158 auto loads = NodeFinder<Load>::find(s); 159 for (const auto& l : loads) { 160 if (l->buf() == target_) { 161 return true; 162 } 163 } 164 return false; 165 } 166 visit(const StorePtr & v)167 void visit(const StorePtr& v) override { 168 if (readsBuffer(v)) { 169 reads_.push_back(v); 170 } 171 } 172 visit(const LetPtr & v)173 void visit(const LetPtr& v) override { 174 if (readsBuffer(v)) { 175 reads_.push_back(v); 176 } 177 } 178 visit(const CondPtr & v)179 void visit(const CondPtr& v) override { 180 if (readsBuffer(v)) { 181 reads_.push_back(v); 182 } 183 } 184 visit(const AtomicAddPtr & v)185 void visit(const AtomicAddPtr& v) override { 186 if (readsBuffer(v)) { 187 reads_.push_back(v); 188 } 189 } 190 191 BufPtr target_; 192 std::vector<StmtPtr> reads_; 193 }; 194 195 class ExternalAllocBufFinder : public IRVisitor { 196 public: visit(const ExternalCallWithAllocPtr & v)197 void visit(const ExternalCallWithAllocPtr& v) override { 198 const auto& bufs_out = v->buf_out_args(); 199 bufs_.insert(bufs_out.begin(), bufs_out.end()); 200 IRVisitor::visit(v); 201 } 202 find(const StmtPtr & s)203 static std::unordered_set<BufPtr> find(const StmtPtr& s) { 204 ExternalAllocBufFinder f; 205 s->accept(&f); 206 return f.bufs(); 207 } 208 find(const ExprPtr & e)209 static std::unordered_set<BufPtr> find(const ExprPtr& e) { 210 ExternalAllocBufFinder f; 211 e->accept(&f); 212 return f.bufs(); 213 } 214 bufs()215 const std::unordered_set<BufPtr>& bufs() { 216 return bufs_; 217 } 218 219 private: 220 std::unordered_set<BufPtr> bufs_; 221 }; 222 223 // Traverses the IR to determine if a particular Var is modified within it. 224 class ModifiesVarChecker : public IRVisitor { 225 public: ModifiesVarChecker(VarPtr v)226 ModifiesVarChecker(VarPtr v) : var_(std::move(v)) {} 227 check(const StmtPtr & s,VarPtr v)228 static bool check(const StmtPtr& s, VarPtr v) { 229 ModifiesVarChecker checker(std::move(v)); 230 s->accept(&checker); 231 return checker.found(); 232 } 233 found()234 bool found() { 235 return found_; 236 } 237 238 private: visit(const StorePtr & v)239 void visit(const StorePtr& v) override { 240 if (v->buf()->base_handle() == var_) { 241 found_ = true; 242 return; 243 } 244 IRVisitor::visit(v); 245 } 246 visit(const AtomicAddPtr & v)247 void visit(const AtomicAddPtr& v) override { 248 if (v->buf()->base_handle() == var_) { 249 found_ = true; 250 return; 251 } 252 IRVisitor::visit(v); 253 } 254 visit(const LetPtr & v)255 void visit(const LetPtr& v) override { 256 if (v->var() == var_) { 257 found_ = true; 258 return; 259 } 260 IRVisitor::visit(v); 261 } 262 visit(const ForPtr & v)263 void visit(const ForPtr& v) override { 264 if (v->var() == var_) { 265 found_ = true; 266 return; 267 } 268 IRVisitor::visit(v); 269 } 270 271 VarPtr var_; 272 bool found_{false}; 273 }; 274 275 // Traverse the Block stmt to identify the live range of the specified buf. The 276 // live range, indicated by a pair of integers, specifies the first and last 277 // stmt in block stmts that access to the buf. 278 class BufLiveRange : public IRVisitor { 279 public: BufLiveRange(BufPtr b)280 BufLiveRange(BufPtr b) : buf_(std::move(b)) {} 281 liveRange(const StmtPtr & s,BufPtr b)282 static std::tuple<int32_t, int32_t> liveRange(const StmtPtr& s, BufPtr b) { 283 BlockPtr block = to<Block>(s); 284 // We Only analyze buffer live ranges for block stmts. 285 if (!block) { 286 return std::make_tuple(0, 0); 287 } 288 289 BufLiveRange analyzer(std::move(b)); 290 block->accept(&analyzer); 291 return analyzer.getLiveRange(); 292 } 293 294 private: getLiveRange()295 std::tuple<int32_t, int32_t> getLiveRange() { 296 return std::make_tuple(begin_, end_); 297 } 298 hasBufReads(const StmtPtr & s)299 bool hasBufReads(const StmtPtr& s) { 300 auto loads1 = NodeFinder<Load>::find(s); 301 for (const auto& l : loads1) { 302 if (l->buf() == buf_) { 303 return true; 304 } 305 } 306 auto loads2 = NodeFinder<ExternalCall>::find(s); 307 for (const auto& l : loads2) { 308 for (const auto& lb : l->buf_args()) { 309 if (lb == buf_) { 310 return true; 311 } 312 } 313 } 314 auto loads3 = NodeFinder<ExternalCallWithAlloc>::find(s); 315 for (const auto& l : loads3) { 316 for (const auto& lb : l->buf_args()) { 317 if (lb == buf_) { 318 return true; 319 } 320 } 321 } 322 return false; 323 } 324 hasBufWrites(const StmtPtr & s)325 bool hasBufWrites(const StmtPtr& s) { 326 auto writes1 = NodeFinder<Store>::find(s); 327 for (const auto& w : writes1) { 328 if (w->buf() == buf_) { 329 return true; 330 } 331 } 332 auto writes2 = NodeFinder<ExternalCall>::find(s); 333 for (const auto& w : writes2) { 334 if (w->buf() == buf_) { 335 return true; 336 } 337 } 338 auto writes3 = NodeFinder<ExternalCallWithAlloc>::find(s); 339 for (const auto& w : writes3) { 340 for (const auto& wb : w->buf_out_args()) { 341 if (wb == buf_) { 342 return true; 343 } 344 } 345 } 346 return false; 347 } 348 findAccAndUpdateLiveRange(const StmtPtr & s)349 void findAccAndUpdateLiveRange(const StmtPtr& s) { 350 bool has_reads = hasBufReads(s), has_writes = hasBufWrites(s); 351 if (has_reads || has_writes) { 352 if (begin_ == -1) { 353 begin_ = curr_index_; 354 }; 355 end_ = curr_index_; 356 } 357 } 358 visit(const BlockPtr & v)359 void visit(const BlockPtr& v) override { 360 for (const StmtPtr& s : *v) { 361 curr_index_ += 1; 362 findAccAndUpdateLiveRange(s); 363 } 364 } 365 366 BufPtr buf_; 367 int32_t begin_ = -1; 368 int32_t end_ = -1; 369 int32_t curr_index_ = -1; 370 }; 371 372 // A class that analyzes the given program relevant for Block backend 373 // It creates a map of multi dim buffers and their flat versions 374 class CreateBufferMap : public IRVisitor { 375 public: getBufferMap()376 const std::unordered_map<std::string, BufPtr>& getBufferMap() const { 377 return map_input_to_tensor_bufs_; 378 } 379 380 private: visit(const StorePtr & v)381 void visit(const StorePtr& v) override { 382 auto load_node = to<Load>(v->value()); 383 if (load_node) { 384 auto t_buf = load_node->buf(); 385 map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf()); 386 } else { 387 auto add_node = to<Add>(v->value()); 388 auto mul_node = to<Mul>(v->value()); 389 // This means for now, v->value() can be Add or Mul 390 TORCH_INTERNAL_ASSERT(add_node || mul_node, buildErrorMessage()); 391 map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf()); 392 } 393 v->value()->accept(this); 394 } 395 std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_; 396 }; 397 398 } // namespace torch::jit::tensorexpr 399