xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/Linear.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Config.h>
3 #include <ATen/Parallel.h>
4 #include <ATen/core/Tensor.h>
5 #include <torch/library.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/Functions.h>
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/_to_dense_native.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/linear.h>
14 #include <ATen/ops/mkldnn_linear_backward_input.h>
15 #include <ATen/ops/mkldnn_linear_backward_input_native.h>
16 #include <ATen/ops/mkldnn_linear_backward_native.h>
17 #include <ATen/ops/mkldnn_linear_backward_weights.h>
18 #include <ATen/ops/mkldnn_linear_backward_weights_native.h>
19 #include <ATen/ops/mkldnn_linear_native.h>
20 #endif
21 
22 #if !AT_MKLDNN_ENABLED()
23 
24 namespace at {
25 namespace native {
26 
mkldnn_linear(const Tensor & self,const Tensor & weight,const std::optional<Tensor> & bias_opt)27 Tensor mkldnn_linear(
28     const Tensor& self,
29     const Tensor& weight, const std::optional<Tensor>& bias_opt) {
30   TORCH_CHECK(false, "mkldnn_linear: ATen not compiled with MKLDNN support");
31 }
mkldnn_linear_backward_input(IntArrayRef input_size,const Tensor & grad_output,const Tensor & weight)32 Tensor mkldnn_linear_backward_input(
33     IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight) {
34   TORCH_CHECK(false, "mkldnn_linear_backward_input: ATen not compiled with MKLDNN support");
35 }
36 
mkldnn_linear_backward_weights(const Tensor & grad_output,const Tensor & input,const Tensor & weight,bool bias_defined)37 std::tuple<Tensor, Tensor> mkldnn_linear_backward_weights(
38     const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined) {
39   TORCH_CHECK(false, "mkldnn_linear_backward_weights: ATen not compiled with MKLDNN support");
40 }
41 
mkldnn_linear_backward(const Tensor & input,const Tensor & grad_output_t,const Tensor & weight,std::array<bool,3> output_mask)42 std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
43     const Tensor& input, const Tensor& grad_output_t,
44     const Tensor& weight, std::array<bool,3> output_mask) {
45   TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support");
46 }
47 
48 } // namespace native
49 } // namespace at
50 
51 #else // AT_MKLDNN_ENABLED
52 
53 #include <ATen/native/mkldnn/MKLDNNCommon.h>
54 #include <ATen/native/mkldnn/Utils.h>
55 
56 namespace at {
57 namespace native {
58 
mkldnn_linear(const Tensor & self,const Tensor & weight_t,const std::optional<Tensor> & bias_opt)59 Tensor mkldnn_linear(
60     const Tensor& self,
61     const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
62   // See [Note: hacky wrapper removal for optional tensor]
63   c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
64   const Tensor& bias = *bias_maybe_owned;
65 
66   const int64_t dim = self.dim();
67   TORCH_CHECK(
68       self.dim() != 0,
69       "mkldnn_linear: input needs to has dim at least 1, input dim ",
70       self.dim());
71   TORCH_CHECK(self.is_mkldnn(),
72       "mkldnn_linear: input needs to be mkldnn layout");
73   if (self.scalar_type() == ScalarType::BFloat16) {
74     TORCH_CHECK(mkldnn_bf16_device_check(),
75         "mkldnn_linear: bf16 path needs the cpu support avx_ne_convert or avx512bw, avx512vl and avx512dq");
76   } else if (self.scalar_type() == ScalarType::Half) {
77     TORCH_CHECK(mkldnn_fp16_device_check(),
78         "mkldnn_linear: fp16 path needs the cpu support avx_ne_convert or avx512_fp16");
79   }
80 
81   // reshape first if input dim != 2 and the reshape will cost a memory copy.
82   auto self_reshaped =
83       dim == 2 ? self : self.reshape({-1, self.size(self.dim() - 1)});
84 
85   const ideep::tensor x = itensor_from_mkldnn(self_reshaped);
86   // weight_t can be a mkldnn tensor or dense tensor.
87   const Tensor weight = (weight_t.is_mkldnn() || weight_t.is_contiguous()) ? weight_t : weight_t.contiguous();
88   const ideep::tensor w = itensor_from_tensor(weight);
89 
90   ideep::tensor y;
91   if (bias.defined()) {
92     const ideep::tensor b = itensor_from_tensor(bias);
93     ideep::inner_product_forward::compute(x, w, b, y);
94   } else {
95     ideep::inner_product_forward::compute(x, w, y);
96   }
97 
98   auto input_size = self.sizes();
99   std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
100   output_size.push_back(weight.size(0));
101 
102   if (self.dim() != 2) {
103     return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
104                                    self.options().device_opt()).reshape(output_size);
105   }
106   return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
107                                  self.options().device_opt());
108 }
109 
110 
mkldnn_linear_backward_input(IntArrayRef input_size,const Tensor & grad_output,const Tensor & weight_t)111 Tensor mkldnn_linear_backward_input(
112     IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight_t){
113   TORCH_CHECK(grad_output.is_mkldnn(),
114       "mkldnn_linear_backward: grad_output needs to be mkldnn layout");
115   TORCH_CHECK(weight_t.device().is_cpu() && weight_t.scalar_type() == kFloat,
116       "mkldnn_linear_backward: weight_t needs to be a dense tensor");
117   auto grad_output_reshaped = grad_output.dim() > 2 ?
118     grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
119 
120   ideep::tensor& grady = itensor_from_mkldnn(grad_output_reshaped);
121   // weight_t always dense tensor for training.
122   const Tensor weight = weight_t.is_contiguous() ? weight_t : weight_t.contiguous();
123   const ideep::tensor w = itensor_view_from_dense(weight);
124 
125   std::vector<int64_t> input_reshaped_size;
126   input_reshaped_size.push_back(grad_output_reshaped.size(0));
127   input_reshaped_size.push_back(weight.size(1));
128 
129   ideep::tensor gradx;
130   ideep::inner_product_backward_data::compute(
131     grady, w, {input_reshaped_size.begin(), input_reshaped_size.end()}, gradx);
132 
133   if (input_size.size() > 2) {
134     return new_with_itensor_mkldnn(std::move(gradx), optTypeMetaToScalarType(grad_output.options().dtype_opt()),
135                                    grad_output.options().device_opt()).reshape(input_size);
136   }
137   return new_with_itensor_mkldnn(std::move(gradx), optTypeMetaToScalarType(grad_output.options().dtype_opt()),
138                                  grad_output.options().device_opt());
139 }
140 
mkldnn_linear_backward_weights(const Tensor & grad_output,const Tensor & input,const Tensor & weight,bool bias_defined)141 std::tuple<Tensor, Tensor> mkldnn_linear_backward_weights(
142     const Tensor& grad_output, const Tensor& input, const Tensor& weight, bool bias_defined) {
143   TORCH_CHECK(grad_output.is_mkldnn() && input.is_mkldnn(),
144       "mkldnn_linear_backward: grad_output and input needs to be mkldnn layout");
145   TORCH_CHECK(weight.device().is_cpu() && weight.scalar_type() == kFloat,
146       "mkldnn_linear_backward: weight needs to be a dense tensor");
147 
148   auto grad_output_reshaped = grad_output.dim() > 2 ?
149     grad_output.reshape({-1, grad_output.size(grad_output.dim() - 1)}) : grad_output;
150   auto input_reshaped = input.dim() > 2 ? input.reshape({-1, input.size(input.dim() - 1)}) : input;
151 
152   ideep::tensor& grady = itensor_from_mkldnn(grad_output_reshaped);
153   ideep::tensor& x = itensor_from_mkldnn(input_reshaped);
154   ideep::tensor gradw, gradb;
155   if (bias_defined) {
156     ideep::inner_product_backward_weights::compute(x, grady, gradw, gradb);
157   } else {
158     ideep::inner_product_backward_weights::compute(x, grady, gradw);
159   }
160 
161   return std::tuple<Tensor, Tensor>{
162     mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradw),
163                     optTypeMetaToScalarType(weight.options().dtype_opt()),
164                     weight.options().device_opt())),
165     mkldnn_to_dense(new_with_itensor_mkldnn(std::move(gradb),
166                     optTypeMetaToScalarType(weight.options().dtype_opt()),
167                     weight.options().device_opt()))};
168 }
169 
mkldnn_linear_backward(const Tensor & input,const Tensor & grad_output,const Tensor & weight,std::array<bool,3> output_mask)170 std::tuple<Tensor, Tensor, Tensor> mkldnn_linear_backward(
171     const Tensor& input, const Tensor& grad_output,
172     const Tensor& weight, std::array<bool,3> output_mask) {
173   Tensor grad_input, grad_weight, grad_bias;
174   if (output_mask[0]) {
175     grad_input = at::mkldnn_linear_backward_input(input.sizes(), grad_output, weight);
176   }
177   if (output_mask[1] || output_mask[2]) {
178     std::tie(grad_weight, grad_bias) = at::mkldnn_linear_backward_weights(grad_output, input, weight, output_mask[2]);
179   }
180   return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
181 }
182 
mkldnn_linear_pointwise(const Tensor & input_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,c10::string_view attr,torch::List<std::optional<at::Scalar>> scalars,std::optional<c10::string_view> algorithm)183 static Tensor mkldnn_linear_pointwise(
184     const Tensor& input_t,
185     const Tensor& weight_t,
186     const std::optional<Tensor>& bias_opt,
187     c10::string_view attr,
188     torch::List<std::optional<at::Scalar>> scalars,
189     std::optional<c10::string_view> algorithm) {
190   auto input = input_t.contiguous();
191   auto input_size = input.sizes();
192 
193   // Make sure input has default contiguous strides if it's contiguous tensors for better performance.
194   input = may_convert_to_default_contiguous_strides(input);
195 
196   const int64_t dim = input.dim();
197   auto input_reshaped =
198       dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
199 
200   std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
201   output_size.push_back(weight_t.size(0));
202   auto output = at::empty(output_size, input.options());
203   if (output.sym_numel() == 0) {
204     return output;
205   }
206   if (dim != 2) {
207     std::vector<int64_t> output_size_reshaped = {input_reshaped.size(0),
208                                                  weight_t.size(0)};
209     output = output.reshape(output_size_reshaped);
210   }
211 
212   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
213   ideep::tensor mkldnn_output = itensor_from_tensor(output);
214 
215   c10::MaybeOwned<Tensor> bias_maybe_owned =
216       at::borrow_from_optional_tensor(bias_opt);
217   const Tensor& bias = *bias_maybe_owned;
218 
219   const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped);
220 
221   std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
222   if (bias.defined()) {
223     mkldnn_bias = itensor_from_tensor(bias);
224   }
225   const ideep::tensor w = itensor_from_tensor(weight_t);
226 
227   ideep::attr_t op_attr = ideep::attr_t();
228   if (attr != "none") {
229     auto it = fusion_unary_attr_map().find(attr);
230     TORCH_CHECK(
231         it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
232     op_attr = it->second(scalars, algorithm);
233   }
234 
235   if (mkldnn_bias.has_value()) {
236     ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
237         mkldnn_input,
238         w,
239         mkldnn_bias.value(),
240         mkldnn_output,
241         op_attr);
242   } else {
243     ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
244         mkldnn_input,
245         w,
246         mkldnn_output,
247         op_attr);
248   }
249 
250   if (dim != 2) {
251     output = output.reshape(output_size);
252   }
253 
254   return output;
255 }
256 
mkldnn_linear_pointwise_binary(const Tensor & input_t,const Tensor & other_t,const Tensor & weight_t,const std::optional<Tensor> & bias_opt,c10::string_view attr)257 static Tensor mkldnn_linear_pointwise_binary(
258     const Tensor& input_t,
259     const Tensor& other_t,
260     const Tensor& weight_t,
261     const std::optional<Tensor>& bias_opt,
262     c10::string_view attr) {
263   c10::MaybeOwned<Tensor> bias_maybe_owned =
264       at::borrow_from_optional_tensor(bias_opt);
265   const Tensor& bias = *bias_maybe_owned;
266   // Make sure inputs have same type(device, layout, dtype), device is cpu and
267   // dtype is float or bfloat16.
268   check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias);
269 
270   auto input = input_t.contiguous();
271   // Make sure input has default contiguous strides if it's contiguous tensors for better performance.
272   input = may_convert_to_default_contiguous_strides(input);
273 
274   auto it_binary = fusion_binary_alg_map().find(attr);
275   TORCH_CHECK(
276       it_binary != fusion_binary_alg_map().end(), "Fusion behavior undefined.");
277 
278   auto input_size = input.sizes();
279 
280   const int64_t dim = input.dim();
281   auto input_reshaped =
282       dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
283 
284   std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
285   output_size.push_back(weight_t.size(0));
286   auto output = at::empty(output_size, input.options());
287   if (output.sym_numel() == 0) {
288     return output;
289   }
290   auto other_reshaped = other_t.contiguous();
291   other_reshaped = may_convert_to_default_contiguous_strides(other_reshaped);
292 
293   if (dim != 2) {
294     std::vector<int64_t> output_size_reshaped = {
295         input_reshaped.size(0), weight_t.size(0)};
296     output = output.reshape(output_size_reshaped);
297     other_reshaped = other_reshaped.reshape(output_size_reshaped);
298   }
299 
300   TORCH_CHECK(
301       output.sizes() == other_reshaped.sizes(),
302       "linear_binary_run expects the size of output and other tensor to be the same");
303 
304   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
305   ideep::tensor mkldnn_output = itensor_from_tensor(output);
306   const ideep::tensor mkldnn_other = itensor_from_tensor(other_reshaped);
307   const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped);
308 
309   std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
310   if (bias.defined()) {
311     mkldnn_bias = itensor_from_tensor(bias);
312   }
313   const ideep::tensor w = itensor_from_tensor(weight_t);
314 
315   auto other_desc = mkldnn_other.get_desc();
316   auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
317 
318   if (mkldnn_bias.has_value()) {
319     ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
320         mkldnn_input,
321         mkldnn_other,
322         w,
323         mkldnn_bias.value(),
324         mkldnn_output,
325         op_attr);
326   } else {
327     ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
328         mkldnn_input, mkldnn_other, w, mkldnn_output, op_attr);
329   }
330 
331   if (dim != 2) {
332     output = output.reshape(output_size);
333   }
334 
335   return output;
336 }
337 
338 #if AT_MKL_ENABLED()
339 #include <mkl.h>
340 
mkl_linear(const Tensor & self,const Tensor & mkl_weight_t,const Tensor & origin_weight_t,const std::optional<Tensor> & bias_opt,const int64_t prepack_batch_size)341 static Tensor mkl_linear(
342     const Tensor& self,
343     const Tensor& mkl_weight_t,
344     const Tensor& origin_weight_t,
345     const std::optional<Tensor>& bias_opt,
346     const int64_t prepack_batch_size) {
347   c10::MaybeOwned<Tensor> bias_maybe_owned =
348       at::borrow_from_optional_tensor(bias_opt);
349   const Tensor& bias = *bias_maybe_owned;
350   TORCH_CHECK(
351       self.options().type_equal(origin_weight_t.options()),
352       "Input type (",
353       self.toString(),
354       ") and weight type (",
355       origin_weight_t.toString(),
356       ") should be the same");
357   TORCH_CHECK(
358       !bias.defined() || (self.options().type_equal(bias.options())),
359       "Input type (",
360       self.toString(),
361       ") and bias type (",
362       bias.toString(),
363       ") should be the same");
364   TORCH_CHECK(
365       mkl_weight_t.scalar_type() == origin_weight_t.scalar_type() &&
366           origin_weight_t.scalar_type() == kFloat,
367       "mkl_linear: weight dtype should be float");
368 
369   c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset);
370   auto input_size = self.sizes();
371   std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
372   output_size.push_back(origin_weight_t.size(0));
373   auto output = at::empty(output_size, self.options());
374   if (self.sym_numel() == 0) {
375     // avoid to call self.numel() / 0 when self.size(self.dim() - 1)==0.
376     return output.fill_(0);
377   }
378   if (output.sym_numel() == 0) {
379     return output;
380   }
381   int64_t M = self.numel() / self.size(self.dim() - 1);
382   if (M == prepack_batch_size && mkl_weight_t.is_mkldnn()) {
383     auto self_ = self.is_contiguous() ? self : self.contiguous();
384     auto K = origin_weight_t.size(1);
385     auto N = origin_weight_t.size(0);
386     const ideep::tensor& w = itensor_from_mkldnn(mkl_weight_t);
387     auto in_ptr = self_.data_ptr<float>();
388     auto weight_ptr = (float*)(w.get_data_handle());
389     auto out_ptr = output.data_ptr<float>();
390     if (bias.defined()) {
391       auto bias_ = bias.is_contiguous() ? bias : bias.contiguous();
392       auto bias_ptr = bias_.data_ptr<float>();
393       at::parallel_for(0, M, 1, [&](int64_t begin, int64_t end) {
394         for (const auto d : c10::irange(begin, end)) {
395           memcpy(out_ptr + d * N, bias_ptr, sizeof(float) * N);
396         }
397       });
398     }
399     cblas_sgemm_compute(
400         CblasRowMajor,
401         CblasNoTrans,
402         CblasPacked,
403         M,
404         N,
405         K,
406         in_ptr,
407         K,
408         weight_ptr,
409         K,
410         bias.defined() ? 1.f : 0.f,
411         out_ptr,
412         N);
413   } else {
414     output = at::linear_out(output, self, origin_weight_t, bias_opt);
415   }
416   return output;
417 }
418 
TORCH_LIBRARY_IMPL(mkl,CPU,m)419 TORCH_LIBRARY_IMPL(mkl, CPU, m) {
420   m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear));
421 }
422 
TORCH_LIBRARY_IMPL(mkl,MkldnnCPU,m)423 TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
424   m.impl(TORCH_SELECTIVE_NAME("mkl::_mkl_linear"), TORCH_FN(mkl_linear));
425 }
426 
427 #endif// AT_MKL_ENABLED
428 
TORCH_LIBRARY_IMPL(mkldnn,CPU,m)429 TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
430   m.impl(
431       TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise"),
432       TORCH_FN(mkldnn_linear_pointwise));
433   m.impl(
434       TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise.binary"),
435       TORCH_FN(mkldnn_linear_pointwise_binary));
436 }
437 
TORCH_LIBRARY_IMPL(mkldnn,MkldnnCPU,m)438 TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {
439   m.impl(
440       TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise"),
441       TORCH_FN(mkldnn_linear_pointwise));
442   m.impl(
443       TORCH_SELECTIVE_NAME("mkldnn::_linear_pointwise.binary"),
444       TORCH_FN(mkldnn_linear_pointwise_binary));
445 }
446 
447 } // namespace native
448 } // namespace at
449 
450 #endif // AT_MKLDNN_ENABLED
451