xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/variable_info.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/variable.h>
4 
5 namespace torch::autograd {
6 
7 struct TORCH_API VariableInfo {
8   explicit VariableInfo();
9   explicit VariableInfo(const Variable& var);
10 
11   Variable zeros(at::OptionalDeviceGuard& device_guard) const;
12 
13   at::Layout layout = at::Layout::Strided;
14   at::Device device = at::kCPU;
15   at::ScalarType scalar_type = at::kFloat;
16   std::vector<c10::SymInt> size;
17   bool requires_grad;
18   bool is_empty;
19 };
20 
21 } // namespace torch::autograd
22