1 #include <ATen/Config.h>
2
3 #if AT_MKLDNN_ENABLED()
4
5 #include <ATen/Tensor.h>
6 #include <ATen/native/mkldnn/ConvPrepack.h>
7 #include <ATen/native/mkldnn/OpContext.h>
8 #include <ATen/native/mkldnn/Utils.h>
9 #include <torch/custom_class.h>
10 #include <torch/library.h>
11
12 namespace at {
13 namespace native {
14 namespace mkldnn {
15
16 using namespace internal::convolution;
17
is_mkldnn_bf16_supported()18 static bool is_mkldnn_bf16_supported() {
19 #if defined(__aarch64__)
20 return mkldnn_bf16_device_check_arm();
21 #else
22 return mkldnn_bf16_device_check();
23 #endif
24 }
25
is_mkldnn_fp16_supported()26 static bool is_mkldnn_fp16_supported() {
27 return mkldnn_fp16_device_check();
28 }
29
is_mkldnn_acl_supported()30 constexpr bool is_mkldnn_acl_supported() {
31 return AT_MKLDNN_ACL_ENABLED();
32 }
33
TORCH_LIBRARY(mkldnn,m)34 TORCH_LIBRARY(mkldnn, m) {
35 m.class_<ConvOpContext>(TORCH_SELECTIVE_CLASS("ConvOpContext"))
36 .def_pickle(
37 [](const c10::intrusive_ptr<ConvOpContext>& op_context)
38 -> SerializationTypeConvPrePack { // __getstate__
39 return op_context->unpack();
40 },
41 [](SerializationTypeConvPrePack state)
42 -> c10::intrusive_ptr<ConvOpContext> { // __setstate__
43 return createConvPrePackOpContext(
44 std::move(std::get<0>(state)),
45 std::move(std::get<1>(state)),
46 std::move(std::get<2>(state)),
47 std::move(std::get<3>(state)),
48 std::move(std::get<4>(state)),
49 std::move(std::get<5>(state)),
50 std::move(std::get<6>(state)),
51 std::move(std::get<7>(state)));
52 });
53
54 m.def(TORCH_SELECTIVE_SCHEMA(
55 "mkldnn::_linear_pointwise(Tensor X, Tensor W, Tensor? B, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
56 m.def(TORCH_SELECTIVE_SCHEMA(
57 "mkldnn::_linear_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, str attr) -> Tensor Y"));
58 m.def(TORCH_SELECTIVE_SCHEMA(
59 "mkldnn::_convolution_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
60 m.def(TORCH_SELECTIVE_SCHEMA(
61 "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y"));
62 m.def(TORCH_SELECTIVE_SCHEMA(
63 "mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y"));
64 m.def(TORCH_SELECTIVE_SCHEMA(
65 "mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y"));
66 m.def(TORCH_SELECTIVE_SCHEMA(
67 "mkldnn::_reorder_convolution_transpose_weight(Tensor self, int[2] padding=0, int[2] output_padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
68 m.def(TORCH_SELECTIVE_SCHEMA(
69 "mkldnn::_reorder_linear_weight(Tensor self, int? batch_size=None) -> Tensor Y"));
70 m.def(TORCH_SELECTIVE_SCHEMA(
71 "mkldnn::_reorder_convolution_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1, int[]? input_size=None) -> Tensor Y"));
72 m.def(TORCH_SELECTIVE_SCHEMA(
73 "mkldnn::_reorder_mkldnn_rnn_layer_weight(Tensor weight0, Tensor weight1, int hidden_size, bool reverse, bool has_biases, bool batch_first, int[]? input_size=None) -> Tensor[] Y"));
74 m.def("_is_mkldnn_bf16_supported", &is_mkldnn_bf16_supported);
75 m.def("_is_mkldnn_fp16_supported", &is_mkldnn_fp16_supported);
76 m.def("_is_mkldnn_acl_supported", &is_mkldnn_acl_supported);
77 m.def("mkldnn::data_ptr(Tensor mkldnn_tensor) -> int");
78 m.def("mkldnn::_get_mkldnn_serialized_md (Tensor mkldnn_tensor) -> Tensor");
79 m.def("mkldnn::_nbytes(Tensor mkldnn_tensor) -> int");
80 }
81
TORCH_LIBRARY(mkldnn_prepacked,m)82 TORCH_LIBRARY(mkldnn_prepacked, m) {
83 m.def(TORCH_SELECTIVE_SCHEMA(
84 "mkldnn_prepacked::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, int[4] input_size, str attr) -> __torch__.torch.classes.mkldnn.ConvOpContext"));
85
86 m.def(TORCH_SELECTIVE_SCHEMA(
87 "mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> Tensor Y"));
88 }
89
TORCH_LIBRARY_IMPL(mkldnn_prepacked,CPU,m)90 TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) {
91 m.impl(
92 TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_prepack"),
93 TORCH_FN(createConvPrePackOpContext));
94
95 m.impl(
96 TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_run"), TORCH_FN(conv_run));
97 }
98
99 } // namespace mkldnn
100 } // namespace native
101 } // namespace at
102
103 #endif // AT_MKLDNN_ENABLED()
104
105 #if AT_MKL_ENABLED() && AT_MKLDNN_ENABLED()
106
107 namespace at {
108 namespace native {
109 namespace mkl {
110
TORCH_LIBRARY(mkl,m)111 TORCH_LIBRARY(mkl, m) {
112 m.def(TORCH_SELECTIVE_SCHEMA(
113 "mkl::_mkl_reorder_linear_weight(Tensor X, int batch_size) -> Tensor"));
114 m.def(TORCH_SELECTIVE_SCHEMA(
115 "mkl::_mkl_linear(Tensor X, Tensor MKL_W, Tensor ORI_W, Tensor? B, int batch_size) -> Tensor"));
116 }
117
118 } // namespace mkl
119 } // namespace native
120 } // namespace at
121
122 #endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED
123