xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_eager/kernel_meta_info.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #pragma once
3 
4 #include <ATen/ATen.h>
5 #include <c10/core/SymIntArrayRef.h>
6 #include <torch/csrc/dynamo/guards.h>
7 
8 #include <string>
9 
10 namespace torch::inductor {
11 
12 // Regarding a aten operation implemented by AOTI, the metadata of the input
13 // tensors will be cached on the disk to accelerate next run. TensorMetada
14 // structure is to represent the metadata of each input tensor. It includes
15 // whether the tensor is symbolic, the dtype, the device, the sizes and the
16 // strides of the tensor. When the metadata of the input tensors is the same as
17 // the cached metadata, the cached kernel library will be loaded and executed.
18 // Otherwise, the AOT Inductor will be called again to generate the kernel
19 // library.
20 // Beyond the TensorMetadata, we build guard/TensorCheck for each input tensor
21 // as well to support symbolic shape. We intend to utilize TensorCheck to find
22 // out the proper kernel rather than TensorMetada comparison. Suppose an
23 // operation with a single input tensor and two kernels:
24 //   kernel1: TensorMetadata(is_symbolic=false, dtype=Float, device=CPU,
25 //   sizes=[s0, s1, s2], strides=[s1 * s2, s2, 1]) kernel2:
26 //   TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, sizes=[3, s1,
27 //   s2], strides=[s1 * s2, s2, 1])
28 // If a tensor with sizes=[3, 4, 5] is passed to the operation, both kernel1 and
29 // kernel2 support the tensor shape. In this case, we need to use TensorCheck
30 // plus some heruistic rules to find out the proper kernel.
31 struct TensorMetadata {
32   // Indicate whether the tensor is symbolic and it may be concluded by sizes_
33   // and strides_ in the future.
34   bool is_symbolic_;
35   // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor)
36   c10::ScalarType dtype_ = c10::ScalarType::Undefined;
37   // Device of a tensor.
38   c10::Device device_;
39   // Dispatch key set of a tensor
40   c10::DispatchKeySet dispatch_key_set_;
41   // Sizes of a tensor. Currently, we only support static shape and use int64_t
42   // to represent the sizes. In the future, we will create symbolic size and use
43   // SymInt to represent it to support symbolic shape.
44   std::vector<int64_t> sizes_;
45   // Strides of a tensor. For symbolic shape support, it is the same as sizes_
46   std::vector<int64_t> strides_;
47   // requires grad
48   bool requires_grad_ = false;
49   // TensorCheck for the tensor
50   std::optional<dynamo::TensorCheck> tensor_check_;
51 
TensorMetadataTensorMetadata52   TensorMetadata()
53       : is_symbolic_(false),
54         device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES),
55         sizes_({}),
56         strides_({}) {}
57   TensorMetadata(const at::Tensor& src_tensor);
58   TensorMetadata(
59       bool is_symbolic,
60       c10::ScalarType dtype,
61       c10::Device device,
62       c10::DispatchKeySet dispatch_key_set,
63       std::vector<int64_t> sizes,
64       std::vector<int64_t> strides,
65       bool requires_grad = false);
66 
67   // Build TensorCheck for the tensor by using the data fields in TensorMetadata
68   void build_guard(const dynamo::LocalState& local_state);
69 
70   // Compare two TensorMetadata objects
71   bool operator==(const TensorMetadata& other) const;
72 };
73 
74 // ParameterTag is to represent the type of the input parameters of a aten
75 // operation. Currently, we support the following types:
76 //   1. TENSOR: a single tensor
77 //   2. TENSOR_OPTIONAL: a single optional tensor
78 //   3. TENSOR_LIST: a list of tensors
79 //   4. TENSOR_LIST_OPTIONAL: a list of optional tensors
80 //   5. SCALAR: a scalar value
81 // If we need to support more types in the future, we will add more types in the
82 // ParameterTag enum. For example, we will extend the enum to support string,
83 // Dimname and so on to support more types of input parameters of aten
84 // operations.
85 enum ParameterTag {
86   TENSOR,
87   TENSOR_OPTIONAL,
88   TENSOR_LIST,
89   TENSOR_LIST_OPTIONAL,
90   SCALAR,
91   STRING,
92   DEVICE,
93   INVALID,
94 };
95 
96 // ParameterMetadataValue is to represent the value of the input parameters of a
97 // aten operation.
98 using ParameterMetadataValue = std::variant<
99     TensorMetadata,
100     std::vector<TensorMetadata>,
101     c10::Scalar,
102     std::string,
103     c10::Device>;
104 
105 // ParameterMetadata is to represent the metadata of the input parameters of a
106 // aten operation. It includes the tag of the parameter, the value of the
107 // parameter and the order of the parameter.
108 struct ParameterMetadata {
109   // The tag of the parameter. It indicates the type of the parameter.
110   ParameterTag tag_;
111   // The value of the parameter. It can be a tensor, a list of tensors or a
112   // scalar.
113   ParameterMetadataValue value_;
114   // The order of the parameter is used to distinguish the parameters with the
115   // same tag. For example, an operation with two input tensors, the first
116   // tensor is a optional tensor and the second tensor is a tensor. The first
117   // tensor will have the order 0 and the second tensor will have the order 1.
118   uint64_t order_{};
119 
ParameterMetadataParameterMetadata120   ParameterMetadata() : tag_(INVALID) {}
121   ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order);
122   ParameterMetadata(const at::Tensor& tensor, uint64_t input_order);
123   ParameterMetadata(
124       const std::vector<at::Tensor>& tensor_list,
125       uint64_t input_order);
126   ParameterMetadata(
127       const std::vector<TensorMetadata>& tensor_metadata_list,
128       uint64_t input_order);
129   ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order);
130   ParameterMetadata(const std::string& string_value, uint64_t input_order);
131   ParameterMetadata(const c10::Device& device, uint64_t input_order);
132 
133   bool operator==(const ParameterMetadata& other) const;
134 
135  private:
136   // Helper function to compare two ParameterMetadata objects with the same
137   // SCALAR tag.
138   bool equal_to(const c10::Scalar& scalar) const;
139 };
140 
141 } // namespace torch::inductor
142 #endif
143