1 #pragma once 2 #include <c10/core/ScalarType.h> 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 5 6 namespace torch::jit::tensorexpr { 7 8 class TORCH_API IRVisitor { 9 public: 10 virtual ~IRVisitor() = default; 11 virtual void visit(const AddPtr& v); 12 virtual void visit(const SubPtr& v); 13 virtual void visit(const MulPtr& v); 14 virtual void visit(const DivPtr& v); 15 virtual void visit(const ModPtr& v); 16 virtual void visit(const MaxPtr& v); 17 virtual void visit(const MinPtr& v); 18 virtual void visit(const AndPtr& v); 19 virtual void visit(const OrPtr& v); 20 virtual void visit(const XorPtr& v); 21 virtual void visit(const LshiftPtr& v); 22 virtual void visit(const RshiftPtr& v); 23 virtual void visit(const CompareSelectPtr& v); 24 25 #define IMM_PRINT_VISIT(Type, Name) virtual void visit(const Name##ImmPtr& v); 26 27 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT) 28 #undef IMM_PRINT_VISIT 29 30 virtual void visit(const CastPtr& v); 31 virtual void visit(const BitCastPtr& v); 32 virtual void visit(const VarPtr& v); 33 virtual void visit(const BufPtr& v); 34 virtual void visit(const RampPtr& v); 35 virtual void visit(const LoadPtr& v); 36 virtual void visit(const ForPtr& v); 37 virtual void visit(const BlockPtr& v); 38 virtual void visit(const StorePtr& v); 39 virtual void visit(const BroadcastPtr& v); 40 virtual void visit(const IfThenElsePtr& v); 41 virtual void visit(const IntrinsicsPtr& v); 42 virtual void visit(const AllocatePtr& v); 43 virtual void visit(const FreePtr& v); 44 virtual void visit(const FreeExtPtr& v); 45 virtual void visit(const PlacementAllocatePtr& v); 46 virtual void visit(const LetPtr& v); 47 virtual void visit(const CondPtr& v); 48 virtual void visit(const TermPtr& v); 49 virtual void visit(const PolynomialPtr& v); 50 virtual void visit(const RoundOffPtr& v); 51 virtual void visit(const MaxTermPtr& v); 52 virtual void visit(const MinTermPtr& v); 53 virtual void visit(const ReduceOpPtr& v); 54 virtual void visit(const AtomicAddPtr& v); 55 virtual void visit(const SyncThreadsPtr& v); 56 virtual void visit(const ExternalCallPtr& v); 57 virtual void visit(const ExternalCallWithAllocPtr& v); 58 }; 59 60 } // namespace torch::jit::tensorexpr 61