xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/MKLDNNCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Config.h>
5 
6 #if AT_MKLDNN_ENABLED()
7 #include <ideep.hpp>
8 
9 #ifndef IDEEP_PREREQ
10 // Please find definitions of version numbers in ideep.hpp
11 #if defined(IDEEP_VERSION_MAJOR) && defined(IDEEP_VERSION_MINOR) && \
12   defined(IDEEP_VERSION_PATCH) && defined(IDEEP_VERSION_REVISION)
13 #define IDEEP_PREREQ(major, minor, patch, revision) \
14   (((IDEEP_VERSION_MAJOR << 16) + (IDEEP_VERSION_MINOR << 8) + \
15    (IDEEP_VERSION_PATCH << 0)) >= \
16    ((major << 16) + (minor << 8) + (patch << 0)) && \
17    (IDEEP_VERSION_REVISION >= revision))
18 #else
19 #define IDEEP_PREREQ(major, minor, patch, revision) 0
20 #endif
21 #endif
22 
23 namespace at { namespace native {
24 
25 // Mapping ScalarType to ideep tensor data_type
26 TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type);
get_mkldnn_dtype(const Tensor & t)27 static inline ideep::tensor::data_type get_mkldnn_dtype(const Tensor& t) {
28   return get_mkldnn_dtype(t.scalar_type());
29 }
30 
31 TORCH_API int64_t data_ptr_from_mkldnn(const Tensor& mkldnn_tensor);
32 
33 TORCH_API at::Tensor mkldnn_tensor_from_data_ptr(
34     void* data_ptr,
35     at::IntArrayRef dims,
36     at::ScalarType dtype,
37     at::Device device,
38     const uint8_t* opaque_metadata,
39     int64_t opaque_metadata_size);
40 
41 // Construct aten MKL-DNN tensor given an ideep tensor
42 TORCH_API Tensor new_with_itensor_mkldnn(ideep::tensor&& it, std::optional<ScalarType> dtype, std::optional<Device> device);
43 
44 // Retrieve `ideep::tensor` from MKL-DNN tensor
45 TORCH_API ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
46 
47 TORCH_API int64_t nbytes_from_mkldnn(const Tensor& mkldnn_tensor);
48 
49 // Construct an `ideep::tensor` "view" from dense tensor, note the
50 // ideep::tensor will share the underlying buffer
51 TORCH_API ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data_ptr=false);
52 
53 // Construct an `ideep::tensor` "view" from dense tensor using given desc, note
54 // the ideep::tensor will share the underlying buffer
55 TORCH_API ideep::tensor itensor_view_from_dense(
56     const at::Tensor& tensor,
57     const ideep::tensor::desc& desc);
58 
59 // Helper function for getting an ideep tensor out of an aten Tensor or MKL-DNN tensor.
60 TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_const_data_ptr=false);
61 
62 // Set MKLDNN verbose level
63 TORCH_API int set_verbose(int level);
64 
65 }}
66 
67 #endif // AT_MKLDNN_ENABLED
68