xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/AveragePool3d.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/Pool.h>
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <ATen/native/quantized/cpu/QuantizedOps.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #include <ATen/ops/avg_pool3d_native.h>
15 #endif
16 
17 #include <vector>
18 
19 namespace at {
20 namespace native {
21 
22 DEFINE_DISPATCH(qavg_pool3d_nhwc_stub);
23 
24 namespace {
25 
get_kernel(IntArrayRef kernel_size)26 inline std::tuple<int, int, int> get_kernel(IntArrayRef kernel_size) {
27   TORCH_CHECK(
28       kernel_size.size() == 1 || kernel_size.size() == 3,
29       "avg_pool3d: kernel_size must either be a single int, or a tuple of three ints");
30   const int kD = safe_downcast<int, int64_t>(kernel_size[0]);
31   const int kH = kernel_size.size() == 1
32       ? kD
33       : safe_downcast<int, int64_t>(kernel_size[1]);
34   const int kW = kernel_size.size() == 1
35       ? kD
36       : safe_downcast<int, int64_t>(kernel_size[2]);
37   return std::make_tuple(kW, kH, kD);
38 }
39 
get_stride(IntArrayRef stride,int kW,int kH,int kD)40 inline std::tuple<int, int, int> get_stride(IntArrayRef stride, int kW, int kH, int kD) {
41   TORCH_CHECK(
42       stride.empty() || stride.size() == 1 || stride.size() == 3,
43       "avg_pool3d: stride must either be omitted, a single int, or a tuple of three ints");
44   const int dD = stride.empty() ? kD : safe_downcast<int, int64_t>(stride[0]);
45   const int dH = stride.empty()
46       ? kH
47       : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[1]);
48   const int dW = stride.empty()
49       ? kW
50       : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[2]);
51   return std::make_tuple(dW, dH, dD);
52 }
53 
get_padding(IntArrayRef padding)54 inline std::tuple<int, int, int> get_padding(IntArrayRef padding) {
55   TORCH_CHECK(
56       padding.size() == 1 || padding.size() == 3,
57       "avg_pool3d: padding must either be a single int, or a tuple of three ints");
58   const int padD = safe_downcast<int, int64_t>(padding[0]);
59   const int padH =
60       padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[1]);
61   const int padW =
62       padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[2]);
63   return std::make_tuple(padW, padH, padD);
64 }
65 
get_output_shape(const Tensor & input_,int kW,int kH,int kD,int dW,int dH,int dD,int padW,int padH,int padD,bool ceil_mode)66 std::vector<int64_t> get_output_shape(
67     const Tensor& input_,
68     int kW,
69     int kH,
70     int kD,
71     int dW,
72     int dH,
73     int dD,
74     int padW,
75     int padH,
76     int padD,
77     bool ceil_mode) {
78   const int64_t nbatch = input_.ndimension() == 5 ? input_.size(-5) : 1;
79   const int64_t nInputPlane = input_.size(-4);
80   const int64_t inputDepth = input_.size(-3);
81   const int64_t inputHeight = input_.size(-2);
82   const int64_t inputWidth = input_.size(-1);
83   const int64_t outputDepth =
84       pooling_output_shape<int64_t>(inputDepth, kD, padD, dD, 1, ceil_mode);
85   const int64_t outputHeight =
86       pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
87   const int64_t outputWidth =
88       pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
89   if (input_.ndimension() == 4) {
90     return {nInputPlane, outputDepth, outputHeight, outputWidth};
91   }
92   return {nbatch, nInputPlane, outputDepth, outputHeight, outputWidth};
93 }
94 
95 template <typename scalar_t>
q_avg_pool3d(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)96 Tensor q_avg_pool3d(
97     const Tensor& input,
98     IntArrayRef kernel_size,
99     IntArrayRef stride,
100     IntArrayRef padding,
101     bool ceil_mode,
102     bool count_include_pad,
103     std::optional<int64_t> divisor_override) {
104   auto [kW, kH, kD] = get_kernel(kernel_size);
105   auto [dW, dH, dD] = get_stride(stride, kW, kH, kD);
106   auto [padW, padH, padD] = get_padding(padding);
107 
108   const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
109   const int64_t nInputPlane = input.size(-4);
110   const int64_t inputDepth = input.size(-3);
111   const int64_t inputHeight = input.size(-2);
112   const int64_t inputWidth = input.size(-1);
113 
114   TORCH_CHECK(
115       !divisor_override.has_value() || divisor_override.value() != 0,
116       "divisor must be not zero");
117 
118   auto output_shape =
119       get_output_shape(input, kW, kH, kD, dW, dH, dD, padW, padH, padD, ceil_mode);
120   const int64_t outputDepth = output_shape[output_shape.size() - 3];
121   const int64_t outputHeight = output_shape[output_shape.size() - 2];
122   const int64_t outputWidth = output_shape[output_shape.size() - 1];
123 
124   auto input_nhwc = input.contiguous(MemoryFormat::ChannelsLast3d);
125 
126   auto output = at::_empty_affine_quantized(
127       output_shape,
128       input_nhwc.options().memory_format(input_nhwc.suggest_memory_format()),
129       input_nhwc.q_scale(),
130       input_nhwc.q_zero_point(),
131       std::nullopt);
132   // fast path for channel last: qavg_pool_2d_nhwc_stub
133   qavg_pool3d_nhwc_stub(
134       input_nhwc.device().type(),
135       input_nhwc,
136       output,
137       nbatch,
138       nInputPlane,
139       inputWidth,
140       inputHeight,
141       inputDepth,
142       outputWidth,
143       outputHeight,
144       outputDepth,
145       kW,
146       kH,
147       kD,
148       dW,
149       dH,
150       dD,
151       padW,
152       padH,
153       padD,
154       count_include_pad,
155       divisor_override);
156   return output;
157 }
158 
159 } // namespace
160 
avg_pool3d_quantized_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)161 Tensor avg_pool3d_quantized_cpu(
162     const Tensor& input,
163     IntArrayRef kernel_size,
164     IntArrayRef stride,
165     IntArrayRef padding,
166     bool ceil_mode,
167     bool count_include_pad,
168     std::optional<int64_t> divisor_override) {
169   Tensor output;
170   AT_DISPATCH_QINT_TYPES(input.scalar_type(), "avg_pool3d_quantized_cpu", [&]() {
171     output = q_avg_pool3d<scalar_t>(
172         input,
173         kernel_size,
174         stride,
175         padding,
176         ceil_mode,
177         count_include_pad,
178         divisor_override);
179   });
180   return output;
181 }
182 
183 } // namespace native
184 } // namespace at
185