1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <c10/core/Scalar.h> 5 6 #ifndef AT_PER_OPERATOR_HEADERS 7 #include <ATen/Functions.h> 8 #else 9 #include <ATen/ops/empty_like.h> 10 #endif 11 12 namespace at { 13 14 #define AT_FORALL_BINARY_OPS(_) \ 15 _(+, x.add(y), y.add(x)) \ 16 _(*, x.mul(y), y.mul(x)) \ 17 _(-, \ 18 x.sub(y), \ 19 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \ 20 _(/, \ 21 x.div(y), \ 22 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \ 23 _(%, \ 24 x.remainder(y), \ 25 ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \ 26 _(&, x.bitwise_and(y), y.bitwise_and(x)) \ 27 _(|, x.bitwise_or(y), y.bitwise_or(x)) \ 28 _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \ 29 _(<, x.lt(y), y.gt(x)) \ 30 _(<=, x.le(y), y.ge(x)) \ 31 _(>, x.gt(y), y.lt(x)) \ 32 _(>=, x.ge(y), y.le(x)) \ 33 _(==, x.eq(y), y.eq(x)) \ 34 _(!=, x.ne(y), y.ne(x)) 35 36 #define DEFINE_OPERATOR(op, body, reverse_scalar_body) \ 37 inline Tensor operator op(const Tensor& x, const Tensor& y) { \ 38 return body; \ 39 } \ 40 inline Tensor operator op(const Tensor& x, const Scalar& y) { \ 41 return body; \ 42 } \ 43 inline Tensor operator op(const Scalar& x, const Tensor& y) { \ 44 return reverse_scalar_body; \ 45 } 46 47 AT_FORALL_BINARY_OPS(DEFINE_OPERATOR) 48 #undef DEFINE_OPERATOR 49 #undef AT_FORALL_BINARY_OPS 50 51 } // namespace at 52