xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/analysis.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/ir.h>
4 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
5 #include <torch/csrc/jit/tensorexpr/stmt.h>
6 #include <torch/csrc/jit/tensorexpr/tensor.h>
7 
8 #include <utility>
9 
10 namespace torch::jit::tensorexpr {
11 class HasRand : public IRVisitor {
12  public:
HasRand(StmtPtr stmt)13   HasRand(StmtPtr stmt) : stmt_(std::move(stmt)) {
14     stmt_->accept(this);
15   }
16 
has_rand()17   bool has_rand() const {
18     return has_rand_;
19   }
20 
21  private:
visit(const IntrinsicsPtr & v)22   void visit(const IntrinsicsPtr& v) override {
23     if (v->op_type() == IntrinsicsOp::kRand) {
24       has_rand_ = true;
25     } else {
26       IRVisitor::visit(v);
27     }
28   }
29   StmtPtr stmt_;
30   bool has_rand_ = false;
31 };
32 
33 template <typename Op>
34 class NodeFinder : public IRVisitor {
35  public:
visit(const NodePtr<Op> & v)36   void visit(const NodePtr<Op>& v) override {
37     nodes.push_back((NodePtr<Op>)v);
38     IRVisitor::visit(v);
39   }
40 
find(const StmtPtr & s)41   static std::vector<NodePtr<Op>> find(const StmtPtr& s) {
42     NodeFinder<Op> nf;
43     s->accept(&nf);
44     return nf.nodes;
45   }
46 
find(const ExprPtr & e)47   static std::vector<NodePtr<Op>> find(const ExprPtr& e) {
48     NodeFinder<Op> nf;
49     e->accept(&nf);
50     return nf.nodes;
51   }
52 
53   std::vector<NodePtr<Op>> nodes;
54 };
55 
56 class VarFinder : public IRVisitor {
57  public:
visit(const VarPtr & v)58   void visit(const VarPtr& v) override {
59     vars_.insert(v);
60     IRVisitor::visit(v);
61   }
62 
find(const StmtPtr & s)63   static std::unordered_set<VarPtr> find(const StmtPtr& s) {
64     VarFinder nf;
65     s->accept(&nf);
66     return nf.vars();
67   }
68 
find(const ExprPtr & e)69   static std::unordered_set<VarPtr> find(const ExprPtr& e) {
70     VarFinder nf;
71     e->accept(&nf);
72     return nf.vars();
73   }
74 
vars()75   const std::unordered_set<VarPtr>& vars() {
76     return vars_;
77   }
78 
79  private:
80   std::unordered_set<VarPtr> vars_;
81 };
82 
83 class BufFinder : public IRVisitor {
84  public:
visit(const BufPtr & v)85   void visit(const BufPtr& v) override {
86     bufs_.insert(v);
87     IRVisitor::visit(v);
88   }
89 
find(const StmtPtr & s)90   static std::unordered_set<BufPtr> find(const StmtPtr& s) {
91     BufFinder nf;
92     s->accept(&nf);
93     return nf.bufs();
94   }
95 
find(const ExprPtr & e)96   static std::unordered_set<BufPtr> find(const ExprPtr& e) {
97     BufFinder nf;
98     e->accept(&nf);
99     return nf.bufs();
100   }
101 
bufs()102   const std::unordered_set<BufPtr>& bufs() {
103     return bufs_;
104   }
105 
106  private:
107   std::unordered_set<BufPtr> bufs_;
108 };
109 
110 // Finds all kinds of write operations to the provided Buf.
111 class WritesToBuf : public IRVisitor {
112  public:
WritesToBuf(BufPtr target)113   WritesToBuf(BufPtr target) : target_(std::move(target)) {}
114 
writes()115   std::vector<StmtPtr> writes() {
116     return writes_;
117   }
118 
find(const StmtPtr & s,BufPtr b)119   static std::vector<StmtPtr> find(const StmtPtr& s, BufPtr b) {
120     WritesToBuf finder(std::move(b));
121     s->accept(&finder);
122     return finder.writes();
123   }
124 
125  private:
visit(const StorePtr & v)126   void visit(const StorePtr& v) override {
127     if (v->buf() == target_) {
128       writes_.push_back(v);
129     }
130   }
131 
visit(const AtomicAddPtr & v)132   void visit(const AtomicAddPtr& v) override {
133     if (v->buf() == target_) {
134       writes_.push_back(v);
135     }
136   }
137 
138   BufPtr target_;
139   std::vector<StmtPtr> writes_;
140 };
141 
142 class StmtsReadingBuf : public IRVisitor {
143  public:
StmtsReadingBuf(BufPtr target)144   StmtsReadingBuf(BufPtr target) : target_(std::move(target)) {}
145 
reads()146   std::vector<StmtPtr> reads() {
147     return reads_;
148   }
149 
find(const StmtPtr & s,BufPtr b)150   static std::vector<StmtPtr> find(const StmtPtr& s, BufPtr b) {
151     StmtsReadingBuf finder(std::move(b));
152     s->accept(&finder);
153     return finder.reads();
154   }
155 
156  private:
readsBuffer(const StmtPtr & s)157   bool readsBuffer(const StmtPtr& s) {
158     auto loads = NodeFinder<Load>::find(s);
159     for (const auto& l : loads) {
160       if (l->buf() == target_) {
161         return true;
162       }
163     }
164     return false;
165   }
166 
visit(const StorePtr & v)167   void visit(const StorePtr& v) override {
168     if (readsBuffer(v)) {
169       reads_.push_back(v);
170     }
171   }
172 
visit(const LetPtr & v)173   void visit(const LetPtr& v) override {
174     if (readsBuffer(v)) {
175       reads_.push_back(v);
176     }
177   }
178 
visit(const CondPtr & v)179   void visit(const CondPtr& v) override {
180     if (readsBuffer(v)) {
181       reads_.push_back(v);
182     }
183   }
184 
visit(const AtomicAddPtr & v)185   void visit(const AtomicAddPtr& v) override {
186     if (readsBuffer(v)) {
187       reads_.push_back(v);
188     }
189   }
190 
191   BufPtr target_;
192   std::vector<StmtPtr> reads_;
193 };
194 
195 class ExternalAllocBufFinder : public IRVisitor {
196  public:
visit(const ExternalCallWithAllocPtr & v)197   void visit(const ExternalCallWithAllocPtr& v) override {
198     const auto& bufs_out = v->buf_out_args();
199     bufs_.insert(bufs_out.begin(), bufs_out.end());
200     IRVisitor::visit(v);
201   }
202 
find(const StmtPtr & s)203   static std::unordered_set<BufPtr> find(const StmtPtr& s) {
204     ExternalAllocBufFinder f;
205     s->accept(&f);
206     return f.bufs();
207   }
208 
find(const ExprPtr & e)209   static std::unordered_set<BufPtr> find(const ExprPtr& e) {
210     ExternalAllocBufFinder f;
211     e->accept(&f);
212     return f.bufs();
213   }
214 
bufs()215   const std::unordered_set<BufPtr>& bufs() {
216     return bufs_;
217   }
218 
219  private:
220   std::unordered_set<BufPtr> bufs_;
221 };
222 
223 // Traverses the IR to determine if a particular Var is modified within it.
224 class ModifiesVarChecker : public IRVisitor {
225  public:
ModifiesVarChecker(VarPtr v)226   ModifiesVarChecker(VarPtr v) : var_(std::move(v)) {}
227 
check(const StmtPtr & s,VarPtr v)228   static bool check(const StmtPtr& s, VarPtr v) {
229     ModifiesVarChecker checker(std::move(v));
230     s->accept(&checker);
231     return checker.found();
232   }
233 
found()234   bool found() {
235     return found_;
236   }
237 
238  private:
visit(const StorePtr & v)239   void visit(const StorePtr& v) override {
240     if (v->buf()->base_handle() == var_) {
241       found_ = true;
242       return;
243     }
244     IRVisitor::visit(v);
245   }
246 
visit(const AtomicAddPtr & v)247   void visit(const AtomicAddPtr& v) override {
248     if (v->buf()->base_handle() == var_) {
249       found_ = true;
250       return;
251     }
252     IRVisitor::visit(v);
253   }
254 
visit(const LetPtr & v)255   void visit(const LetPtr& v) override {
256     if (v->var() == var_) {
257       found_ = true;
258       return;
259     }
260     IRVisitor::visit(v);
261   }
262 
visit(const ForPtr & v)263   void visit(const ForPtr& v) override {
264     if (v->var() == var_) {
265       found_ = true;
266       return;
267     }
268     IRVisitor::visit(v);
269   }
270 
271   VarPtr var_;
272   bool found_{false};
273 };
274 
275 // Traverse the Block stmt to identify the live range of the specified buf. The
276 // live range, indicated by a pair of integers, specifies the first and last
277 // stmt in block stmts that access to the buf.
278 class BufLiveRange : public IRVisitor {
279  public:
BufLiveRange(BufPtr b)280   BufLiveRange(BufPtr b) : buf_(std::move(b)) {}
281 
liveRange(const StmtPtr & s,BufPtr b)282   static std::tuple<int32_t, int32_t> liveRange(const StmtPtr& s, BufPtr b) {
283     BlockPtr block = to<Block>(s);
284     // We Only analyze buffer live ranges for block stmts.
285     if (!block) {
286       return std::make_tuple(0, 0);
287     }
288 
289     BufLiveRange analyzer(std::move(b));
290     block->accept(&analyzer);
291     return analyzer.getLiveRange();
292   }
293 
294  private:
getLiveRange()295   std::tuple<int32_t, int32_t> getLiveRange() {
296     return std::make_tuple(begin_, end_);
297   }
298 
hasBufReads(const StmtPtr & s)299   bool hasBufReads(const StmtPtr& s) {
300     auto loads1 = NodeFinder<Load>::find(s);
301     for (const auto& l : loads1) {
302       if (l->buf() == buf_) {
303         return true;
304       }
305     }
306     auto loads2 = NodeFinder<ExternalCall>::find(s);
307     for (const auto& l : loads2) {
308       for (const auto& lb : l->buf_args()) {
309         if (lb == buf_) {
310           return true;
311         }
312       }
313     }
314     auto loads3 = NodeFinder<ExternalCallWithAlloc>::find(s);
315     for (const auto& l : loads3) {
316       for (const auto& lb : l->buf_args()) {
317         if (lb == buf_) {
318           return true;
319         }
320       }
321     }
322     return false;
323   }
324 
hasBufWrites(const StmtPtr & s)325   bool hasBufWrites(const StmtPtr& s) {
326     auto writes1 = NodeFinder<Store>::find(s);
327     for (const auto& w : writes1) {
328       if (w->buf() == buf_) {
329         return true;
330       }
331     }
332     auto writes2 = NodeFinder<ExternalCall>::find(s);
333     for (const auto& w : writes2) {
334       if (w->buf() == buf_) {
335         return true;
336       }
337     }
338     auto writes3 = NodeFinder<ExternalCallWithAlloc>::find(s);
339     for (const auto& w : writes3) {
340       for (const auto& wb : w->buf_out_args()) {
341         if (wb == buf_) {
342           return true;
343         }
344       }
345     }
346     return false;
347   }
348 
findAccAndUpdateLiveRange(const StmtPtr & s)349   void findAccAndUpdateLiveRange(const StmtPtr& s) {
350     bool has_reads = hasBufReads(s), has_writes = hasBufWrites(s);
351     if (has_reads || has_writes) {
352       if (begin_ == -1) {
353         begin_ = curr_index_;
354       };
355       end_ = curr_index_;
356     }
357   }
358 
visit(const BlockPtr & v)359   void visit(const BlockPtr& v) override {
360     for (const StmtPtr& s : *v) {
361       curr_index_ += 1;
362       findAccAndUpdateLiveRange(s);
363     }
364   }
365 
366   BufPtr buf_;
367   int32_t begin_ = -1;
368   int32_t end_ = -1;
369   int32_t curr_index_ = -1;
370 };
371 
372 // A class that analyzes the given program relevant for Block backend
373 // It creates a map of multi dim buffers and their flat versions
374 class CreateBufferMap : public IRVisitor {
375  public:
getBufferMap()376   const std::unordered_map<std::string, BufPtr>& getBufferMap() const {
377     return map_input_to_tensor_bufs_;
378   }
379 
380  private:
visit(const StorePtr & v)381   void visit(const StorePtr& v) override {
382     auto load_node = to<Load>(v->value());
383     if (load_node) {
384       auto t_buf = load_node->buf();
385       map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf());
386     } else {
387       auto add_node = to<Add>(v->value());
388       auto mul_node = to<Mul>(v->value());
389       // This means for now, v->value() can be Add or Mul
390       TORCH_INTERNAL_ASSERT(add_node || mul_node, buildErrorMessage());
391       map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), v->buf());
392     }
393     v->value()->accept(this);
394   }
395   std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
396 };
397 
398 } // namespace torch::jit::tensorexpr
399