xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/MetaTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/EmptyTensor.h>
3 #include <ATen/core/Tensor.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/empty_native.h>
9 #include <ATen/ops/empty_strided_native.h>
10 #endif
11 
12 namespace at::native {
13 
empty_meta_symint(SymIntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)14 Tensor empty_meta_symint(
15   SymIntArrayRef size,
16   std::optional<ScalarType> dtype_opt,
17   std::optional<Layout> layout_opt,
18   std::optional<Device> device_opt,
19   std::optional<bool> pin_memory_opt,
20   std::optional<c10::MemoryFormat> memory_format_opt
21 ) {
22 
23   auto opt_size = asIntArrayRefSlowOpt(size);
24   if (opt_size.has_value()) {
25     return at::detail::empty_meta(*opt_size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
26   }
27   return at::detail::empty_symint_meta(
28       size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
29 }
30 
empty_strided_meta_symint(SymIntArrayRef size,SymIntArrayRef stride,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)31 Tensor empty_strided_meta_symint(
32   SymIntArrayRef size,
33   SymIntArrayRef stride,
34   std::optional<ScalarType> dtype_opt,
35   std::optional<Layout> layout_opt,
36   std::optional<Device> device_opt,
37   std::optional<bool> pin_memory_opt
38 ) {
39   return at::detail::empty_strided_symint_meta(
40       size, stride, dtype_opt, layout_opt, device_opt);
41 }
42 
43 } // namespace at::native
44