xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Config.h>
2 
3 #if AT_MKLDNN_ENABLED()
4 
5 #include <ATen/Tensor.h>
6 #include <ATen/native/mkldnn/ConvPrepack.h>
7 #include <ATen/native/mkldnn/OpContext.h>
8 #include <ATen/native/mkldnn/Utils.h>
9 #include <torch/custom_class.h>
10 #include <torch/library.h>
11 
12 namespace at {
13 namespace native {
14 namespace mkldnn {
15 
16 using namespace internal::convolution;
17 
is_mkldnn_bf16_supported()18 static bool is_mkldnn_bf16_supported() {
19 #if defined(__aarch64__)
20   return mkldnn_bf16_device_check_arm();
21 #else
22   return mkldnn_bf16_device_check();
23 #endif
24 }
25 
is_mkldnn_fp16_supported()26 static bool is_mkldnn_fp16_supported() {
27   return mkldnn_fp16_device_check();
28 }
29 
is_mkldnn_acl_supported()30 constexpr bool is_mkldnn_acl_supported() {
31   return AT_MKLDNN_ACL_ENABLED();
32 }
33 
TORCH_LIBRARY(mkldnn,m)34 TORCH_LIBRARY(mkldnn, m) {
35   m.class_<ConvOpContext>(TORCH_SELECTIVE_CLASS("ConvOpContext"))
36       .def_pickle(
37           [](const c10::intrusive_ptr<ConvOpContext>& op_context)
38               -> SerializationTypeConvPrePack { // __getstate__
39             return op_context->unpack();
40           },
41           [](SerializationTypeConvPrePack state)
42               -> c10::intrusive_ptr<ConvOpContext> { // __setstate__
43             return createConvPrePackOpContext(
44                 std::move(std::get<0>(state)),
45                 std::move(std::get<1>(state)),
46                 std::move(std::get<2>(state)),
47                 std::move(std::get<3>(state)),
48                 std::move(std::get<4>(state)),
49                 std::move(std::get<5>(state)),
50                 std::move(std::get<6>(state)),
51                 std::move(std::get<7>(state)));
52           });
53 
54   m.def(TORCH_SELECTIVE_SCHEMA(
55       "mkldnn::_linear_pointwise(Tensor X, Tensor W, Tensor? B, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
56   m.def(TORCH_SELECTIVE_SCHEMA(
57       "mkldnn::_linear_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, str attr) -> Tensor Y"));
58   m.def(TORCH_SELECTIVE_SCHEMA(
59       "mkldnn::_convolution_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
60   m.def(TORCH_SELECTIVE_SCHEMA(
61       "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y"));
62   m.def(TORCH_SELECTIVE_SCHEMA(
63       "mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
64   m.def(TORCH_SELECTIVE_SCHEMA(
65       "mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
66   m.def(TORCH_SELECTIVE_SCHEMA(
67       "mkldnn::_reorder_convolution_transpose_weight(Tensor self, int[2] padding=0, int[2] output_padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
68   m.def(TORCH_SELECTIVE_SCHEMA(
69       "mkldnn::_reorder_linear_weight(Tensor self, int? batch_size=None) -> Tensor Y"));
70   m.def(TORCH_SELECTIVE_SCHEMA(
71       "mkldnn::_reorder_convolution_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
72   m.def(TORCH_SELECTIVE_SCHEMA(
73       "mkldnn::_reorder_mkldnn_rnn_layer_weight(Tensor weight0, Tensor weight1, int hidden_size, bool reverse, bool has_biases, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
74   m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
75   m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
76   m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
77   m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
78   m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
79   m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
80 }
81 
TORCH_LIBRARY(mkldnn_prepacked,m)82 TORCH_LIBRARY(mkldnn_prepacked, m) {
83   m.def(TORCH_SELECTIVE_SCHEMA(
84       "mkldnn_prepacked::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, int[4] input_size, str attr) -> __torch__.torch.classes.mkldnn.ConvOpContext"));
85 
86   m.def(TORCH_SELECTIVE_SCHEMA(
87       "mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> Tensor Y"));
88 }
89 
TORCH_LIBRARY_IMPL(mkldnn_prepacked,CPU,m)90 TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) {
91   m.impl(
92       TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_prepack"),
93       TORCH_FN(createConvPrePackOpContext));
94 
95   m.impl(
96       TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_run"), TORCH_FN(conv_run));
97 }
98 
99 } // namespace mkldnn
100 } // namespace native
101 } // namespace at
102 
103 #endif // AT_MKLDNN_ENABLED()
104 
105 #if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
106 
107 namespace at {
108 namespace native {
109 namespace mkl {
110 
TORCH_LIBRARY(mkl,m)111 TORCH_LIBRARY(mkl, m) {
112   m.def(TORCH_SELECTIVE_SCHEMA(
113       "mkl::_mkl_reorder_linear_weight(Tensor X, int batch_size) -> Tensor"));
114   m.def(TORCH_SELECTIVE_SCHEMA(
115       "mkl::_mkl_linear(Tensor X, Tensor MKL_W, Tensor ORI_W, Tensor? B, int batch_size) -> Tensor"));
116 }
117 
118 } // namespace mkl
119 } // namespace native
120 } // namespace at
121 
122 #endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
123