xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/PointwiseOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Ternary and higher-order pointwise operations
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 #include <ATen/native/PointwiseOps.h>
4 
5 #include <ATen/core/Tensor.h>
6 #include <ATen/TensorMeta.h>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/addcdiv_native.h>
12 #include <ATen/ops/addcmul_native.h>
13 #endif
14 
15 namespace at::meta {
16 
TORCH_META_FUNC(addcmul)17 TORCH_META_FUNC(addcmul)
18 (const Tensor& self,
19  const Tensor& tensor1,
20  const Tensor& tensor2,
21  const Scalar& value) {
22   build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
23 }
24 
TORCH_META_FUNC(addcdiv)25 TORCH_META_FUNC(addcdiv)
26 (const Tensor& self,
27  const Tensor& tensor1,
28  const Tensor& tensor2,
29  const Scalar& value) {
30   if (isIntegralType(tensor1.scalar_type(), /*includeBool=*/true) &&
31       isIntegralType(tensor2.scalar_type(), /*includeBool=*/true)) {
32     TORCH_CHECK(
33         false,
34         "Integer division with addcdiv is no longer supported, and in a future  ",
35         "release addcdiv will perform a true division of tensor1 and tensor2. ",
36         "The historic addcdiv behavior can be implemented as ",
37         "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
38         "for integer inputs and as ",
39         "(input + value * tensor1 / tensor2) for float inputs. ",
40         "The future addcdiv behavior is just the latter implementation: ",
41         "(input + value * tensor1 / tensor2), for all dtypes.");
42   }
43   build_ternary_op(maybe_get_output(), self, tensor1, tensor2);
44 }
45 
46 } // namespace at::meta
47 namespace at::native {
48 
TORCH_IMPL_FUNC(addcmul_out)49 TORCH_IMPL_FUNC(addcmul_out)
50 (const Tensor& self,
51  const Tensor& tensor1,
52  const Tensor& tensor2,
53  const Scalar& value,
54  const Tensor& result) {
55   addcmul_stub(device_type(), *this, value);
56 }
57 
TORCH_IMPL_FUNC(addcdiv_out)58 TORCH_IMPL_FUNC(addcdiv_out)
59 (const Tensor& self,
60  const Tensor& tensor1,
61  const Tensor& tensor2,
62  const Scalar& value,
63  const Tensor& result) {
64   addcdiv_stub(device_type(), *this, value);
65 }
66 
67 DEFINE_DISPATCH(addcmul_stub);
68 DEFINE_DISPATCH(addcdiv_stub);
69 
70 } // namespace at::native
71