1 /**
2 * This file implements the core classes for Tensor Expressions.
3 *
4 * The structure of the expressions is inspired by Halide/TVM IR.
5 */
6 #pragma once
7
8 #include <c10/core/MemoryFormat.h>
9 #include <torch/csrc/jit/tensorexpr/fwd_decls.h>
10 #include <torch/csrc/jit/tensorexpr/ir_mutator.h>
11 #include <torch/csrc/jit/tensorexpr/ir_visitor.h>
12 #include <torch/csrc/jit/tensorexpr/types.h>
13 #include <optional>
14
15 #include <utility>
16
17 namespace torch::jit::tensorexpr {
18
19 enum IRNodeType {
20 kPrimitive,
21 kAdd,
22 kSub,
23 kMul,
24 kDiv,
25 kMod,
26 kMax,
27 kMin,
28 kAnd,
29 kOr,
30 kLshift,
31 kRshift,
32 kXor,
33 kCompareSelect,
34 kCast,
35 kBitCast,
36 kOther,
37 };
38
39 // The common base between all expression node.
40 class TORCH_API Expr : public std::enable_shared_from_this<Expr> {
41 public:
42 explicit Expr(Dtype dtype, IRNodeType expr_type = kOther)
dtype_(dtype)43 : dtype_(dtype), expr_type_(expr_type) {}
44 virtual ~Expr() = default;
dtype()45 Dtype dtype() const {
46 return dtype_;
47 }
48 virtual void accept(IRVisitor* visitor) = 0;
49 virtual ExprPtr accept_mutator(IRMutator* mutator) = 0;
50
expr_type()51 IRNodeType expr_type() const {
52 return expr_type_;
53 }
54 // Is this a fixed (constant) immediate value.
isConstant()55 virtual bool isConstant() const {
56 return false;
57 }
58
set_dtype(Dtype dtype)59 void set_dtype(Dtype dtype) {
60 dtype_ = dtype;
61 }
62
63 /*
64 * Make a deep copy of the given expression.
65 *
66 * All sub-expressions inside the given expressions are also cloned. Note
67 * that the variables are not deep-copied since they are immutable.
68 */
69 static ExprPtr clone(const ExprPtr& s);
70
71 protected:
getptr()72 std::shared_ptr<Expr> getptr() {
73 return shared_from_this();
74 }
75
76 private:
77 Dtype dtype_;
78 IRNodeType expr_type_;
79 };
80
81 // A CRTP pattern to accept visitors for children class,
82 // and dispatch back to the children.
83 template <class Op, class Base = Expr>
84 class ExprNode : public Base {
85 public:
86 using ExprNodeBase = ExprNode<Op>;
accept(IRVisitor * visitor)87 void accept(IRVisitor* visitor) override {
88 visitor->visit(static_to<Op>(Base::getptr()));
89 }
90 ExprPtr accept_mutator(IRMutator* mutator) override;
91 // pass the constructor to the base class
92 using Base::Base;
93 };
94
95 // A wrapper object to the underlying ExprNode.
96 // Also serves the primary way to build and operate on other expressions.
97 class TORCH_API ExprHandle {
98 public:
99 ExprHandle() = default;
ExprHandle(ExprPtr node)100 explicit ExprHandle(ExprPtr node) : base_expr_node_(std::move(node)) {}
101
node()102 ExprPtr node() {
103 return base_expr_node_;
104 }
105
node()106 ExprPtr node() const {
107 return base_expr_node_;
108 }
109
empty()110 bool empty() const {
111 return base_expr_node_ == nullptr;
112 }
113
114 #define IMM_EXPR_DECLARE(Type, Name) ExprHandle(Type v);
115 AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_EXPR_DECLARE);
116 #undef IMM_EXPR_DECLARE
117
118 template <class Op>
AsNode()119 NodePtr<Op> AsNode() {
120 return to<Op>(this->node());
121 }
122
123 template <class Op>
AsNode()124 NodePtr<Op> AsNode() const {
125 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
126 return const_cast<ExprHandle*>(this)->AsNode<Op>();
127 }
128
dtype()129 Dtype dtype() const {
130 return node()->dtype();
131 }
132
133 // Handling the math operators.
134 ExprHandle operator+(const ExprHandle& other) const;
135 ExprHandle operator-(const ExprHandle& other) const;
136 ExprHandle operator*(const ExprHandle& other) const;
137 ExprHandle operator/(const ExprHandle& other) const;
138 ExprHandle operator%(const ExprHandle& other) const;
139 ExprHandle operator==(const ExprHandle& other) const;
140 ExprHandle operator!=(const ExprHandle& other) const;
141 ExprHandle operator>(const ExprHandle& other) const;
142 ExprHandle operator>=(const ExprHandle& other) const;
143 ExprHandle operator<(const ExprHandle& other) const;
144 ExprHandle operator<=(const ExprHandle& other) const;
145 ExprHandle operator&(const ExprHandle& other) const;
146 ExprHandle operator|(const ExprHandle& other) const;
147 ExprHandle operator&&(const ExprHandle& other) const;
148 ExprHandle operator||(const ExprHandle& other) const;
149 ExprHandle operator^(const ExprHandle& other) const;
150 ExprHandle operator<<(const ExprHandle& other) const;
151 ExprHandle operator>>(const ExprHandle& other) const;
152
153 private:
154 ExprPtr base_expr_node_ = nullptr;
155 };
156
157 // The underlying representation node to a Var.
158 // Currently, each Var object represents a unique variable, even though the
159 // names might be the same. We should consider add a unique_name as well.
160 class TORCH_API Var : public ExprNode<Var> {
161 public:
make(const std::string & name_hint,Dtype dtype)162 static ExprHandle make(const std::string& name_hint, Dtype dtype) {
163 return ExprHandle(alloc<Var>(name_hint, dtype));
164 }
make(Dtype dtype)165 static ExprHandle make(Dtype dtype) {
166 return ExprHandle(alloc<Var>("", dtype));
167 }
168
169 // TODO: unique_name
name_hint()170 const std::string& name_hint() const {
171 return name_hint_;
172 }
173
set_name_hint(const std::string & name)174 void set_name_hint(const std::string& name) {
175 name_hint_ = name;
176 }
177
set_name_hint(std::string && name)178 void set_name_hint(std::string&& name) {
179 name_hint_ = std::move(name);
180 }
181
Var(std::string name_hint,Dtype dtype)182 Var(std::string name_hint, Dtype dtype)
183 : ExprNodeBase(dtype, kPrimitive), name_hint_(std::move(name_hint)) {}
184
185 private:
186 std::string name_hint_;
187 };
188
189 TORCH_API std::vector<ExprPtr> make_contiguous_strides(
190 const std::vector<ExprHandle>& dims);
191 TORCH_API std::vector<ExprPtr> make_channels_last_strides(
192 const std::vector<ExprHandle>& dims);
193
194 class TORCH_API Buf : public ExprNode<Buf> {
195 public:
196 static BufHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
197
198 static BufHandle make(
199 const std::string& name_hint,
200 const std::vector<ExprHandle>& dims,
201 const std::vector<ExprHandle>& strides,
202 Dtype dtype);
203
204 static BufHandle make(
205 const std::string& name_hint,
206 const std::vector<ExprHandle>& dims,
207 Dtype dtype,
208 std::optional<ExprHandle> initializer = std::nullopt,
209 const std::optional<std::vector<ExprHandle>>& strides = std::nullopt,
210 std::optional<ExprHandle> qscale = std::nullopt,
211 std::optional<ExprHandle> qzero = std::nullopt);
212
213 // TODO: unique_name
base_handle()214 VarPtr base_handle() const {
215 return base_handle_;
216 }
set_base_handle(VarPtr base_handle)217 void set_base_handle(VarPtr base_handle) {
218 base_handle_ = std::move(base_handle);
219 }
220
name_hint()221 const std::string& name_hint() const {
222 return base_handle_->name_hint();
223 }
set_name_hint(const std::string & name_hint)224 void set_name_hint(const std::string& name_hint) {
225 base_handle_->set_name_hint(name_hint);
226 }
227
228 Buf(const std::string& name_hint,
229 const std::vector<ExprPtr>& dims,
230 Dtype dtype,
231 ExprPtr initializer = nullptr,
232 std::optional<std::vector<ExprPtr>> strides = std::nullopt,
233 ExprPtr qscale = nullptr,
234 ExprPtr qzero = nullptr)
Buf(alloc<Var> (name_hint,kHandle),dims,dtype,std::move (initializer),std::move (strides),std::move (qscale),std::move (qzero))235 : Buf(alloc<Var>(name_hint, kHandle),
236 dims,
237 dtype,
238 std::move(initializer),
239 std::move(strides),
240 std::move(qscale),
241 std::move(qzero)) {}
242
243 Buf(const VarPtr& var,
244 std::vector<ExprPtr> dims,
245 Dtype dtype,
246 ExprPtr initializer = nullptr,
247 std::optional<std::vector<ExprPtr>> strides = std::nullopt,
248 ExprPtr qscale = nullptr,
249 ExprPtr qzero = nullptr);
250
ndim()251 size_t ndim() const {
252 return dims_.size();
253 }
dim(size_t index)254 ExprPtr dim(size_t index) const {
255 if (index >= ndim()) {
256 throw out_of_range_index();
257 }
258 return dims_[index];
259 }
dims()260 std::vector<ExprPtr> dims() const {
261 return dims_;
262 }
set_dims(std::vector<ExprPtr> dims)263 void set_dims(std::vector<ExprPtr> dims) {
264 dims_ = std::move(dims);
265 }
266
strides()267 std::vector<ExprPtr> strides() const {
268 return strides_;
269 }
270
set_strides(std::vector<ExprPtr> strides)271 void set_strides(std::vector<ExprPtr> strides) {
272 strides_ = std::move(strides);
273 }
274
initializer()275 ExprPtr initializer() const {
276 return initializer_;
277 };
278
qzero()279 ExprPtr qzero() const {
280 return qzero_;
281 }
282
qscale()283 ExprPtr qscale() const {
284 return qscale_;
285 }
286
set_qzero(ExprPtr qzero)287 void set_qzero(ExprPtr qzero) {
288 qzero_ = std::move(qzero);
289 }
290
set_qscale(ExprPtr qscale)291 void set_qscale(ExprPtr qscale) {
292 qscale_ = std::move(qscale);
293 }
294
hasConstantDims()295 bool hasConstantDims() const {
296 for (const auto& d : dims_) {
297 if (!d->isConstant()) {
298 return false;
299 }
300 }
301 return true;
302 }
303
304 bool is_contiguous(
305 at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const;
306
307 // The channels-last 1d can benefit the performance of some operators like
308 // conv1d. But the MemoryFormat enum has not covered this layout yet. Hence,
309 // we abstract a dedicated function to check channels-last 1d contiguous.
310 //
311 // Channels-last 1d:
312 // dims: n c l
313 // strides(nlc): c*l 1 c
is_channels_last_1d_contiguous()314 bool is_channels_last_1d_contiguous() const {
315 if (dims_.size() != 3) {
316 return false;
317 }
318 return is_stride_one(1) && is_cont_with(2, 1) && is_cont_with(0, 2);
319 }
320
321 private:
322 bool is_cont_with(int cur_dim, int adjacent_dim) const;
323 bool is_stride_one(int cur_dim) const;
324
325 VarPtr base_handle_;
326 std::vector<ExprPtr> dims_;
327 std::vector<ExprPtr> strides_;
328 ExprPtr initializer_;
329 // qscale_ and qzero_ are used only for quantized dtypes Bufs: kQUInt8, kQInt8
330 ExprPtr qscale_;
331 ExprPtr qzero_;
332 };
333
334 class TORCH_API BufHandle : public ExprHandle {
335 public:
BufHandle(const std::string & name_hint,const std::vector<ExprHandle> & dims,Dtype dtype)336 BufHandle(
337 const std::string& name_hint,
338 const std::vector<ExprHandle>& dims,
339 Dtype dtype)
340 : ExprHandle(Buf::make(name_hint, dims, dtype)) {}
341
BufHandle(const std::string & name_hint,const std::vector<ExprHandle> & dims,const std::vector<ExprHandle> & strides,Dtype dtype)342 BufHandle(
343 const std::string& name_hint,
344 const std::vector<ExprHandle>& dims,
345 const std::vector<ExprHandle>& strides,
346 Dtype dtype)
347 : ExprHandle(Buf::make(name_hint, dims, strides, dtype)) {}
348
BufHandle(const std::vector<ExprHandle> & dims,Dtype dtype)349 BufHandle(const std::vector<ExprHandle>& dims, Dtype dtype)
350 : ExprHandle(Buf::make("_", dims, dtype)) {}
351
BufHandle(Dtype dtype)352 explicit BufHandle(Dtype dtype) : ExprHandle(Buf::make("_", {}, dtype)) {}
353
BufHandle(BufPtr node)354 explicit BufHandle(BufPtr node) : ExprHandle(std::move(node)) {}
node()355 BufPtr node() const {
356 return static_to<Buf>(ExprHandle::node());
357 }
node()358 BufPtr node() {
359 return static_to<Buf>(ExprHandle::node());
360 }
361
362 template <typename... Ts>
363 inline ExprHandle load(const Ts&... ts) const;
364
365 template <typename T>
366 inline ExprHandle load(const std::vector<T>& args) const;
367
368 inline ExprHandle load(const std::vector<ExprHandle>& args) const;
369
370 StorePtr store(const std::vector<ExprHandle>& args, const ExprHandle& val)
371 const;
372
373 bool operator==(const BufHandle& other) const {
374 return this->node() == other.node();
375 }
376 bool operator!=(const BufHandle& other) const {
377 return !(*this == other);
378 }
379
name_hint()380 const std::string& name_hint() const {
381 return this->node()->name_hint();
382 }
383
empty()384 bool empty() const {
385 return (this->node() == nullptr);
386 }
387
ndim()388 size_t ndim() const {
389 return node()->ndim();
390 }
391
392 std::vector<ExprHandle> dims() const;
393
dim(size_t index)394 ExprHandle dim(size_t index) const {
395 return ExprHandle(node()->dim(index));
396 }
397
398 bool is_contiguous(
399 at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
400 return node()->is_contiguous(memory_format);
401 }
402
is_channels_last_1d_contiguous()403 bool is_channels_last_1d_contiguous() const {
404 return node()->is_channels_last_1d_contiguous();
405 }
406 };
407
408 // An expression to construct the underlying variable node.
409 // Note: do not store any info here, since it is often possible to slice this
410 // object. For example: VarHandle x('x'); ExprHandle x2 = x;
411 class TORCH_API VarHandle : public ExprHandle {
412 public:
413 // Creates an empty VarHandle whose base Var is set to nullptr.
VarHandle()414 VarHandle() : ExprHandle() {}
415
VarHandle(Dtype dtype)416 explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {}
417
VarHandle(const std::string & name_hint,Dtype dtype)418 VarHandle(const std::string& name_hint, Dtype dtype)
419 : ExprHandle(Var::make(name_hint, dtype)) {}
420
VarHandle(VarPtr node)421 explicit VarHandle(VarPtr node) : ExprHandle(std::move(node)) {}
422
node()423 VarPtr node() const {
424 return static_to<Var>(ExprHandle::node());
425 }
426 bool operator==(const VarHandle& other) const {
427 return this->node() == other.node();
428 }
429 bool operator!=(const VarHandle& other) const {
430 return !(*this == other);
431 }
432
name_hint()433 const std::string& name_hint() const {
434 return this->node()->name_hint();
435 }
empty()436 bool empty() const {
437 return (this->node() == nullptr);
438 }
439 };
440
441 template <class Op, class Base>
accept_mutator(IRMutator * mutator)442 ExprPtr ExprNode<Op, Base>::accept_mutator(IRMutator* mutator) {
443 return mutator->mutate(static_to<Op>(Base::getptr()));
444 }
445
same_node(const ExprHandle & expr1,const ExprHandle & expr2)446 inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) {
447 return expr1.AsNode<Expr>() == expr2.AsNode<Expr>();
448 }
449
450 TORCH_API ExprHandle sin(const ExprHandle& v);
451 TORCH_API ExprHandle cos(const ExprHandle& v);
452 TORCH_API ExprHandle tan(const ExprHandle& v);
453 TORCH_API ExprHandle asin(const ExprHandle& v);
454 TORCH_API ExprHandle acos(const ExprHandle& v);
455 TORCH_API ExprHandle atan(const ExprHandle& v);
456 TORCH_API ExprHandle sinh(const ExprHandle& v);
457 TORCH_API ExprHandle cosh(const ExprHandle& v);
458 TORCH_API ExprHandle tanh(const ExprHandle& v);
459 TORCH_API ExprHandle sigmoid(const ExprHandle& v);
460 TORCH_API ExprHandle exp(const ExprHandle& v);
461 TORCH_API ExprHandle expm1(const ExprHandle& v);
462 TORCH_API ExprHandle abs(const ExprHandle& v);
463 TORCH_API ExprHandle log(const ExprHandle& v);
464 TORCH_API ExprHandle fast_tanh(const ExprHandle& v);
465 TORCH_API ExprHandle fast_sigmoid(const ExprHandle& v);
466 TORCH_API ExprHandle fast_log(const ExprHandle& v);
467 TORCH_API ExprHandle log_vml(const ExprHandle& v);
468 TORCH_API ExprHandle log2(const ExprHandle& v);
469 TORCH_API ExprHandle log10(const ExprHandle& v);
470 TORCH_API ExprHandle log1p(const ExprHandle& v);
471 TORCH_API ExprHandle erf(const ExprHandle& v);
472 TORCH_API ExprHandle erfc(const ExprHandle& v);
473 TORCH_API ExprHandle sqrt(const ExprHandle& v);
474 TORCH_API ExprHandle rsqrt(const ExprHandle& v);
475 TORCH_API ExprHandle ceil(const ExprHandle& v);
476 TORCH_API ExprHandle floor(const ExprHandle& v);
477 TORCH_API ExprHandle round(const ExprHandle& v);
478 TORCH_API ExprHandle trunc(const ExprHandle& v);
479 TORCH_API ExprHandle frac(const ExprHandle& v);
480 TORCH_API ExprHandle lgamma(const ExprHandle& v);
481 TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2);
482 TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2);
483 TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2);
484 TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2);
485 TORCH_API ExprHandle isnan(const ExprHandle& v1);
486 TORCH_API ExprHandle Relu(const ExprHandle& v1);
487
488 TORCH_API ExprHandle
489 ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f);
490
491 TORCH_API ExprHandle expr_to_vec(ExprHandle v, int lanes);
492
493 } // namespace torch::jit::tensorexpr
494