1 // Ternary and higher-order pointwise operations
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/native/PointwiseOps.h>
4
5 #include <ATen/core/Tensor.h>
6 #include <ATen/TensorMeta.h>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/addcdiv_native.h>
12 #include <ATen/ops/addcmul_native.h>
13 #endif
14
15 namespace at::meta {
16
TORCH_META_FUNC(addcmul)17 TORCH_META_FUNC(addcmul)
18 (const Tensor& self,
19 const Tensor& tensor1,
20 const Tensor& tensor2,
21 const Scalar& value) {
22 build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
23 }
24
TORCH_META_FUNC(addcdiv)25 TORCH_META_FUNC(addcdiv)
26 (const Tensor& self,
27 const Tensor& tensor1,
28 const Tensor& tensor2,
29 const Scalar& value) {
30 if (isIntegralType(tensor1.scalar_type(), /*includeBool=*/true) &&
31 isIntegralType(tensor2.scalar_type(), /*includeBool=*/true)) {
32 TORCH_CHECK(
33 false,
34 "Integer division with addcdiv is no longer supported, and in a future ",
35 "release addcdiv will perform a true division of tensor1 and tensor2. ",
36 "The historic addcdiv behavior can be implemented as ",
37 "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
38 "for integer inputs and as ",
39 "(input + value * tensor1 / tensor2) for float inputs. ",
40 "The future addcdiv behavior is just the latter implementation: ",
41 "(input + value * tensor1 / tensor2), for all dtypes.");
42 }
43 build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
44 }
45
46 } // namespace at::meta
47 namespace at::native {
48
TORCH_IMPL_FUNC(addcmul_out)49 TORCH_IMPL_FUNC(addcmul_out)
50 (const Tensor& self,
51 const Tensor& tensor1,
52 const Tensor& tensor2,
53 const Scalar& value,
54 const Tensor& result) {
55 addcmul_stub(device_type(), *this, value);
56 }
57
TORCH_IMPL_FUNC(addcdiv_out)58 TORCH_IMPL_FUNC(addcdiv_out)
59 (const Tensor& self,
60 const Tensor& tensor1,
61 const Tensor& tensor2,
62 const Scalar& value,
63 const Tensor& result) {
64 addcdiv_stub(device_type(), *this, value);
65 }
66
67 DEFINE_DISPATCH(addcmul_stub);
68 DEFINE_DISPATCH(addcdiv_stub);
69
70 } // namespace at::native
71