1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/UpSample.cuh>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/cuda/CUDAContext.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/upsample_nearest3d.h>
18 #include <ATen/ops/upsample_nearest3d_native.h>
19 #include <ATen/ops/upsample_nearest3d_backward.h>
20 #include <ATen/ops/upsample_nearest3d_backward_native.h>
21 #include <ATen/ops/_upsample_nearest_exact3d.h>
22 #include <ATen/ops/_upsample_nearest_exact3d_native.h>
23 #include <ATen/ops/_upsample_nearest_exact3d_backward.h>
24 #include <ATen/ops/_upsample_nearest_exact3d_backward_native.h>
25 #endif
26
27 namespace at::native {
28 namespace {
29
30 #define MAX_THREADS 512
31
32 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
33 // nearest_neighbor_exact_compute_source_index
34 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
35
36 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
37 // nearest_neighbor_exact_bw_compute_source_index
38 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
39
40 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
41 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
42 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest3d_out_frame(const scalar_t * input,size_t dim_b,size_t dim_c,size_t src_dim_d,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_d,size_t dst_dim_h,size_t dst_dim_w,scalar_t * output,float depth_scale,float height_scale,float width_scale)43 __global__ void upsample_nearest3d_out_frame(
44 const scalar_t* input,
45 size_t dim_b,
46 size_t dim_c,
47 size_t src_dim_d,
48 size_t src_dim_h,
49 size_t src_dim_w,
50 size_t dst_dim_d,
51 size_t dst_dim_h,
52 size_t dst_dim_w,
53 scalar_t* output,
54 float depth_scale,
55 float height_scale,
56 float width_scale) {
57
58 int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
59 if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
60 return;
61
62 int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
63 int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
64
65 int c = (dst_idx / (dst_c_stride)) % dim_c;
66
67 int dst_z = (dst_idx / dst_dim_h / dst_dim_w) % dst_dim_d;
68 int src_z = nn_compute_source_index_fn(depth_scale, dst_z, src_dim_d);
69 int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
70 int src_y = nn_compute_source_index_fn(height_scale, dst_y, src_dim_h);
71
72 int dst_x = dst_idx % dst_dim_w;
73 int src_x = nn_compute_source_index_fn(width_scale, dst_x, src_dim_w);
74
75 int src_idx = c * src_c_stride + src_z * src_dim_h * src_dim_w +
76 src_y * src_dim_w + src_x;
77 for (int b = 0; b < dim_b; b++) {
78 output[dst_idx] = input[src_idx];
79 src_idx += dim_c * src_c_stride;
80 dst_idx += dim_c * dst_c_stride;
81 }
82 }
83
84 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
85 // Backward operation
86 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
87 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest3d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_d,size_t src_dim_h,size_t src_dim_w,size_t dst_dim_d,size_t dst_dim_h,size_t dst_dim_w,scalar_t * grad_i,float depth_scale,float height_scale,float width_scale)88 __global__ void upsample_nearest3d_backward_out_frame(
89 const scalar_t* grad_o,
90 size_t dim_b,
91 size_t dim_c,
92 size_t src_dim_d,
93 size_t src_dim_h,
94 size_t src_dim_w,
95 size_t dst_dim_d,
96 size_t dst_dim_h,
97 size_t dst_dim_w,
98 scalar_t* grad_i,
99 float depth_scale,
100 float height_scale,
101 float width_scale) {
102
103 int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
104 if (dst_idx >= dim_c * dst_dim_d * dst_dim_h * dst_dim_w)
105 return;
106
107 int dst_c_stride = dst_dim_d * dst_dim_h * dst_dim_w;
108 int src_c_stride = src_dim_d * src_dim_h * src_dim_w;
109
110 int c = (dst_idx / (dst_c_stride)) % dim_c;
111
112 int dst_z = (dst_idx / dst_dim_h / dst_dim_w) % dst_dim_d;
113 // note that we do not want to clamp src_z to src_dim_z, since we might
114 // intentionally want to skip in case of scale_factor < 1.0
115 int src_z = nn_bw_compute_source_index_fn(depth_scale, dst_z, src_dim_d);
116 int src_z_up = nn_bw_compute_source_index_fn(depth_scale, dst_z+1, src_dim_d);
117
118 int dst_y = (dst_idx / dst_dim_w) % dst_dim_h;
119 // note that we do not want to clamp src_y to src_dim_y, since we might
120 // intentionally want to skip in case of scale_factor < 1.0
121 int src_y = nn_bw_compute_source_index_fn(height_scale, dst_y, src_dim_h);
122 int src_y_up = nn_bw_compute_source_index_fn(height_scale, dst_y+1, src_dim_h);
123
124 int dst_x = dst_idx % dst_dim_w;
125 // note that we do not want to clamp src_x to src_dim_w, since we might
126 // intentionally want to skip in case of scale_factor < 1.0
127 int src_x = nn_bw_compute_source_index_fn(width_scale, dst_x, src_dim_w);
128 int src_x_up = nn_bw_compute_source_index_fn(width_scale, dst_x+1, src_dim_w);
129
130 for (int b = 0; b < dim_b; b++) {
131 accscalar_t grad = 0;
132 for (int z = src_z; z < src_z_up; z++) {
133 for (int y = src_y; y < src_y_up; y++) {
134 for (int x = src_x; x < src_x_up; x++) {
135 int src_idx = b * dim_c * src_c_stride + c * src_c_stride +
136 z * src_dim_h * src_dim_w + y * src_dim_w + x;
137 grad += grad_o[src_idx];
138 }
139 }
140 }
141 grad_i[dst_idx] = grad;
142 dst_idx += dim_c * dst_c_stride;
143 }
144 }
145
146 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest3d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)147 static void upsample_nearest3d_out_cuda_template(
148 const Tensor& output,
149 const Tensor& input_,
150 IntArrayRef output_size,
151 std::optional<double> scales_d,
152 std::optional<double> scales_h,
153 std::optional<double> scales_w) {
154 TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
155 checkAllSameGPU(__func__, {input_arg, output_arg});
156
157 // TODO: remove this when the cuda kernel is updated to support the channels_last memory format.
158 // This is a temporary hack to prevent a silence correctness issue when calling this kernel
159 // with tensors in channels_last format.
160 auto output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
161
162 int output_depth = output_size[0];
163 int output_height = output_size[1];
164 int output_width = output_size[2];
165
166 int nbatch = input_.size(0);
167 int channels = input_.size(1);
168 int input_depth = input_.size(2);
169 int input_height = input_.size(3);
170 int input_width = input_.size(4);
171
172 Tensor input = input_.contiguous();
173
174 if (input.numel() == 0) {
175 return;
176 }
177
178 // upsample_nearest3d meta call makes sure `nbatch != 0`
179 unsigned int n = output.numel() / nbatch;
180 dim3 bdim{std::min<unsigned int>(
181 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
182 dim3 gdim{ceil_div(n, bdim.x)};
183 // safe check for int32 indexing; implicitly restrict launch config for kernel
184 TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max(),
185 "upsample_nearest3d only supports output tensors with less than INT_MAX elements, but got ", output.sizes());
186
187 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
188 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte,input.scalar_type(), "upsample_nearest3d_out_frame", [&] {
189 using accscalar_t = at::acc_type<scalar_t, true>;
190
191 auto idata = input.const_data_ptr<scalar_t>();
192 auto odata = output_c.mutable_data_ptr<scalar_t>();
193
194 const float depth_scale = compute_scales_value<float>(scales_d, input_depth, output_depth);
195 const float height_scale = compute_scales_value<float>(scales_h, input_height, output_height);
196 const float width_scale = compute_scales_value<float>(scales_w, input_width, output_width);
197
198 upsample_nearest3d_out_frame<scalar_t, nn_compute_source_index_fn>
199 <<<gdim, bdim, 0, stream>>>(
200 idata,
201 nbatch,
202 channels,
203 input_depth,
204 input_height,
205 input_width,
206 output_depth,
207 output_height,
208 output_width,
209 odata,
210 depth_scale,
211 height_scale,
212 width_scale);
213 C10_CUDA_KERNEL_LAUNCH_CHECK();
214 });
215
216 if (!output.is_contiguous()) {
217 output.copy_(output_c);
218 }
219 }
220
221 template<nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest3d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)222 static void upsample_nearest3d_backward_out_cuda_template(
223 const Tensor& grad_input,
224 const Tensor& grad_output_,
225 IntArrayRef output_size,
226 IntArrayRef input_size,
227 std::optional<double> scales_d,
228 std::optional<double> scales_h,
229 std::optional<double> scales_w) {
230 TensorArg grad_input_arg{grad_input, "grad_input", 1},
231 grad_output_arg{grad_output_, "grad_output_", 2};
232 checkAllSameGPU(
233 __func__,
234 {grad_output_arg, grad_input_arg});
235
236 int output_depth = output_size[0];
237 int output_height = output_size[1];
238 int output_width = output_size[2];
239
240 int nbatch = input_size[0];
241 int channels = input_size[1];
242 int input_depth = input_size[2];
243 int input_height = input_size[3];
244 int input_width = input_size[4];
245
246 Tensor grad_output = grad_output_.contiguous();
247
248 if (grad_input.numel() == 0) {
249 return;
250 }
251
252 // upsample_nearest3d meta call makes sure `nbatch != 0`
253 unsigned int n = grad_input.numel() / nbatch;
254 dim3 bdim{std::min<unsigned int>(
255 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
256 dim3 gdim{ceil_div(n, bdim.x)};
257 // safe check for int32 indexing; implicitly restrict launch config for kernel
258 TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
259 "upsample_nearest3d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
260 TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
261 "upsample_nearest3d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
262
263 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
264 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest3d_backward_out_frame", [&] {
265 using accscalar_t = at::acc_type<scalar_t, true>;
266
267 auto idata = grad_input.mutable_data_ptr<scalar_t>();
268 auto odata = grad_output.const_data_ptr<scalar_t>();
269
270 float depth_scale = compute_scales_value_backwards<float>(scales_d, output_depth, input_depth);
271 float height_scale = compute_scales_value_backwards<float>(scales_h, output_height, input_height);
272 float width_scale = compute_scales_value_backwards<float>(scales_w, output_width, input_width);
273
274 upsample_nearest3d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
275 <<<gdim, bdim, 0, stream>>>(
276 odata,
277 nbatch,
278 channels,
279 output_depth,
280 output_height,
281 output_width,
282 input_depth,
283 input_height,
284 input_width,
285 idata,
286 depth_scale,
287 height_scale,
288 width_scale);
289 C10_CUDA_KERNEL_LAUNCH_CHECK();
290 });
291 }
292
293 } // namespace
294
TORCH_IMPL_FUNC(upsample_nearest3d_out_cuda)295 TORCH_IMPL_FUNC(upsample_nearest3d_out_cuda) (
296 const Tensor& input,
297 IntArrayRef output_size,
298 std::optional<double> scales_d,
299 std::optional<double> scales_h,
300 std::optional<double> scales_w,
301 const Tensor& output) {
302 upsample_nearest3d_out_cuda_template<nearest_neighbor_compute_source_index>(
303 output, input, output_size, scales_d, scales_h, scales_w);
304 }
305
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cuda)306 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_out_cuda) (
307 const Tensor& input,
308 IntArrayRef output_size,
309 std::optional<double> scales_d,
310 std::optional<double> scales_h,
311 std::optional<double> scales_w,
312 const Tensor& output) {
313 upsample_nearest3d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(output, input, output_size, scales_d, scales_h, scales_w);
314 }
315
TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cuda)316 TORCH_IMPL_FUNC(upsample_nearest3d_backward_out_cuda) (
317 const Tensor& grad_output,
318 IntArrayRef output_size,
319 IntArrayRef input_size,
320 std::optional<double> scales_d,
321 std::optional<double> scales_h,
322 std::optional<double> scales_w,
323 const Tensor& grad_input) {
324 upsample_nearest3d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
325 grad_input, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
326 }
327
TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cuda)328 TORCH_IMPL_FUNC(_upsample_nearest_exact3d_backward_out_cuda) (
329 const Tensor& grad_output,
330 IntArrayRef output_size,
331 IntArrayRef input_size,
332 std::optional<double> scales_d,
333 std::optional<double> scales_h,
334 std::optional<double> scales_w,
335 const Tensor& grad_input) {
336 upsample_nearest3d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
337 grad_input, grad_output, output_size, input_size, scales_d, scales_h, scales_w);
338 }
339
340 using at::native::upsample::compute_output_size;
341 using at::native::upsample_cuda::get_scale_value;
342
343 } // namespace at::native
344