xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/expr.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * This file implements the core classes for Tensor Expressions.
3  *
4  * The structure of the expressions is inspired by Halide/TVM IR.
5  */
6 #pragma once
7 
8 #include <c10/core/MemoryFormat.h>
9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
10 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
11 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
12 #include <torch/csrc/jit/tensorexpr/types.h>
13 #include <optional>
14 
15 #include <utility>
16 
17 namespace torch::jit::tensorexpr {
18 
19 enum IRNodeType {
20   kPrimitive,
21   kAdd,
22   kSub,
23   kMul,
24   kDiv,
25   kMod,
26   kMax,
27   kMin,
28   kAnd,
29   kOr,
30   kLshift,
31   kRshift,
32   kXor,
33   kCompareSelect,
34   kCast,
35   kBitCast,
36   kOther,
37 };
38 
39 // The common base between all expression node.
40 class TORCH_API Expr : public std::enable_shared_from_this<Expr> {
41  public:
42   explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
dtype_(dtype)43       : dtype_(dtype), expr_type_(expr_type) {}
44   virtual ~Expr() = default;
dtype()45   Dtype dtype() const {
46     return dtype_;
47   }
48   virtual void accept(IRVisitor* visitor) = 0;
49   virtual ExprPtr accept_mutator(IRMutator* mutator) = 0;
50 
expr_type()51   IRNodeType expr_type() const {
52     return expr_type_;
53   }
54   // Is this a fixed (constant) immediate value.
isConstant()55   virtual bool isConstant() const {
56     return false;
57   }
58 
set_dtype(Dtype dtype)59   void set_dtype(Dtype dtype) {
60     dtype_ = dtype;
61   }
62 
63   /*
64    * Make a deep copy of the given expression.
65    *
66    * All sub-expressions inside the given expressions are also cloned. Note
67    * that the variables are not deep-copied since they are immutable.
68    */
69   static ExprPtr clone(const ExprPtr& s);
70 
71  protected:
getptr()72   std::shared_ptr<Expr> getptr() {
73     return shared_from_this();
74   }
75 
76  private:
77   Dtype dtype_;
78   IRNodeType expr_type_;
79 };
80 
81 // A CRTP pattern to accept visitors for children class,
82 // and dispatch back to the children.
83 template <class Op, class Base = Expr>
84 class ExprNode : public Base {
85  public:
86   using ExprNodeBase = ExprNode<Op>;
accept(IRVisitor * visitor)87   void accept(IRVisitor* visitor) override {
88     visitor->visit(static_to<Op>(Base::getptr()));
89   }
90   ExprPtr accept_mutator(IRMutator* mutator) override;
91   // pass the constructor to the base class
92   using Base::Base;
93 };
94 
95 // A wrapper object to the underlying ExprNode.
96 // Also serves the primary way to build and operate on other expressions.
97 class TORCH_API ExprHandle {
98  public:
99   ExprHandle() = default;
ExprHandle(ExprPtr node)100   explicit ExprHandle(ExprPtr node) : base_expr_node_(std::move(node)) {}
101 
node()102   ExprPtr node() {
103     return base_expr_node_;
104   }
105 
node()106   ExprPtr node() const {
107     return base_expr_node_;
108   }
109 
empty()110   bool empty() const {
111     return base_expr_node_ == nullptr;
112   }
113 
114 #define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
115   AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
116 #undef IMM_EXPR_DECLARE
117 
118   template <class Op>
AsNode()119   NodePtr<Op> AsNode() {
120     return to<Op>(this->node());
121   }
122 
123   template <class Op>
AsNode()124   NodePtr<Op> AsNode() const {
125     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
126     return const_cast<ExprHandle*>(this)->AsNode<Op>();
127   }
128 
dtype()129   Dtype dtype() const {
130     return node()->dtype();
131   }
132 
133   // Handling the math operators.
134   ExprHandle operator+(const ExprHandle& other) const;
135   ExprHandle operator-(const ExprHandle& other) const;
136   ExprHandle operator*(const ExprHandle& other) const;
137   ExprHandle operator/(const ExprHandle& other) const;
138   ExprHandle operator%(const ExprHandle& other) const;
139   ExprHandle operator==(const ExprHandle& other) const;
140   ExprHandle operator!=(const ExprHandle& other) const;
141   ExprHandle operator>(const ExprHandle& other) const;
142   ExprHandle operator>=(const ExprHandle& other) const;
143   ExprHandle operator<(const ExprHandle& other) const;
144   ExprHandle operator<=(const ExprHandle& other) const;
145   ExprHandle operator&(const ExprHandle& other) const;
146   ExprHandle operator|(const ExprHandle& other) const;
147   ExprHandle operator&&(const ExprHandle& other) const;
148   ExprHandle operator||(const ExprHandle& other) const;
149   ExprHandle operator^(const ExprHandle& other) const;
150   ExprHandle operator<<(const ExprHandle& other) const;
151   ExprHandle operator>>(const ExprHandle& other) const;
152 
153  private:
154   ExprPtr base_expr_node_ = nullptr;
155 };
156 
157 // The underlying representation node to a Var.
158 // Currently, each Var object represents a unique variable, even though the
159 // names might be the same. We should consider add a unique_name as well.
160 class TORCH_API Var : public ExprNode<Var> {
161  public:
make(const std::string & name_hint,Dtype dtype)162   static ExprHandle make(const std::string& name_hint, Dtype dtype) {
163     return ExprHandle(alloc<Var>(name_hint, dtype));
164   }
make(Dtype dtype)165   static ExprHandle make(Dtype dtype) {
166     return ExprHandle(alloc<Var>("", dtype));
167   }
168 
169   // TODO: unique_name
name_hint()170   const std::string& name_hint() const {
171     return name_hint_;
172   }
173 
set_name_hint(const std::string & name)174   void set_name_hint(const std::string& name) {
175     name_hint_ = name;
176   }
177 
set_name_hint(std::string && name)178   void set_name_hint(std::string&& name) {
179     name_hint_ = std::move(name);
180   }
181 
Var(std::string name_hint,Dtype dtype)182   Var(std::string name_hint, Dtype dtype)
183       : ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
184 
185  private:
186   std::string name_hint_;
187 };
188 
189 TORCH_API std::vector<ExprPtr> make_contiguous_strides(
190     const std::vector<ExprHandle>& dims);
191 TORCH_API std::vector<ExprPtr> make_channels_last_strides(
192     const std::vector<ExprHandle>& dims);
193 
194 class TORCH_API Buf : public ExprNode<Buf> {
195  public:
196   static BufHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
197 
198   static BufHandle make(
199       const std::string& name_hint,
200       const std::vector<ExprHandle>& dims,
201       const std::vector<ExprHandle>& strides,
202       Dtype dtype);
203 
204   static BufHandle make(
205       const std::string& name_hint,
206       const std::vector<ExprHandle>& dims,
207       Dtype dtype,
208       std::optional<ExprHandle> initializer = std::nullopt,
209       const std::optional<std::vector<ExprHandle>>& strides = std::nullopt,
210       std::optional<ExprHandle> qscale = std::nullopt,
211       std::optional<ExprHandle> qzero = std::nullopt);
212 
213   // TODO: unique_name
base_handle()214   VarPtr base_handle() const {
215     return base_handle_;
216   }
set_base_handle(VarPtr base_handle)217   void set_base_handle(VarPtr base_handle) {
218     base_handle_ = std::move(base_handle);
219   }
220 
name_hint()221   const std::string& name_hint() const {
222     return base_handle_->name_hint();
223   }
set_name_hint(const std::string & name_hint)224   void set_name_hint(const std::string& name_hint) {
225     base_handle_->set_name_hint(name_hint);
226   }
227 
228   Buf(const std::string& name_hint,
229       const std::vector<ExprPtr>& dims,
230       Dtype dtype,
231       ExprPtr initializer = nullptr,
232       std::optional<std::vector<ExprPtr>> strides = std::nullopt,
233       ExprPtr qscale = nullptr,
234       ExprPtr qzero = nullptr)
Buf(alloc<Var> (name_hint,kHandle),dims,dtype,std::move (initializer),std::move (strides),std::move (qscale),std::move (qzero))235       : Buf(alloc<Var>(name_hint, kHandle),
236             dims,
237             dtype,
238             std::move(initializer),
239             std::move(strides),
240             std::move(qscale),
241             std::move(qzero)) {}
242 
243   Buf(const VarPtr& var,
244       std::vector<ExprPtr> dims,
245       Dtype dtype,
246       ExprPtr initializer = nullptr,
247       std::optional<std::vector<ExprPtr>> strides = std::nullopt,
248       ExprPtr qscale = nullptr,
249       ExprPtr qzero = nullptr);
250 
ndim()251   size_t ndim() const {
252     return dims_.size();
253   }
dim(size_t index)254   ExprPtr dim(size_t index) const {
255     if (index >= ndim()) {
256       throw out_of_range_index();
257     }
258     return dims_[index];
259   }
dims()260   std::vector<ExprPtr> dims() const {
261     return dims_;
262   }
set_dims(std::vector<ExprPtr> dims)263   void set_dims(std::vector<ExprPtr> dims) {
264     dims_ = std::move(dims);
265   }
266 
strides()267   std::vector<ExprPtr> strides() const {
268     return strides_;
269   }
270 
set_strides(std::vector<ExprPtr> strides)271   void set_strides(std::vector<ExprPtr> strides) {
272     strides_ = std::move(strides);
273   }
274 
initializer()275   ExprPtr initializer() const {
276     return initializer_;
277   };
278 
qzero()279   ExprPtr qzero() const {
280     return qzero_;
281   }
282 
qscale()283   ExprPtr qscale() const {
284     return qscale_;
285   }
286 
set_qzero(ExprPtr qzero)287   void set_qzero(ExprPtr qzero) {
288     qzero_ = std::move(qzero);
289   }
290 
set_qscale(ExprPtr qscale)291   void set_qscale(ExprPtr qscale) {
292     qscale_ = std::move(qscale);
293   }
294 
hasConstantDims()295   bool hasConstantDims() const {
296     for (const auto& d : dims_) {
297       if (!d->isConstant()) {
298         return false;
299       }
300     }
301     return true;
302   }
303 
304   bool is_contiguous(
305       at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const;
306 
307   // The channels-last 1d can benefit the performance of some operators like
308   // conv1d. But the MemoryFormat enum has not covered this layout yet. Hence,
309   // we abstract a dedicated function to check channels-last 1d contiguous.
310   //
311   // Channels-last 1d:
312   //   dims:              n   c    l
313   //   strides(nlc):    c*l   1    c
is_channels_last_1d_contiguous()314   bool is_channels_last_1d_contiguous() const {
315     if (dims_.size() != 3) {
316       return false;
317     }
318     return is_stride_one(1) && is_cont_with(2, 1) && is_cont_with(0, 2);
319   }
320 
321  private:
322   bool is_cont_with(int cur_dim, int adjacent_dim) const;
323   bool is_stride_one(int cur_dim) const;
324 
325   VarPtr base_handle_;
326   std::vector<ExprPtr> dims_;
327   std::vector<ExprPtr> strides_;
328   ExprPtr initializer_;
329   // qscale_ and qzero_ are used only for quantized dtypes Bufs: kQUInt8, kQInt8
330   ExprPtr qscale_;
331   ExprPtr qzero_;
332 };
333 
334 class TORCH_API BufHandle : public ExprHandle {
335  public:
BufHandle(const std::string & name_hint,const std::vector<ExprHandle> & dims,Dtype dtype)336   BufHandle(
337       const std::string& name_hint,
338       const std::vector<ExprHandle>& dims,
339       Dtype dtype)
340       : ExprHandle(Buf::make(name_hint, dims, dtype)) {}
341 
BufHandle(const std::string & name_hint,const std::vector<ExprHandle> & dims,const std::vector<ExprHandle> & strides,Dtype dtype)342   BufHandle(
343       const std::string& name_hint,
344       const std::vector<ExprHandle>& dims,
345       const std::vector<ExprHandle>& strides,
346       Dtype dtype)
347       : ExprHandle(Buf::make(name_hint, dims, strides, dtype)) {}
348 
BufHandle(const std::vector<ExprHandle> & dims,Dtype dtype)349   BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
350       : ExprHandle(Buf::make("_", dims, dtype)) {}
351 
BufHandle(Dtype dtype)352   explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
353 
BufHandle(BufPtr node)354   explicit BufHandle(BufPtr node) : ExprHandle(std::move(node)) {}
node()355   BufPtr node() const {
356     return static_to<Buf>(ExprHandle::node());
357   }
node()358   BufPtr node() {
359     return static_to<Buf>(ExprHandle::node());
360   }
361 
362   template <typename... Ts>
363   inline ExprHandle load(const Ts&... ts) const;
364 
365   template <typename T>
366   inline ExprHandle load(const std::vector<T>& args) const;
367 
368   inline ExprHandle load(const std::vector<ExprHandle>& args) const;
369 
370   StorePtr store(const std::vector<ExprHandle>& args, const ExprHandle& val)
371       const;
372 
373   bool operator==(const BufHandle& other) const {
374     return this->node() == other.node();
375   }
376   bool operator!=(const BufHandle& other) const {
377     return !(*this == other);
378   }
379 
name_hint()380   const std::string& name_hint() const {
381     return this->node()->name_hint();
382   }
383 
empty()384   bool empty() const {
385     return (this->node() == nullptr);
386   }
387 
ndim()388   size_t ndim() const {
389     return node()->ndim();
390   }
391 
392   std::vector<ExprHandle> dims() const;
393 
dim(size_t index)394   ExprHandle dim(size_t index) const {
395     return ExprHandle(node()->dim(index));
396   }
397 
398   bool is_contiguous(
399       at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
400     return node()->is_contiguous(memory_format);
401   }
402 
is_channels_last_1d_contiguous()403   bool is_channels_last_1d_contiguous() const {
404     return node()->is_channels_last_1d_contiguous();
405   }
406 };
407 
408 // An expression to construct the underlying variable node.
409 // Note: do not store any info here, since it is often possible to slice this
410 // object. For example: VarHandle x('x'); ExprHandle x2 = x;
411 class TORCH_API VarHandle : public ExprHandle {
412  public:
413   // Creates an empty VarHandle whose base Var is set to nullptr.
VarHandle()414   VarHandle() : ExprHandle() {}
415 
VarHandle(Dtype dtype)416   explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
417 
VarHandle(const std::string & name_hint,Dtype dtype)418   VarHandle(const std::string& name_hint, Dtype dtype)
419       : ExprHandle(Var::make(name_hint, dtype)) {}
420 
VarHandle(VarPtr node)421   explicit VarHandle(VarPtr node) : ExprHandle(std::move(node)) {}
422 
node()423   VarPtr node() const {
424     return static_to<Var>(ExprHandle::node());
425   }
426   bool operator==(const VarHandle& other) const {
427     return this->node() == other.node();
428   }
429   bool operator!=(const VarHandle& other) const {
430     return !(*this == other);
431   }
432 
name_hint()433   const std::string& name_hint() const {
434     return this->node()->name_hint();
435   }
empty()436   bool empty() const {
437     return (this->node() == nullptr);
438   }
439 };
440 
441 template <class Op, class Base>
accept_mutator(IRMutator * mutator)442 ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
443   return mutator->mutate(static_to<Op>(Base::getptr()));
444 }
445 
same_node(const ExprHandle & expr1,const ExprHandle & expr2)446 inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
447   return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
448 }
449 
450 TORCH_API ExprHandle sin(const ExprHandle& v);
451 TORCH_API ExprHandle cos(const ExprHandle& v);
452 TORCH_API ExprHandle tan(const ExprHandle& v);
453 TORCH_API ExprHandle asin(const ExprHandle& v);
454 TORCH_API ExprHandle acos(const ExprHandle& v);
455 TORCH_API ExprHandle atan(const ExprHandle& v);
456 TORCH_API ExprHandle sinh(const ExprHandle& v);
457 TORCH_API ExprHandle cosh(const ExprHandle& v);
458 TORCH_API ExprHandle tanh(const ExprHandle& v);
459 TORCH_API ExprHandle sigmoid(const ExprHandle& v);
460 TORCH_API ExprHandle exp(const ExprHandle& v);
461 TORCH_API ExprHandle expm1(const ExprHandle& v);
462 TORCH_API ExprHandle abs(const ExprHandle& v);
463 TORCH_API ExprHandle log(const ExprHandle& v);
464 TORCH_API ExprHandle fast_tanh(const ExprHandle& v);
465 TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v);
466 TORCH_API ExprHandle fast_log(const ExprHandle& v);
467 TORCH_API ExprHandle log_vml(const ExprHandle& v);
468 TORCH_API ExprHandle log2(const ExprHandle& v);
469 TORCH_API ExprHandle log10(const ExprHandle& v);
470 TORCH_API ExprHandle log1p(const ExprHandle& v);
471 TORCH_API ExprHandle erf(const ExprHandle& v);
472 TORCH_API ExprHandle erfc(const ExprHandle& v);
473 TORCH_API ExprHandle sqrt(const ExprHandle& v);
474 TORCH_API ExprHandle rsqrt(const ExprHandle& v);
475 TORCH_API ExprHandle ceil(const ExprHandle& v);
476 TORCH_API ExprHandle floor(const ExprHandle& v);
477 TORCH_API ExprHandle round(const ExprHandle& v);
478 TORCH_API ExprHandle trunc(const ExprHandle& v);
479 TORCH_API ExprHandle frac(const ExprHandle& v);
480 TORCH_API ExprHandle lgamma(const ExprHandle& v);
481 TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
482 TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
483 TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
484 TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
485 TORCH_API ExprHandle isnan(const ExprHandle& v1);
486 TORCH_API ExprHandle Relu(const ExprHandle& v1);
487 
488 TORCH_API ExprHandle
489 ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);
490 
491 TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);
492 
493 } // namespace torch::jit::tensorexpr
494