1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5
6 #ifdef __HIP_PLATFORM_AMD__
7 #include <ATen/native/cudnn/hip/BatchNorm.h>
8 #else
9 #include <ATen/native/cudnn/BatchNorm.h>
10 #endif
11
12 #if !AT_CUDNN_ENABLED()
13
14 namespace at {
15 namespace native {
16
17 // See Note [ATen preprocessor philosophy]
18
cudnn_batch_norm(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double exponential_average_factor,double epsilon)19 std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
20 const Tensor& input,
21 const Tensor& weight,
22 const std::optional<Tensor>& bias_opt,
23 const std::optional<Tensor>& running_mean_opt,
24 const std::optional<Tensor>& running_var_opt,
25 bool training,
26 double exponential_average_factor,
27 double epsilon) {
28 AT_ERROR("cudnn_batch_norm: ATen not compiled with cuDNN support");
29 }
30
cudnn_batch_norm_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,double epsilon,const Tensor & reservedSpace)31 std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
32 const Tensor& input,
33 const Tensor& grad_output,
34 const Tensor& weight,
35 const std::optional<Tensor>& running_mean_opt,
36 const std::optional<Tensor>& running_var_opt,
37 const std::optional<Tensor>& save_mean_opt,
38 const std::optional<Tensor>& save_var_opt,
39 double epsilon,
40 const Tensor& reservedSpace) {
41 AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
42 }
43
_get_cudnn_batch_norm_reserve_space_size(const Tensor & input_t,bool training)44 size_t _get_cudnn_batch_norm_reserve_space_size(
45 const Tensor& input_t,
46 bool training) {
47 AT_ERROR(
48 "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
49 }
50
51 } // namespace native
52 } // namespace at
53
54 #else // AT_CUDNN_ENABLED
55
56 #include <ATen/TensorUtils.h>
57 #include <ATen/cuda/Exceptions.h>
58 #include <ATen/cudnn/Descriptors.h>
59 #include <ATen/cudnn/Types.h>
60 #include <ATen/cudnn/Utils.h>
61
62 #ifndef AT_PER_OPERATOR_HEADERS
63 #include <ATen/Functions.h>
64 #include <ATen/NativeFunctions.h>
65 #else
66 #include <ATen/ops/cudnn_batch_norm_backward_native.h>
67 #include <ATen/ops/cudnn_batch_norm_native.h>
68 #include <ATen/ops/empty.h>
69 #include <ATen/ops/empty_like.h>
70 #endif
71
72 namespace at {
73 namespace native {
74
75 namespace {
76
expandScale(const Tensor & t,int64_t dim)77 Tensor expandScale(const Tensor& t, int64_t dim) {
78 std::vector<int64_t> size{1, t.numel()};
79 while (static_cast<int64_t>(size.size()) < dim) {
80 size.emplace_back(1);
81 }
82 return t.view(size);
83 }
84
getCudnnBatchNormMode(bool training,at::MemoryFormat memory_format,int64_t dim)85 cudnnBatchNormMode_t getCudnnBatchNormMode(
86 bool training,
87 at::MemoryFormat memory_format,
88 int64_t dim) {
89 if (dim == 2) {
90 return CUDNN_BATCHNORM_PER_ACTIVATION;
91 } else if (training && memory_format == at::MemoryFormat::ChannelsLast) {
92 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
93 } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) {
94 return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
95 } else {
96 // TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
97 // introduced in CuDNN 7 for performance optimization, but it results in
98 // accuracy losses in convolution models such as ResNeXt-101 and
99 // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
100 return CUDNN_BATCHNORM_SPATIAL;
101 }
102 }
103
104 } // namespace
105
_get_cudnn_batch_norm_reserve_space_size(const Tensor & input_t,bool training)106 size_t _get_cudnn_batch_norm_reserve_space_size(
107 const Tensor& input_t,
108 bool training) {
109 size_t reserve_size;
110 TensorArg input{input_t, "input", 1};
111 TensorDescriptor idesc{*input, 4};
112 auto handle = getCudnnHandle();
113 cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
114 training, input->suggest_memory_format(), input->dim());
115 auto op = CUDNN_BATCHNORM_OPS_BN;
116 AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
117 handle, mode, op, nullptr, idesc.desc(), &reserve_size));
118 return reserve_size;
119 }
120
cudnn_batch_norm(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,const std::optional<Tensor> & running_mean_t_opt,const std::optional<Tensor> & running_var_t_opt,bool training,double exponential_average_factor,double epsilon)121 std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
122 const Tensor& input_t,
123 const Tensor& weight_t,
124 const std::optional<Tensor>& bias_t_opt,
125 const std::optional<Tensor>& running_mean_t_opt,
126 const std::optional<Tensor>& running_var_t_opt,
127 bool training,
128 double exponential_average_factor,
129 double epsilon) {
130 // See [Note: hacky wrapper removal for optional tensor]
131 c10::MaybeOwned<Tensor> bias_t_maybe_owned =
132 at::borrow_from_optional_tensor(bias_t_opt);
133 const Tensor& bias_t = *bias_t_maybe_owned;
134 const Tensor& running_mean_t =
135 c10::value_or_else(running_mean_t_opt, [] { return Tensor(); });
136 const Tensor& running_var_t =
137 c10::value_or_else(running_var_t_opt, [] { return Tensor(); });
138
139 TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2},
140 bias{bias_t, "bias", 3}, running_mean{running_mean_t, "running_mean", 4},
141 running_var{running_var_t, "running_var", 5};
142 CheckedFrom c = "cudnn_batch_norm";
143
144 checkAllDefined(c, {input, weight, bias});
145 if (!training) {
146 checkAllDefined(c, {running_mean, running_var});
147 }
148 checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
149 if (input->scalar_type() == ScalarType::Half) {
150 checkScalarType(c, weight, ScalarType::Float);
151 } else {
152 checkAllSameType(c, {input, weight});
153 }
154 checkAllSameType(c, {weight, bias, running_mean, running_var});
155 // TODO: is weight required to be contiguous?
156 checkAllContiguous(c, {weight, bias, running_mean, running_var});
157 // TODO: TensorArg check should start handle memory format
158 TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
159
160 checkDimRange(c, input, 2, 6 /* exclusive */);
161 auto num_features = input->size(1);
162 for (auto t : {weight, bias, running_mean, running_var}) {
163 if (t->defined()) {
164 checkNumel(c, t, num_features);
165 }
166 }
167
168 cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
169 training, input->suggest_memory_format(), input->dim());
170
171 auto output_t =
172 at::empty_like(*input, input->options(), input->suggest_memory_format());
173
174 TensorArg output{output_t, "output", 0};
175
176 auto handle = getCudnnHandle();
177 auto dataType = getCudnnDataType(*input);
178 TensorDescriptor idesc{*input, 4}; // input descriptor
179 TensorDescriptor wdesc{
180 expandScale(*weight, input->dim()),
181 4}; // descriptor for weight, bias, running_mean, etc.
182
183 Constant one(dataType, 1);
184 Constant zero(dataType, 0);
185 Tensor save_mean, save_var;
186
187 Tensor reserve;
188
189 if (training) {
190 int64_t num_features = input_t.size(1);
191 save_mean = at::empty({num_features}, weight_t.options());
192 save_var = at::empty({num_features}, weight_t.options());
193
194 auto op = CUDNN_BATCHNORM_OPS_BN;
195 size_t workspace_size;
196 AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
197 handle,
198 mode,
199 op,
200 idesc.desc(),
201 idesc.desc(),
202 idesc.desc(),
203 wdesc.desc(),
204 nullptr,
205 &workspace_size));
206 Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
207
208 // get the reserved size and allocate as tensor
209 size_t reserve_size =
210 _get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */);
211 reserve = at::empty(reserve_size, input->options().dtype(kByte));
212
213 AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
214 handle,
215 mode,
216 op,
217 &one,
218 &zero,
219 idesc.desc(),
220 input->const_data_ptr(),
221 nullptr, // z descriptor for BN-Add-Relu
222 nullptr, // z for BN-Add-ReLU
223 idesc.desc(),
224 output->data_ptr(),
225 wdesc.desc(),
226 weight->const_data_ptr(),
227 bias->const_data_ptr(),
228 exponential_average_factor,
229 at::maybe_data_ptr(running_mean),
230 at::maybe_data_ptr(running_var),
231 epsilon,
232 save_mean.mutable_data_ptr(),
233 save_var.mutable_data_ptr(),
234 nullptr,
235 workspace.data_ptr(),
236 workspace_size,
237 reserve.mutable_data_ptr(),
238 reserve_size));
239 } else {
240 reserve = at::empty({0}, input->options().dtype(kByte));
241 // This keeps a consistent output with native_batch_norm
242 save_mean = at::empty({0}, weight_t.options());
243 save_var = at::empty({0}, weight_t.options());
244 AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
245 handle,
246 mode,
247 &one,
248 &zero,
249 idesc.desc(),
250 input->const_data_ptr(),
251 idesc.desc(),
252 output->data_ptr(),
253 wdesc.desc(),
254 weight->const_data_ptr(),
255 bias->const_data_ptr(),
256 running_mean->const_data_ptr(),
257 running_var->const_data_ptr(),
258 epsilon));
259 }
260
261 // save_mean and save_var can be undefined
262 // If this causes problems, we can initialize them to empty tensors
263 // of the correct type
264 return std::tuple<Tensor, Tensor, Tensor, Tensor>{
265 output_t, save_mean, save_var, reserve};
266 }
267
268 // NB: CuDNN only implements the backward algorithm for batchnorm
269 // in training mode (evaluation mode batchnorm has a different algorithm),
270 // which is why this doesn't accept a 'training' parameter.
cudnn_batch_norm_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_t_opt,const std::optional<Tensor> & save_var_t_opt,double epsilon,const Tensor & reserveSpace)271 std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
272 const Tensor& input_t,
273 const Tensor& grad_output_t,
274 const Tensor& weight_t,
275 // Unused: but we require them to be passed so that double backwards
276 // has access
277 const std::optional<Tensor>& running_mean_opt,
278 const std::optional<Tensor>& running_var_opt,
279 const std::optional<Tensor>& save_mean_t_opt,
280 const std::optional<Tensor>& save_var_t_opt,
281 double epsilon,
282 const Tensor& reserveSpace) {
283 // See [Note: hacky wrapper removal for optional tensor]
284 const Tensor& save_mean_t =
285 c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
286 const Tensor& save_var_t =
287 c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
288
289 // TODO: Is it worth it to have a contiguous call or maybe we should go with
290 // whatever format is given here.
291
292 auto grad_output_contig =
293 grad_output_t.contiguous(input_t.suggest_memory_format());
294 TensorArg input{input_t, "input", 1},
295 grad_output{grad_output_contig, "grad_output", 2},
296 weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
297 save_var{save_var_t, "save_var", 5},
298 reserve{reserveSpace, "reserve_space", 6};
299 CheckedFrom c = "cudnn_batch_norm_backward";
300
301 checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
302 checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
303 if (input->scalar_type() == ScalarType::Half) {
304 checkScalarType(c, weight, ScalarType::Float);
305 } else {
306 checkAllSameType(c, {input, weight});
307 }
308 checkAllSameType(c, {input, grad_output});
309 checkAllSameType(c, {weight, save_mean, save_var});
310 // TODO: is weight required to be contiguous?
311 checkAllContiguous(c, {save_mean, save_var});
312 // TODO: TensorArg check should start handle memory format
313 TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
314 TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
315 checkDimRange(c, input, 2, 6 /* exclusive */);
316 checkSameSize(c, input, grad_output);
317 auto num_features = input->size(1);
318 for (auto t : {weight, save_mean, save_var}) {
319 checkNumel(c, t, num_features);
320 }
321
322 cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
323 true, // training
324 input->suggest_memory_format(),
325 input->dim());
326
327 auto grad_input_t = at::empty(
328 input->sizes(), input->options(), input->suggest_memory_format());
329 auto grad_weight_t = at::empty(weight->sizes(), weight->options());
330 auto grad_bias_t = at::empty(weight->sizes(), weight->options());
331
332 auto handle = getCudnnHandle();
333 auto dataType = getCudnnDataType(*input);
334
335 TensorDescriptor idesc{*input, 4}; // input, grad_output descriptor
336 TensorDescriptor odesc{*grad_output, 4}; // input, grad_output descriptor
337 TensorDescriptor wdesc{
338 expandScale(*weight, input->dim()),
339 4}; // descriptor for weight, save_mean, etc.
340
341 Constant one(dataType, 1);
342 Constant zero(dataType, 0);
343
344 auto op = CUDNN_BATCHNORM_OPS_BN;
345
346 size_t workspace_size;
347 AT_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
348 handle,
349 mode,
350 op,
351 idesc.desc(),
352 idesc.desc(),
353 idesc.desc(),
354 nullptr,
355 odesc.desc(),
356 wdesc.desc(),
357 nullptr,
358 &workspace_size));
359 Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
360
361 AT_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(
362 handle,
363 mode,
364 op,
365 &one,
366 &zero,
367 &one,
368 &zero,
369 idesc.desc(),
370 input->const_data_ptr(),
371 nullptr,
372 nullptr,
373 odesc.desc(),
374 grad_output->const_data_ptr(),
375 nullptr,
376 nullptr,
377 idesc.desc(),
378 grad_input_t.data_ptr(),
379 wdesc.desc(),
380 weight->const_data_ptr(),
381 nullptr,
382 grad_weight_t.data_ptr(),
383 grad_bias_t.data_ptr(),
384 epsilon,
385 save_mean->const_data_ptr(),
386 save_var->const_data_ptr(),
387 nullptr,
388 workspace.data_ptr(),
389 workspace_size,
390 reserve->data_ptr(),
391 reserve->numel()));
392
393 return std::tuple<Tensor, Tensor, Tensor>{
394 grad_input_t, grad_weight_t, grad_bias_t};
395 }
396
397 } // namespace native
398 } // namespace at
399
400 #endif
401