xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/CheckMemoryFormat.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/TensorOptions.h>
2 
3 namespace c10::impl {
4 
5 inline std::optional<MemoryFormat>
check_tensor_options_and_extract_memory_format(const TensorOptions & options,std::optional<MemoryFormat> memory_format)6 check_tensor_options_and_extract_memory_format(
7     const TensorOptions& options,
8     std::optional<MemoryFormat> memory_format) {
9   TORCH_CHECK(
10       options.requires_grad_opt() == std::nullopt ||
11       options.requires_grad_opt().value() == false,
12       "Operators taking TensorOptions cannot take a TensorOptions with "
13       "options.requires_grad set as true. This isn't implemented yet.");
14   TORCH_CHECK(
15       !(options.has_memory_format() && memory_format.has_value()),
16       "Cannot set memory_format both in TensorOptions and explicit argument; please delete "
17       "the redundant setter.");
18   if (memory_format.has_value()) {
19     return memory_format;
20   } else {
21     return options.memory_format_opt();
22   }
23 }
24 
25 } // namespace impl namespace c10
26