1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/TensorUtils.h>
7 #include <ATen/Utils.h>
8 #include <ATen/cuda/CUDAContext.h>
9 #include <ATen/native/cuda/UpSample.cuh>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/upsample_nearest1d_native.h>
16 #include <ATen/ops/upsample_nearest1d_backward_native.h>
17 #include <ATen/ops/_upsample_nearest_exact1d_native.h>
18 #include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
19 #endif
20
21 namespace at::native {
22 namespace {
23
24 #define MAX_THREADS 512
25
26 // Define a typedef to dispatch to nearest_neighbor_compute_source_index or
27 // nearest_neighbor_exact_compute_source_index
28 typedef int (*nn_compute_source_index_fn_t)(const float, int, int);
29
30 // Define a typedef to dispatch to nearest_neighbor_bw_compute_source_index or
31 // nearest_neighbor_exact_bw_compute_source_index
32 typedef int (*nn_bw_compute_source_index_fn_t)(const float, int, int);
33
34
35 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
36 template <typename scalar_t, nn_compute_source_index_fn_t nn_compute_source_index_fn>
37 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest1d_out_frame(const scalar_t * input,size_t dim_b,size_t dim_c,size_t src_dim_w,size_t dst_dim_w,scalar_t * output,float scale_factor)38 __global__ void upsample_nearest1d_out_frame(
39 const scalar_t* input,
40 size_t dim_b,
41 size_t dim_c,
42 size_t src_dim_w,
43 size_t dst_dim_w,
44 scalar_t* output,
45 float scale_factor) {
46 int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
47 if (dst_idx >= dim_c * dst_dim_w)
48 return;
49
50 int c = (dst_idx / dst_dim_w) % dim_c;
51
52 int dst_x = dst_idx % dst_dim_w;
53 int src_x = nn_compute_source_index_fn(scale_factor, dst_x, src_dim_w);
54
55 int src_idx = c * src_dim_w + src_x;
56 int src_stride = dim_c * src_dim_w;
57 int dst_stride = dim_c * dst_dim_w;
58
59 for (int b = 0; b < dim_b; b++) {
60 output[dst_idx] = input[src_idx];
61 src_idx += src_stride;
62 dst_idx += dst_stride;
63 }
64 }
65
66 // see NOTE [ Nearest neighbor upsampling kernel implementation ]
67 // Backward operation
68 template <typename scalar_t, typename accscalar_t, nn_bw_compute_source_index_fn_t nn_bw_compute_source_index_fn>
69 C10_LAUNCH_BOUNDS_1(1024)
upsample_nearest1d_backward_out_frame(const scalar_t * grad_o,size_t dim_b,size_t dim_c,size_t src_dim_w,size_t dst_dim_w,scalar_t * grad_i,float scale_factor)70 __global__ void upsample_nearest1d_backward_out_frame(
71 const scalar_t* grad_o,
72 size_t dim_b,
73 size_t dim_c,
74 size_t src_dim_w,
75 size_t dst_dim_w,
76 scalar_t* grad_i,
77 float scale_factor) {
78
79 int dst_idx = blockIdx.x * blockDim.x + threadIdx.x;
80 if (dst_idx >= dim_c * dst_dim_w)
81 return;
82
83 int c = (dst_idx / (dst_dim_w)) % dim_c;
84
85 int dst_x = dst_idx % dst_dim_w;
86 // note that we do not want to clamp src_x to src_dim_w, since we might
87 // intentionally want to skip in case of scale_factor < 1.0
88 int src_x = nn_bw_compute_source_index_fn(scale_factor, dst_x, src_dim_w);
89 int src_x_up = nn_bw_compute_source_index_fn(scale_factor, dst_x+1, src_dim_w);
90
91 for (int b = 0; b < dim_b; b++) {
92 accscalar_t grad = 0;
93 int src_idx = b * dim_c * src_dim_w + c * src_dim_w + src_x;
94 for (int x = src_x; x < src_x_up; x++) {
95 grad += grad_o[src_idx++];
96 }
97 grad_i[dst_idx] = grad;
98 dst_idx += dim_c * dst_dim_w;
99 }
100 }
101
102 template<nn_compute_source_index_fn_t nn_compute_source_index_fn>
upsample_nearest1d_out_cuda_template(const Tensor & output,const Tensor & input_,IntArrayRef output_size,std::optional<double> scales)103 static void upsample_nearest1d_out_cuda_template(
104 const Tensor& output,
105 const Tensor& input_,
106 IntArrayRef output_size,
107 std::optional<double> scales) {
108 TensorArg input_arg{input_, "input_", 1}, output_arg{output, "output", 2};
109 checkAllSameGPU("upsample_nearest1d_out_cuda", {input_arg, output_arg});
110
111 int output_width = output_size[0];
112
113 int nbatch = input_.size(0);
114 int channels = input_.size(1);
115 int input_width = input_.size(2);
116
117 Tensor input = input_.contiguous();
118
119 if (input.numel() == 0) {
120 return;
121 }
122
123 // upsample_nearest1d meta call makes sure `nbatch != 0`
124 unsigned int n = output.numel() / nbatch;
125 dim3 bdim{std::min<unsigned int>(
126 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
127 dim3 gdim{ceil_div(n, bdim.x)};
128 // safe check for int32 indexing; implicitly restrict launch config for kernel
129 TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max());
130
131 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
132 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, input.scalar_type(), "upsample_nearest1d_out_frame", [&] {
133 using accscalar_t = at::acc_type<scalar_t, true>;
134
135 auto idata = input.const_data_ptr<scalar_t>();
136 auto odata = output.mutable_data_ptr<scalar_t>();
137
138 const float scale_factor = compute_scales_value<float>(scales, input_width, output_width);
139
140 upsample_nearest1d_out_frame<scalar_t, nn_compute_source_index_fn><<<gdim, bdim, 0, stream>>>(
141 idata, nbatch, channels, input_width, output_width, odata, scale_factor);
142 C10_CUDA_KERNEL_LAUNCH_CHECK();
143 });
144 }
145
146 template<nn_compute_source_index_fn_t nn_bw_compute_source_index_fn>
upsample_nearest1d_backward_out_cuda_template(const Tensor & grad_input,const Tensor & grad_output_,IntArrayRef output_size,IntArrayRef input_size,std::optional<double> scales)147 static void upsample_nearest1d_backward_out_cuda_template(
148 const Tensor& grad_input,
149 const Tensor& grad_output_,
150 IntArrayRef output_size,
151 IntArrayRef input_size,
152 std::optional<double> scales) {
153 TensorArg grad_input_arg{grad_input, "grad_input", 1},
154 grad_output_arg{grad_output_, "grad_output_", 2};
155 checkAllSameGPU(
156 "upsample_nearest1d_backward_out_cuda_template",
157 {grad_output_arg, grad_input_arg});
158
159 int output_width = output_size[0];
160
161 int nbatch = input_size[0];
162 int channels = input_size[1];
163 int input_width = input_size[2];
164
165 Tensor grad_output = grad_output_.contiguous();
166
167 if (grad_input.numel() == 0) {
168 return;
169 }
170
171 // upsample_nearest1d meta call makes sure `nbatch != 0`
172 unsigned int n = grad_input.numel() / nbatch;
173 dim3 bdim{std::min<unsigned int>(
174 at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, MAX_THREADS)};
175 dim3 gdim{ceil_div(n, bdim.x)};
176 // safe check for int32 indexing; implicitly restrict launch config for kernel
177 TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max(),
178 "upsample_nearest1d_backward only supports input tensors with less than INT_MAX elements, but got ", grad_input.sizes());
179 TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max(),
180 "upsample_nearest1d_backward only supports output tensors with less than INT_MAX elements, but got ", grad_output.sizes());
181
182 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
183 AT_DISPATCH_FLOATING_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Byte, grad_output.scalar_type(), "upsample_nearest1d_backward_out_frame", [&] {
184 using accscalar_t = at::acc_type<scalar_t, true>;
185
186 auto idata = grad_input.mutable_data_ptr<scalar_t>();
187 auto odata = grad_output.const_data_ptr<scalar_t>();
188
189 const float scale_factor = compute_scales_value_backwards<float>(scales, output_width, input_width);
190
191 upsample_nearest1d_backward_out_frame<scalar_t, accscalar_t, nn_bw_compute_source_index_fn>
192 <<<gdim, bdim, 0, stream>>>(
193 odata, nbatch, channels, output_width, input_width, idata, scale_factor);
194 C10_CUDA_KERNEL_LAUNCH_CHECK();
195 });
196 }
197
198 } // namespace
199
TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda)200 TORCH_IMPL_FUNC(upsample_nearest1d_out_cuda) (
201 const Tensor& input,
202 IntArrayRef output_size,
203 std::optional<double> scales,
204 const Tensor& output
205 ) {
206 upsample_nearest1d_out_cuda_template<nearest_neighbor_compute_source_index>(
207 output, input, output_size, scales);
208 }
209
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_cuda)210 TORCH_IMPL_FUNC(_upsample_nearest_exact1d_out_cuda) (
211 const Tensor& input,
212 IntArrayRef output_size,
213 std::optional<double> scales,
214 const Tensor& output
215 ) {
216 upsample_nearest1d_out_cuda_template<nearest_neighbor_exact_compute_source_index>(output, input, output_size, scales);
217 }
218
TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda)219 TORCH_IMPL_FUNC(upsample_nearest1d_backward_out_cuda) (
220 const Tensor& grad_output,
221 IntArrayRef output_size,
222 IntArrayRef input_size,
223 std::optional<double> scales,
224 const Tensor& grad_input
225 ) {
226 upsample_nearest1d_backward_out_cuda_template<nearest_neighbor_bw_compute_source_index>(
227 grad_input, grad_output, output_size, input_size, scales);
228 }
229
TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_cuda)230 TORCH_IMPL_FUNC(_upsample_nearest_exact1d_backward_out_cuda) (
231 const Tensor& grad_output,
232 IntArrayRef output_size,
233 IntArrayRef input_size,
234 std::optional<double> scales,
235 const Tensor& grad_input
236 ) {
237 upsample_nearest1d_backward_out_cuda_template<nearest_neighbor_exact_bw_compute_source_index>(
238 grad_input, grad_output, output_size, input_size, scales);
239 }
240
241 } // namespace at::native
242