xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorOperators.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <c10/core/Scalar.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty_like.h>
10 #endif
11 
12 namespace at {
13 
14 #define AT_FORALL_BINARY_OPS(_)                                             \
15   _(+, x.add(y), y.add(x))                                                  \
16   _(*, x.mul(y), y.mul(x))                                                  \
17   _(-,                                                                      \
18     x.sub(y),                                                               \
19     ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y))       \
20   _(/,                                                                      \
21     x.div(y),                                                               \
22     ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y))       \
23   _(%,                                                                      \
24     x.remainder(y),                                                         \
25     ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \
26   _(&, x.bitwise_and(y), y.bitwise_and(x))                                  \
27   _(|, x.bitwise_or(y), y.bitwise_or(x))                                    \
28   _(^, x.bitwise_xor(y), y.bitwise_xor(x))                                  \
29   _(<, x.lt(y), y.gt(x))                                                    \
30   _(<=, x.le(y), y.ge(x))                                                   \
31   _(>, x.gt(y), y.lt(x))                                                    \
32   _(>=, x.ge(y), y.le(x))                                                   \
33   _(==, x.eq(y), y.eq(x))                                                   \
34   _(!=, x.ne(y), y.ne(x))
35 
36 #define DEFINE_OPERATOR(op, body, reverse_scalar_body)          \
37   inline Tensor operator op(const Tensor& x, const Tensor& y) { \
38     return body;                                                \
39   }                                                             \
40   inline Tensor operator op(const Tensor& x, const Scalar& y) { \
41     return body;                                                \
42   }                                                             \
43   inline Tensor operator op(const Scalar& x, const Tensor& y) { \
44     return reverse_scalar_body;                                 \
45   }
46 
47 AT_FORALL_BINARY_OPS(DEFINE_OPERATOR)
48 #undef DEFINE_OPERATOR
49 #undef AT_FORALL_BINARY_OPS
50 
51 } // namespace at
52