xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/ir_verifier.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
4 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace tensorexpr {
9 
10 class Expr;
11 class ExprHandle;
12 class Mod;
13 class And;
14 class Or;
15 class Xor;
16 class Lshift;
17 class Rshift;
18 class CompareSelect;
19 class Ramp;
20 class Load;
21 class IfThenElse;
22 class Intrinsics;
23 
24 class Stmt;
25 class ExternalCall;
26 class Store;
27 class For;
28 class Block;
29 
30 class TORCH_API IRVerifier : public IRVisitor {
31  public:
32   IRVerifier() = default;
33 
34   void visit(const ModPtr& v) override;
35   void visit(const AndPtr& v) override;
36   void visit(const OrPtr& v) override;
37   void visit(const XorPtr& v) override;
38   void visit(const LshiftPtr& v) override;
39   void visit(const RshiftPtr& v) override;
40   void visit(const CompareSelectPtr& v) override;
41   void visit(const RampPtr& v) override;
42   void visit(const LoadPtr& v) override;
43   void visit(const IfThenElsePtr& v) override;
44   void visit(const IntrinsicsPtr& v) override;
45 
46   void visit(const ExternalCallPtr& v) override;
47   void visit(const StorePtr& v) override;
48   void visit(const ForPtr& v) override;
49   void visit(const BlockPtr& v) override;
50 };
51 
52 TORCH_API void verify(const StmtPtr&);
53 TORCH_API void verify(const ExprPtr&);
54 TORCH_API void verify(const ExprHandle&);
55 
56 } // namespace tensorexpr
57 } // namespace jit
58 } // namespace torch
59