xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Adapted from interp.cpp from Caffe util by Pauline Luc
2 // Originally developed by George Papandreou
3 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
4 #include <ATen/core/Tensor.h>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/Utils.h>
10 #include <ATen/cuda/Atomic.cuh>
11 #include <ATen/cuda/CUDAContext.h>
12 #include <ATen/cuda/CUDAApplyUtils.cuh>
13 #include <ATen/native/cuda/UpSample.cuh>
14 #include <ATen/native/cuda/KernelUtils.cuh>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #include <ATen/NativeFunctions.h>
19 #else
20 #include <ATen/ops/upsample_trilinear3d_native.h>
21 #include <ATen/ops/upsample_trilinear3d_backward_native.h>
22 #endif
23 
24 namespace at::native {
25 namespace {
26 
27 __device__ __forceinline__ size_t
idx_3d(const size_t nc,const size_t depth,const size_t height,const size_t width,const size_t z,const size_t y,const size_t x)28 idx_3d(const size_t nc,
29     const size_t depth,
30     const size_t height,
31     const size_t width,
32     const size_t z,
33     const size_t y,
34     const size_t x) {
35   return ((nc * depth + z) * height + y) * width + x;
36 }
37 
38 template <typename scalar_t, typename accscalar_t>
39 C10_LAUNCH_BOUNDS_1(512)
upsample_trilinear3d_out_frame(const int n,const accscalar_t rdepth,const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,const PackedTensorAccessor64<const scalar_t,5> idata,PackedTensorAccessor64<scalar_t,5> odata)40 __global__ void upsample_trilinear3d_out_frame(
41     const int n,
42     const accscalar_t rdepth,
43     const accscalar_t rheight,
44     const accscalar_t rwidth,
45     const bool align_corners,
46     const PackedTensorAccessor64<const scalar_t, 5> idata,
47     PackedTensorAccessor64<scalar_t, 5> odata) {
48   int index = threadIdx.x + blockIdx.x * blockDim.x;
49 
50   const int batchsize = idata.size(0);
51   const int channels = idata.size(1);
52   const int depth1 = idata.size(2);
53   const int height1 = idata.size(3);
54   const int width1 = idata.size(4);
55   const int depth2 = odata.size(2);
56   const int height2 = odata.size(3);
57   const int width2 = odata.size(4);
58 
59   if (index < n) {
60     const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
61     const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
62     const int t2 = index / (height2 * width2); // 0:depth2-1
63     // special case: just copy
64     if (depth1 == depth2 && height1 == height2 && width1 == width2) {
65       const int t1 = t2;
66       const int h1 = h2;
67       const int w1 = w2;
68 
69       for (int n = 0; n < batchsize; n++) {
70         for (int c = 0; c < channels; ++c) {
71           const scalar_t val = idata[n][c][t1][h1][w1];
72           odata[n][c][t2][h2][w2] = val;
73         }
74       }
75       return;
76     }
77     //
78     const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t>(
79         rdepth, t2, align_corners, /*cubic=*/false);
80     const int t1 = t1r;
81     const int t1p = (t1 < depth1 - 1) ? 1 : 0;
82     const accscalar_t t1lambda = t1r - t1;
83     const accscalar_t t0lambda = static_cast<accscalar_t>(1) - t1lambda;
84     //
85     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
86         rheight, h2, align_corners, /*cubic=*/false);
87     const int h1 = h1r;
88     const int h1p = (h1 < height1 - 1) ? 1 : 0;
89     const accscalar_t h1lambda = h1r - h1;
90     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
91     //
92     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
93         rwidth, w2, align_corners, /*cubic=*/false);
94     const int w1 = w1r;
95     const int w1p = (w1 < width1 - 1) ? 1 : 0;
96     const accscalar_t w1lambda = w1r - w1;
97     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
98     //
99     for (int n = 0; n < batchsize; n++) {
100       for (int c = 0; c < channels; ++c) {
101         const accscalar_t val = t0lambda *
102                 (h0lambda *
103                      (w0lambda * idata[n][c][t1][h1][w1] +
104                       w1lambda * idata[n][c][t1][h1][w1 + w1p]) +
105                  h1lambda *
106                      (w0lambda * idata[n][c][t1][h1 + h1p][w1] +
107                       w1lambda * idata[n][c][t1][h1 + h1p][w1 + w1p])) +
108             t1lambda *
109                 (h0lambda *
110                      (w0lambda * idata[n][c][t1 + t1p][h1][w1] +
111                       w1lambda * idata[n][c][t1 + t1p][h1][w1 + w1p]) +
112                  h1lambda *
113                      (w0lambda * idata[n][c][t1 + t1p][h1 + h1p][w1] +
114                       w1lambda * idata[n][c][t1 + t1p][h1 + h1p][w1 + w1p]));
115         odata[n][c][t2][h2][w2] = static_cast<scalar_t>(val);
116       }
117     }
118   }
119 }
120 
121 // Backward (adjoint) operation 1 <- 2 (accumulates)
122 template <typename scalar_t, typename accscalar_t>
123 C10_LAUNCH_BOUNDS_1(256)
upsample_trilinear3d_backward_out_frame(const int num_kernels,const accscalar_t rdepth,const accscalar_t rheight,const accscalar_t rwidth,const bool align_corners,PackedTensorAccessor64<scalar_t,5> idata,const PackedTensorAccessor64<const scalar_t,5> odata,scalar_t * idata_ptr)124 __global__ void upsample_trilinear3d_backward_out_frame(
125     const int num_kernels,
126     const accscalar_t rdepth,
127     const accscalar_t rheight,
128     const accscalar_t rwidth,
129     const bool align_corners,
130     PackedTensorAccessor64<scalar_t, 5> idata,
131     const PackedTensorAccessor64<const scalar_t, 5> odata,
132     scalar_t* idata_ptr) {
133   int index = threadIdx.x + blockIdx.x * blockDim.x;
134 
135   const int batchsize = idata.size(0);
136   const int channels = idata.size(1);
137   const int depth1 = idata.size(2);
138   const int height1 = idata.size(3);
139   const int width1 = idata.size(4);
140   const int depth2 = odata.size(2);
141   const int height2 = odata.size(3);
142   const int width2 = odata.size(4);
143 
144   const size_t i_numel = batchsize * channels * depth1 * height1 * width1;
145 
146   if (index < num_kernels) {
147     const int w2 = (index % (height2 * width2)) % width2; // 0:width2-1
148     const int h2 = (index % (height2 * width2)) / width2; // 0:height2-1
149     const int t2 = index / (height2 * width2); // 0:depth2-1
150     // special case: just copy
151     if (depth1 == depth2 && height1 == height2 && width1 == width2) {
152       const int t1 = t2;
153       const int h1 = h2;
154       const int w1 = w2;
155 
156       for (int n = 0; n < batchsize; n++) {
157         for (int c = 0; c < channels; ++c) {
158           const scalar_t val = odata[n][c][t1][h1][w1];
159           idata[n][c][t2][h2][w2] = val;
160         }
161       }
162       return;
163     }
164     //
165     const accscalar_t t1r = area_pixel_compute_source_index<accscalar_t>(
166         rdepth, t2, align_corners, /*cubic=*/false);
167     const int t1 = t1r;
168     const int t1p = (t1 < depth1 - 1) ? 1 : 0;
169     const accscalar_t t1lambda = t1r - t1;
170     const accscalar_t t0lambda = static_cast<accscalar_t>(1) - t1lambda;
171     //
172     const accscalar_t h1r = area_pixel_compute_source_index<accscalar_t>(
173         rheight, h2, align_corners, /*cubic=*/false);
174     const int h1 = h1r;
175     const int h1p = (h1 < height1 - 1) ? 1 : 0;
176     const accscalar_t h1lambda = h1r - h1;
177     const accscalar_t h0lambda = static_cast<accscalar_t>(1) - h1lambda;
178     //
179     const accscalar_t w1r = area_pixel_compute_source_index<accscalar_t>(
180         rwidth, w2, align_corners, /*cubic=*/false);
181     const int w1 = w1r;
182     const int w1p = (w1 < width1 - 1) ? 1 : 0;
183     const accscalar_t w1lambda = w1r - w1;
184     const accscalar_t w0lambda = static_cast<accscalar_t>(1) - w1lambda;
185     //
186     for (int n = 0; n < batchsize; n++) {
187       for (int c = 0; c < channels; ++c) {
188         const scalar_t d2val = odata[n][c][t2][h2][w2];
189         const size_t nc = n * channels + c;
190         fastAtomicAdd(
191           idata_ptr,
192           idx_3d(nc, depth1, height1, width1, t1, h1, w1),
193           i_numel,
194           static_cast<scalar_t>(t0lambda * h0lambda * w0lambda * d2val),
195           true);
196         fastAtomicAdd(
197           idata_ptr,
198           idx_3d(nc, depth1, height1, width1, t1, h1, w1 + w1p),
199           i_numel,
200           static_cast<scalar_t>(t0lambda * h0lambda * w1lambda * d2val),
201           true);
202         fastAtomicAdd(
203           idata_ptr,
204           idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1),
205           i_numel,
206           static_cast<scalar_t>(t0lambda * h1lambda * w0lambda * d2val),
207           true);
208         fastAtomicAdd(
209           idata_ptr,
210           idx_3d(nc, depth1, height1, width1, t1, h1 + h1p, w1 + w1p),
211           i_numel,
212           static_cast<scalar_t>(t0lambda * h1lambda * w1lambda * d2val),
213           true);
214         fastAtomicAdd(
215           idata_ptr,
216           idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1),
217           i_numel,
218           static_cast<scalar_t>(t1lambda * h0lambda * w0lambda * d2val),
219           true);
220         fastAtomicAdd(
221           idata_ptr,
222           idx_3d(nc, depth1, height1, width1, t1 + t1p, h1, w1 + w1p),
223           i_numel,
224           static_cast<scalar_t>(t1lambda * h0lambda * w1lambda * d2val),
225           true);
226         fastAtomicAdd(
227           idata_ptr,
228           idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1),
229           i_numel,
230           static_cast<scalar_t>(t1lambda * h1lambda * w0lambda * d2val),
231           true);
232         fastAtomicAdd(
233           idata_ptr,
234           idx_3d(nc, depth1, height1, width1, t1 + t1p, h1 + h1p, w1 + w1p),
235           i_numel,
236           static_cast<scalar_t>(t1lambda * h1lambda * w1lambda * d2val),
237           true);
238       }
239     }
240   }
241 }
242 
upsample_trilinear3d_out_cuda_template(const Tensor & output,const Tensor & input,IntArrayRef output_size,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)243 static void upsample_trilinear3d_out_cuda_template(
244     const Tensor& output,
245     const Tensor& input,
246     IntArrayRef output_size,
247     bool align_corners,
248     std::optional<double> scales_d,
249     std::optional<double> scales_h,
250     std::optional<double> scales_w) {
251   TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
252   checkAllSameGPU("upsample_trilinear3d_out_cuda", {input_arg, output_arg});
253 
254   int output_depth = output_size[0];
255   int output_height = output_size[1];
256   int output_width = output_size[2];
257 
258   int input_depth = input.size(2);
259   int input_height = input.size(3);
260   int input_width = input.size(4);
261 
262   const int num_kernels = output_depth * output_height * output_width;
263   const int num_threads = std::min(
264       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 512);
265   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
266 
267   AT_DISPATCH_FLOATING_TYPES_AND2(
268       at::ScalarType::Half, at::ScalarType::BFloat16,
269       input.scalar_type(), "upsample_trilinear3d_out_frame", [&] {
270         using accscalar_t = at::acc_type<scalar_t, true>;
271 
272         auto idata = input.packed_accessor64<const scalar_t, 5>();
273         auto odata = output.packed_accessor64<scalar_t, 5>();
274 
275         const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t>(
276             input_depth, output_depth, align_corners, scales_d);
277         const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
278             input_height, output_height, align_corners, scales_h);
279         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
280             input_width, output_width, align_corners, scales_w);
281 
282         upsample_trilinear3d_out_frame<scalar_t, accscalar_t>
283             <<<ceil_div(num_kernels, num_threads),
284                num_threads,
285                0,
286                stream>>>(
287                 num_kernels,
288                 rdepth,
289                 rheight,
290                 rwidth,
291                 align_corners,
292                 idata,
293                 odata);
294         C10_CUDA_KERNEL_LAUNCH_CHECK();
295       });
296 }
297 
upsample_trilinear3d_backward_out_cuda_template(const Tensor & grad_input_,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,bool align_corners,std::optional<double> scales_d,std::optional<double> scales_h,std::optional<double> scales_w)298 static void upsample_trilinear3d_backward_out_cuda_template(
299     const Tensor& grad_input_,
300     const Tensor& grad_output_,
301     IntArrayRef output_size,
302     IntArrayRef input_size,
303     bool align_corners,
304     std::optional<double> scales_d,
305     std::optional<double> scales_h,
306     std::optional<double> scales_w) {
307   TensorArg grad_input_arg{grad_input_, "grad_input_", 1},
308       grad_output_arg{grad_output_, "grad_output_", 2};
309   checkAllSameGPU(
310       "upsample_trilinear3d_backward_out_cuda",
311       {grad_output_arg, grad_input_arg});
312 
313   int output_depth = output_size[0];
314   int output_height = output_size[1];
315   int output_width = output_size[2];
316 
317   int input_depth = input_size[2];
318   int input_height = input_size[3];
319   int input_width = input_size[4];
320 
321   Tensor grad_output = grad_output_.contiguous();
322 
323   // A contiguous tensor is required for the kernel launch config
324   Tensor grad_input = grad_input_.contiguous();
325 
326   // Numbers are added atomically to grad_input tensor from multiple threads,
327   // so it has to be initialized to zero.
328   grad_input.zero_();
329 
330   const int num_kernels = output_depth * output_height * output_width;
331   const int num_threads = std::min(
332       at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 256);
333   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
334 
335   AT_DISPATCH_FLOATING_TYPES_AND2(
336       at::ScalarType::Half, at::ScalarType::BFloat16,
337       grad_output.scalar_type(),
338       "upsample_trilinear3d_backward_out_frame",
339       [&] {
340         using accscalar_t = at::acc_type<scalar_t, true>;
341 
342         auto idata = grad_input.packed_accessor64<scalar_t, 5>();
343         auto odata = grad_output.packed_accessor64<const scalar_t, 5>();
344         scalar_t* idata_ptr = grad_input.mutable_data_ptr<scalar_t>();
345 
346         const accscalar_t rdepth = area_pixel_compute_scale<accscalar_t>(
347             input_depth, output_depth, align_corners, scales_d);
348         const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
349             input_height, output_height, align_corners, scales_h);
350         const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
351             input_width, output_width, align_corners, scales_w);
352 
353         upsample_trilinear3d_backward_out_frame<scalar_t, accscalar_t>
354             <<<ceil_div(num_kernels, num_threads),
355                num_threads,
356                0,
357                stream>>>(
358                 num_kernels,
359                 rdepth,
360                 rheight,
361                 rwidth,
362                 align_corners,
363                 idata,
364                 odata,
365                 idata_ptr);
366         C10_CUDA_KERNEL_LAUNCH_CHECK();
367 
368         if (!grad_input_.is_contiguous()) {
369             grad_input_.copy_(grad_input);
370         }
371       });
372 }
373 
374 } // namespace
375 
TORCH_IMPL_FUNC(upsample_trilinear3d_out_cuda)376 TORCH_IMPL_FUNC(upsample_trilinear3d_out_cuda) (
377     const Tensor& input,
378     IntArrayRef output_size,
379     bool align_corners,
380     std::optional<double> scales_d,
381     std::optional<double> scales_h,
382     std::optional<double> scales_w,
383     const Tensor& output) {
384   upsample_trilinear3d_out_cuda_template(output, input, output_size, align_corners, scales_d, scales_h, scales_w);
385 }
386 
TORCH_IMPL_FUNC(upsample_trilinear3d_backward_out_cuda)387 TORCH_IMPL_FUNC(upsample_trilinear3d_backward_out_cuda) (
388     const Tensor& grad_output,
389     IntArrayRef output_size,
390     IntArrayRef input_size,
391     bool align_corners,
392     std::optional<double> scales_d,
393     std::optional<double> scales_h,
394     std::optional<double> scales_w,
395     const Tensor& grad_input) {
396   // See Note [Writing Nondeterministic Operations]
397   // Nondeterministic because of atomicAdd usage
398   globalContext().alertNotDeterministic("upsample_trilinear3d_backward_out_cuda");
399   upsample_trilinear3d_backward_out_cuda_template(
400       grad_input, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w);
401 }
402 
403 } // namespace at::native
404