xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mkldnn/TensorFactories.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/native/mkldnn/MKLDNNCommon.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty_native.h>
9 #endif
10 
11 namespace at { namespace native {
12 
13 #if AT_MKLDNN_ENABLED()
14 
empty_mkldnn(IntArrayRef sizes,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory,std::optional<c10::MemoryFormat> optional_memory_format)15 Tensor empty_mkldnn(IntArrayRef sizes, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device, std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
16   TORCH_CHECK(
17      !optional_memory_format.has_value(),
18      "'memory_format' argument is incompatible with mkldnn tensor");
19   // NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
20   // TODO: support int64_t dims in ideep::tensor to avoid extra conversion
21   ideep::tensor::dims dst_dims (sizes.begin(), sizes.end());
22   auto data_type = dtype.has_value() ? get_mkldnn_dtype(dtype.value()) : ideep::tensor::data_type::f32;
23   ideep::tensor it {dst_dims, data_type};
24   return new_with_itensor_mkldnn(std::move(it), dtype, device);
25 }
26 
27 #else
28 
29 Tensor empty_mkldnn(IntArrayRef sizes, std::optional<ScalarType> dtype, std::optional<Layout> layout, std::optional<Device> device, std::optional<bool> pin_memory, std::optional<c10::MemoryFormat> optional_memory_format) {
30   TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled");
31 }
32 
33 #endif // AT_MKLDNN_ENABLED()
34 
35 }}
36