1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/Pool.h>
5 #include <ATen/native/quantized/cpu/init_qnnpack.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <ATen/native/quantized/cpu/QuantizedOps.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_empty_affine_quantized.h>
14 #include <ATen/ops/avg_pool3d_native.h>
15 #endif
16
17 #include <vector>
18
19 namespace at {
20 namespace native {
21
22 DEFINE_DISPATCH(qavg_pool3d_nhwc_stub);
23
24 namespace {
25
get_kernel(IntArrayRef kernel_size)26 inline std::tuple<int, int, int> get_kernel(IntArrayRef kernel_size) {
27 TORCH_CHECK(
28 kernel_size.size() == 1 || kernel_size.size() == 3,
29 "avg_pool3d: kernel_size must either be a single int, or a tuple of three ints");
30 const int kD = safe_downcast<int, int64_t>(kernel_size[0]);
31 const int kH = kernel_size.size() == 1
32 ? kD
33 : safe_downcast<int, int64_t>(kernel_size[1]);
34 const int kW = kernel_size.size() == 1
35 ? kD
36 : safe_downcast<int, int64_t>(kernel_size[2]);
37 return std::make_tuple(kW, kH, kD);
38 }
39
get_stride(IntArrayRef stride,int kW,int kH,int kD)40 inline std::tuple<int, int, int> get_stride(IntArrayRef stride, int kW, int kH, int kD) {
41 TORCH_CHECK(
42 stride.empty() || stride.size() == 1 || stride.size() == 3,
43 "avg_pool3d: stride must either be omitted, a single int, or a tuple of three ints");
44 const int dD = stride.empty() ? kD : safe_downcast<int, int64_t>(stride[0]);
45 const int dH = stride.empty()
46 ? kH
47 : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[1]);
48 const int dW = stride.empty()
49 ? kW
50 : stride.size() == 1 ? dD : safe_downcast<int, int64_t>(stride[2]);
51 return std::make_tuple(dW, dH, dD);
52 }
53
get_padding(IntArrayRef padding)54 inline std::tuple<int, int, int> get_padding(IntArrayRef padding) {
55 TORCH_CHECK(
56 padding.size() == 1 || padding.size() == 3,
57 "avg_pool3d: padding must either be a single int, or a tuple of three ints");
58 const int padD = safe_downcast<int, int64_t>(padding[0]);
59 const int padH =
60 padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[1]);
61 const int padW =
62 padding.size() == 1 ? padD : safe_downcast<int, int64_t>(padding[2]);
63 return std::make_tuple(padW, padH, padD);
64 }
65
get_output_shape(const Tensor & input_,int kW,int kH,int kD,int dW,int dH,int dD,int padW,int padH,int padD,bool ceil_mode)66 std::vector<int64_t> get_output_shape(
67 const Tensor& input_,
68 int kW,
69 int kH,
70 int kD,
71 int dW,
72 int dH,
73 int dD,
74 int padW,
75 int padH,
76 int padD,
77 bool ceil_mode) {
78 const int64_t nbatch = input_.ndimension() == 5 ? input_.size(-5) : 1;
79 const int64_t nInputPlane = input_.size(-4);
80 const int64_t inputDepth = input_.size(-3);
81 const int64_t inputHeight = input_.size(-2);
82 const int64_t inputWidth = input_.size(-1);
83 const int64_t outputDepth =
84 pooling_output_shape<int64_t>(inputDepth, kD, padD, dD, 1, ceil_mode);
85 const int64_t outputHeight =
86 pooling_output_shape<int64_t>(inputHeight, kH, padH, dH, 1, ceil_mode);
87 const int64_t outputWidth =
88 pooling_output_shape<int64_t>(inputWidth, kW, padW, dW, 1, ceil_mode);
89 if (input_.ndimension() == 4) {
90 return {nInputPlane, outputDepth, outputHeight, outputWidth};
91 }
92 return {nbatch, nInputPlane, outputDepth, outputHeight, outputWidth};
93 }
94
95 template <typename scalar_t>
q_avg_pool3d(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)96 Tensor q_avg_pool3d(
97 const Tensor& input,
98 IntArrayRef kernel_size,
99 IntArrayRef stride,
100 IntArrayRef padding,
101 bool ceil_mode,
102 bool count_include_pad,
103 std::optional<int64_t> divisor_override) {
104 auto [kW, kH, kD] = get_kernel(kernel_size);
105 auto [dW, dH, dD] = get_stride(stride, kW, kH, kD);
106 auto [padW, padH, padD] = get_padding(padding);
107
108 const int64_t nbatch = input.ndimension() == 5 ? input.size(-5) : 1;
109 const int64_t nInputPlane = input.size(-4);
110 const int64_t inputDepth = input.size(-3);
111 const int64_t inputHeight = input.size(-2);
112 const int64_t inputWidth = input.size(-1);
113
114 TORCH_CHECK(
115 !divisor_override.has_value() || divisor_override.value() != 0,
116 "divisor must be not zero");
117
118 auto output_shape =
119 get_output_shape(input, kW, kH, kD, dW, dH, dD, padW, padH, padD, ceil_mode);
120 const int64_t outputDepth = output_shape[output_shape.size() - 3];
121 const int64_t outputHeight = output_shape[output_shape.size() - 2];
122 const int64_t outputWidth = output_shape[output_shape.size() - 1];
123
124 auto input_nhwc = input.contiguous(MemoryFormat::ChannelsLast3d);
125
126 auto output = at::_empty_affine_quantized(
127 output_shape,
128 input_nhwc.options().memory_format(input_nhwc.suggest_memory_format()),
129 input_nhwc.q_scale(),
130 input_nhwc.q_zero_point(),
131 std::nullopt);
132 // fast path for channel last: qavg_pool_2d_nhwc_stub
133 qavg_pool3d_nhwc_stub(
134 input_nhwc.device().type(),
135 input_nhwc,
136 output,
137 nbatch,
138 nInputPlane,
139 inputWidth,
140 inputHeight,
141 inputDepth,
142 outputWidth,
143 outputHeight,
144 outputDepth,
145 kW,
146 kH,
147 kD,
148 dW,
149 dH,
150 dD,
151 padW,
152 padH,
153 padD,
154 count_include_pad,
155 divisor_override);
156 return output;
157 }
158
159 } // namespace
160
avg_pool3d_quantized_cpu(const Tensor & input,IntArrayRef kernel_size,IntArrayRef stride,IntArrayRef padding,bool ceil_mode,bool count_include_pad,std::optional<int64_t> divisor_override)161 Tensor avg_pool3d_quantized_cpu(
162 const Tensor& input,
163 IntArrayRef kernel_size,
164 IntArrayRef stride,
165 IntArrayRef padding,
166 bool ceil_mode,
167 bool count_include_pad,
168 std::optional<int64_t> divisor_override) {
169 Tensor output;
170 AT_DISPATCH_QINT_TYPES(input.scalar_type(), "avg_pool3d_quantized_cpu", [&]() {
171 output = q_avg_pool3d<scalar_t>(
172 input,
173 kernel_size,
174 stride,
175 padding,
176 ceil_mode,
177 count_include_pad,
178 divisor_override);
179 });
180 return output;
181 }
182
183 } // namespace native
184 } // namespace at
185