1 #pragma once 2 #include <c10/core/ScalarType.h> 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/tensorexpr/fwd_decls.h> 5 6 namespace torch::jit::tensorexpr { 7 8 class TORCH_API IRMutator { 9 public: 10 virtual ~IRMutator() = default; 11 virtual ExprPtr mutate(const AddPtr& v); 12 virtual ExprPtr mutate(const SubPtr& v); 13 virtual ExprPtr mutate(const MulPtr& v); 14 virtual ExprPtr mutate(const DivPtr& v); 15 virtual ExprPtr mutate(const ModPtr& v); 16 virtual ExprPtr mutate(const MaxPtr& v); 17 virtual ExprPtr mutate(const MinPtr& v); 18 virtual ExprPtr mutate(const AndPtr& v); 19 virtual ExprPtr mutate(const OrPtr& v); 20 virtual ExprPtr mutate(const XorPtr& v); 21 virtual ExprPtr mutate(const LshiftPtr& v); 22 virtual ExprPtr mutate(const RshiftPtr& v); 23 virtual ExprPtr mutate(const CompareSelectPtr& v); 24 #define IMM_MUTATE_DECLARE(Type, Name) \ 25 virtual ExprPtr mutate(const Name##ImmPtr& v); 26 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_MUTATE_DECLARE); 27 #undef IMM_MUTATE_DECLARE 28 virtual ExprPtr mutate(const CastPtr& v); 29 virtual ExprPtr mutate(const BitCastPtr& v); 30 virtual ExprPtr mutate(const VarPtr& v); 31 virtual ExprPtr mutate(const BufPtr& v); 32 virtual ExprPtr mutate(const RampPtr& v); 33 virtual ExprPtr mutate(const LoadPtr& v); 34 virtual ExprPtr mutate(const BroadcastPtr& v); 35 virtual ExprPtr mutate(const IfThenElsePtr& v); 36 virtual ExprPtr mutate(const IntrinsicsPtr& v); 37 38 virtual ExprPtr mutate(const TermPtr& v); 39 virtual ExprPtr mutate(const PolynomialPtr& v); 40 virtual ExprPtr mutate(const RoundOffPtr& v); 41 virtual ExprPtr mutate(const MaxTermPtr& v); 42 virtual ExprPtr mutate(const MinTermPtr& v); 43 44 virtual ExprPtr mutate(const ReduceOpPtr& v); 45 46 virtual StmtPtr mutate(const ForPtr& v); 47 virtual StmtPtr mutate(const BlockPtr& v); 48 virtual StmtPtr mutate(const StorePtr& v); 49 virtual StmtPtr mutate(const AtomicAddPtr& v); 50 virtual StmtPtr mutate(const SyncThreadsPtr& v); 51 virtual StmtPtr mutate(const ExternalCallPtr& v); 52 virtual StmtPtr mutate(const ExternalCallWithAllocPtr& v); 53 54 virtual StmtPtr mutate(const AllocatePtr& v); 55 virtual StmtPtr mutate(const FreePtr& v); 56 virtual StmtPtr mutate(const FreeExtPtr& v); 57 virtual StmtPtr mutate(const PlacementAllocatePtr& v); 58 virtual StmtPtr mutate(const LetPtr& v); 59 virtual StmtPtr mutate(const CondPtr& v); 60 }; 61 62 } // namespace torch::jit::tensorexpr 63