xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/BatchNorm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 
6 #ifdef __HIP_PLATFORM_AMD__
7 #include <ATen/native/cudnn/hip/BatchNorm.h>
8 #else
9 #include <ATen/native/cudnn/BatchNorm.h>
10 #endif
11 
12 #if !AT_CUDNN_ENABLED()
13 
14 namespace at {
15 namespace native {
16 
17 // See Note [ATen preprocessor philosophy]
18 
cudnn_batch_norm(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias_opt,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,bool training,double exponential_average_factor,double epsilon)19 std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
20     const Tensor& input,
21     const Tensor& weight,
22     const std::optional<Tensor>& bias_opt,
23     const std::optional<Tensor>& running_mean_opt,
24     const std::optional<Tensor>& running_var_opt,
25     bool training,
26     double exponential_average_factor,
27     double epsilon) {
28   AT_ERROR("cudnn_batch_norm: ATen not compiled with cuDNN support");
29 }
30 
cudnn_batch_norm_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_opt,const std::optional<Tensor> & save_var_opt,double epsilon,const Tensor & reservedSpace)31 std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
32     const Tensor& input,
33     const Tensor& grad_output,
34     const Tensor& weight,
35     const std::optional<Tensor>& running_mean_opt,
36     const std::optional<Tensor>& running_var_opt,
37     const std::optional<Tensor>& save_mean_opt,
38     const std::optional<Tensor>& save_var_opt,
39     double epsilon,
40     const Tensor& reservedSpace) {
41   AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support");
42 }
43 
_get_cudnn_batch_norm_reserve_space_size(const Tensor & input_t,bool training)44 size_t _get_cudnn_batch_norm_reserve_space_size(
45     const Tensor& input_t,
46     bool training) {
47   AT_ERROR(
48       "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support");
49 }
50 
51 } // namespace native
52 } // namespace at
53 
54 #else // AT_CUDNN_ENABLED
55 
56 #include <ATen/TensorUtils.h>
57 #include <ATen/cuda/Exceptions.h>
58 #include <ATen/cudnn/Descriptors.h>
59 #include <ATen/cudnn/Types.h>
60 #include <ATen/cudnn/Utils.h>
61 
62 #ifndef AT_PER_OPERATOR_HEADERS
63 #include <ATen/Functions.h>
64 #include <ATen/NativeFunctions.h>
65 #else
66 #include <ATen/ops/cudnn_batch_norm_backward_native.h>
67 #include <ATen/ops/cudnn_batch_norm_native.h>
68 #include <ATen/ops/empty.h>
69 #include <ATen/ops/empty_like.h>
70 #endif
71 
72 namespace at {
73 namespace native {
74 
75 namespace {
76 
expandScale(const Tensor & t,int64_t dim)77 Tensor expandScale(const Tensor& t, int64_t dim) {
78   std::vector<int64_t> size{1, t.numel()};
79   while (static_cast<int64_t>(size.size()) < dim) {
80     size.emplace_back(1);
81   }
82   return t.view(size);
83 }
84 
getCudnnBatchNormMode(bool training,at::MemoryFormat memory_format,int64_t dim)85 cudnnBatchNormMode_t getCudnnBatchNormMode(
86     bool training,
87     at::MemoryFormat memory_format,
88     int64_t dim) {
89   if (dim == 2) {
90     return CUDNN_BATCHNORM_PER_ACTIVATION;
91   } else if (training && memory_format == at::MemoryFormat::ChannelsLast) {
92     return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
93   } else if (training && memory_format == at::MemoryFormat::ChannelsLast3d) {
94     return CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
95   } else {
96     // TODO: The new CUDNN_BATCHNORM_SPATIAL_PERSISTENT mode was
97     // introduced in CuDNN 7 for performance optimization, but it results in
98     // accuracy losses in convolution models such as ResNeXt-101 and
99     // video R(2+1)D. We will fall back to the normal CUDNN_BATCHNORM_SPATIAL
100     return CUDNN_BATCHNORM_SPATIAL;
101   }
102 }
103 
104 } // namespace
105 
_get_cudnn_batch_norm_reserve_space_size(const Tensor & input_t,bool training)106 size_t _get_cudnn_batch_norm_reserve_space_size(
107     const Tensor& input_t,
108     bool training) {
109   size_t reserve_size;
110   TensorArg input{input_t, "input", 1};
111   TensorDescriptor idesc{*input, 4};
112   auto handle = getCudnnHandle();
113   cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
114       training, input->suggest_memory_format(), input->dim());
115   auto op = CUDNN_BATCHNORM_OPS_BN;
116   AT_CUDNN_CHECK(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
117       handle, mode, op, nullptr, idesc.desc(), &reserve_size));
118   return reserve_size;
119 }
120 
cudnn_batch_norm(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_t_opt,const std::optional<Tensor> & running_mean_t_opt,const std::optional<Tensor> & running_var_t_opt,bool training,double exponential_average_factor,double epsilon)121 std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
122     const Tensor& input_t,
123     const Tensor& weight_t,
124     const std::optional<Tensor>& bias_t_opt,
125     const std::optional<Tensor>& running_mean_t_opt,
126     const std::optional<Tensor>& running_var_t_opt,
127     bool training,
128     double exponential_average_factor,
129     double epsilon) {
130   // See [Note: hacky wrapper removal for optional tensor]
131   c10::MaybeOwned<Tensor> bias_t_maybe_owned =
132       at::borrow_from_optional_tensor(bias_t_opt);
133   const Tensor& bias_t = *bias_t_maybe_owned;
134   const Tensor& running_mean_t =
135       c10::value_or_else(running_mean_t_opt, [] { return Tensor(); });
136   const Tensor& running_var_t =
137       c10::value_or_else(running_var_t_opt, [] { return Tensor(); });
138 
139   TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2},
140       bias{bias_t, "bias", 3}, running_mean{running_mean_t, "running_mean", 4},
141       running_var{running_var_t, "running_var", 5};
142   CheckedFrom c = "cudnn_batch_norm";
143 
144   checkAllDefined(c, {input, weight, bias});
145   if (!training) {
146     checkAllDefined(c, {running_mean, running_var});
147   }
148   checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
149   if (input->scalar_type() == ScalarType::Half) {
150     checkScalarType(c, weight, ScalarType::Float);
151   } else {
152     checkAllSameType(c, {input, weight});
153   }
154   checkAllSameType(c, {weight, bias, running_mean, running_var});
155   // TODO: is weight required to be contiguous?
156   checkAllContiguous(c, {weight, bias, running_mean, running_var});
157   // TODO: TensorArg check should start handle memory format
158   TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
159 
160   checkDimRange(c, input, 2, 6 /* exclusive */);
161   auto num_features = input->size(1);
162   for (auto t : {weight, bias, running_mean, running_var}) {
163     if (t->defined()) {
164       checkNumel(c, t, num_features);
165     }
166   }
167 
168   cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
169       training, input->suggest_memory_format(), input->dim());
170 
171   auto output_t =
172       at::empty_like(*input, input->options(), input->suggest_memory_format());
173 
174   TensorArg output{output_t, "output", 0};
175 
176   auto handle = getCudnnHandle();
177   auto dataType = getCudnnDataType(*input);
178   TensorDescriptor idesc{*input, 4}; // input descriptor
179   TensorDescriptor wdesc{
180       expandScale(*weight, input->dim()),
181       4}; // descriptor for weight, bias, running_mean, etc.
182 
183   Constant one(dataType, 1);
184   Constant zero(dataType, 0);
185   Tensor save_mean, save_var;
186 
187   Tensor reserve;
188 
189   if (training) {
190     int64_t num_features = input_t.size(1);
191     save_mean = at::empty({num_features}, weight_t.options());
192     save_var = at::empty({num_features}, weight_t.options());
193 
194     auto op = CUDNN_BATCHNORM_OPS_BN;
195     size_t workspace_size;
196     AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
197         handle,
198         mode,
199         op,
200         idesc.desc(),
201         idesc.desc(),
202         idesc.desc(),
203         wdesc.desc(),
204         nullptr,
205         &workspace_size));
206     Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
207 
208     // get the reserved size and allocate as tensor
209     size_t reserve_size =
210         _get_cudnn_batch_norm_reserve_space_size(input_t, true /* training */);
211     reserve = at::empty(reserve_size, input->options().dtype(kByte));
212 
213     AT_CUDNN_CHECK(cudnnBatchNormalizationForwardTrainingEx(
214         handle,
215         mode,
216         op,
217         &one,
218         &zero,
219         idesc.desc(),
220         input->const_data_ptr(),
221         nullptr, // z descriptor for BN-Add-Relu
222         nullptr, // z for BN-Add-ReLU
223         idesc.desc(),
224         output->data_ptr(),
225         wdesc.desc(),
226         weight->const_data_ptr(),
227         bias->const_data_ptr(),
228         exponential_average_factor,
229         at::maybe_data_ptr(running_mean),
230         at::maybe_data_ptr(running_var),
231         epsilon,
232         save_mean.mutable_data_ptr(),
233         save_var.mutable_data_ptr(),
234         nullptr,
235         workspace.data_ptr(),
236         workspace_size,
237         reserve.mutable_data_ptr(),
238         reserve_size));
239   } else {
240     reserve = at::empty({0}, input->options().dtype(kByte));
241     // This keeps a consistent output with native_batch_norm
242     save_mean = at::empty({0}, weight_t.options());
243     save_var = at::empty({0}, weight_t.options());
244     AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
245         handle,
246         mode,
247         &one,
248         &zero,
249         idesc.desc(),
250         input->const_data_ptr(),
251         idesc.desc(),
252         output->data_ptr(),
253         wdesc.desc(),
254         weight->const_data_ptr(),
255         bias->const_data_ptr(),
256         running_mean->const_data_ptr(),
257         running_var->const_data_ptr(),
258         epsilon));
259   }
260 
261   // save_mean and save_var can be undefined
262   // If this causes problems, we can initialize them to empty tensors
263   // of the correct type
264   return std::tuple<Tensor, Tensor, Tensor, Tensor>{
265       output_t, save_mean, save_var, reserve};
266 }
267 
268 // NB: CuDNN only implements the backward algorithm for batchnorm
269 // in training mode (evaluation mode batchnorm has a different algorithm),
270 // which is why this doesn't accept a 'training' parameter.
cudnn_batch_norm_backward(const Tensor & input_t,const Tensor & grad_output_t,const Tensor & weight_t,const std::optional<Tensor> & running_mean_opt,const std::optional<Tensor> & running_var_opt,const std::optional<Tensor> & save_mean_t_opt,const std::optional<Tensor> & save_var_t_opt,double epsilon,const Tensor & reserveSpace)271 std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
272     const Tensor& input_t,
273     const Tensor& grad_output_t,
274     const Tensor& weight_t,
275     // Unused: but we require them to be passed so that double backwards
276     // has access
277     const std::optional<Tensor>& running_mean_opt,
278     const std::optional<Tensor>& running_var_opt,
279     const std::optional<Tensor>& save_mean_t_opt,
280     const std::optional<Tensor>& save_var_t_opt,
281     double epsilon,
282     const Tensor& reserveSpace) {
283   // See [Note: hacky wrapper removal for optional tensor]
284   const Tensor& save_mean_t =
285       c10::value_or_else(save_mean_t_opt, [] { return Tensor(); });
286   const Tensor& save_var_t =
287       c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
288 
289   // TODO: Is it worth it to have a contiguous call or maybe we should go with
290   // whatever format is given here.
291 
292   auto grad_output_contig =
293       grad_output_t.contiguous(input_t.suggest_memory_format());
294   TensorArg input{input_t, "input", 1},
295       grad_output{grad_output_contig, "grad_output", 2},
296       weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4},
297       save_var{save_var_t, "save_var", 5},
298       reserve{reserveSpace, "reserve_space", 6};
299   CheckedFrom c = "cudnn_batch_norm_backward";
300 
301   checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
302   checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
303   if (input->scalar_type() == ScalarType::Half) {
304     checkScalarType(c, weight, ScalarType::Float);
305   } else {
306     checkAllSameType(c, {input, weight});
307   }
308   checkAllSameType(c, {input, grad_output});
309   checkAllSameType(c, {weight, save_mean, save_var});
310   // TODO: is weight required to be contiguous?
311   checkAllContiguous(c, {save_mean, save_var});
312   // TODO: TensorArg check should start handle memory format
313   TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
314   TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
315   checkDimRange(c, input, 2, 6 /* exclusive */);
316   checkSameSize(c, input, grad_output);
317   auto num_features = input->size(1);
318   for (auto t : {weight, save_mean, save_var}) {
319     checkNumel(c, t, num_features);
320   }
321 
322   cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
323       true, // training
324       input->suggest_memory_format(),
325       input->dim());
326 
327   auto grad_input_t = at::empty(
328       input->sizes(), input->options(), input->suggest_memory_format());
329   auto grad_weight_t = at::empty(weight->sizes(), weight->options());
330   auto grad_bias_t = at::empty(weight->sizes(), weight->options());
331 
332   auto handle = getCudnnHandle();
333   auto dataType = getCudnnDataType(*input);
334 
335   TensorDescriptor idesc{*input, 4}; // input, grad_output descriptor
336   TensorDescriptor odesc{*grad_output, 4}; // input, grad_output descriptor
337   TensorDescriptor wdesc{
338       expandScale(*weight, input->dim()),
339       4}; // descriptor for weight, save_mean, etc.
340 
341   Constant one(dataType, 1);
342   Constant zero(dataType, 0);
343 
344   auto op = CUDNN_BATCHNORM_OPS_BN;
345 
346   size_t workspace_size;
347   AT_CUDNN_CHECK(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
348       handle,
349       mode,
350       op,
351       idesc.desc(),
352       idesc.desc(),
353       idesc.desc(),
354       nullptr,
355       odesc.desc(),
356       wdesc.desc(),
357       nullptr,
358       &workspace_size));
359   Tensor workspace = at::empty(workspace_size, input->options().dtype(kByte));
360 
361   AT_CUDNN_CHECK(cudnnBatchNormalizationBackwardEx(
362       handle,
363       mode,
364       op,
365       &one,
366       &zero,
367       &one,
368       &zero,
369       idesc.desc(),
370       input->const_data_ptr(),
371       nullptr,
372       nullptr,
373       odesc.desc(),
374       grad_output->const_data_ptr(),
375       nullptr,
376       nullptr,
377       idesc.desc(),
378       grad_input_t.data_ptr(),
379       wdesc.desc(),
380       weight->const_data_ptr(),
381       nullptr,
382       grad_weight_t.data_ptr(),
383       grad_bias_t.data_ptr(),
384       epsilon,
385       save_mean->const_data_ptr(),
386       save_var->const_data_ptr(),
387       nullptr,
388       workspace.data_ptr(),
389       workspace_size,
390       reserve->data_ptr(),
391       reserve->numel()));
392 
393   return std::tuple<Tensor, Tensor, Tensor>{
394       grad_input_t, grad_weight_t, grad_bias_t};
395 }
396 
397 } // namespace native
398 } // namespace at
399 
400 #endif
401