1 #if !defined(C10_MOBILE) && !defined(ANDROID) 2 #pragma once 3 4 #include <ATen/ATen.h> 5 #include <c10/core/SymIntArrayRef.h> 6 #include <torch/csrc/dynamo/guards.h> 7 8 #include <string> 9 10 namespace torch::inductor { 11 12 // Regarding a aten operation implemented by AOTI, the metadata of the input 13 // tensors will be cached on the disk to accelerate next run. TensorMetada 14 // structure is to represent the metadata of each input tensor. It includes 15 // whether the tensor is symbolic, the dtype, the device, the sizes and the 16 // strides of the tensor. When the metadata of the input tensors is the same as 17 // the cached metadata, the cached kernel library will be loaded and executed. 18 // Otherwise, the AOT Inductor will be called again to generate the kernel 19 // library. 20 // Beyond the TensorMetadata, we build guard/TensorCheck for each input tensor 21 // as well to support symbolic shape. We intend to utilize TensorCheck to find 22 // out the proper kernel rather than TensorMetada comparison. Suppose an 23 // operation with a single input tensor and two kernels: 24 // kernel1: TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, 25 // sizes=[s0, s1, s2], strides=[s1 * s2, s2, 1]) kernel2: 26 // TensorMetadata(is_symbolic=false, dtype=Float, device=CPU, sizes=[3, s1, 27 // s2], strides=[s1 * s2, s2, 1]) 28 // If a tensor with sizes=[3, 4, 5] is passed to the operation, both kernel1 and 29 // kernel2 support the tensor shape. In this case, we need to use TensorCheck 30 // plus some heruistic rules to find out the proper kernel. 31 struct TensorMetadata { 32 // Indicate whether the tensor is symbolic and it may be concluded by sizes_ 33 // and strides_ in the future. 34 bool is_symbolic_; 35 // Dtype of a tensor(For scalar, we will wrap it as a scalar tensor) 36 c10::ScalarType dtype_ = c10::ScalarType::Undefined; 37 // Device of a tensor. 38 c10::Device device_; 39 // Dispatch key set of a tensor 40 c10::DispatchKeySet dispatch_key_set_; 41 // Sizes of a tensor. Currently, we only support static shape and use int64_t 42 // to represent the sizes. In the future, we will create symbolic size and use 43 // SymInt to represent it to support symbolic shape. 44 std::vector<int64_t> sizes_; 45 // Strides of a tensor. For symbolic shape support, it is the same as sizes_ 46 std::vector<int64_t> strides_; 47 // requires grad 48 bool requires_grad_ = false; 49 // TensorCheck for the tensor 50 std::optional<dynamo::TensorCheck> tensor_check_; 51 TensorMetadataTensorMetadata52 TensorMetadata() 53 : is_symbolic_(false), 54 device_(c10::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES), 55 sizes_({}), 56 strides_({}) {} 57 TensorMetadata(const at::Tensor& src_tensor); 58 TensorMetadata( 59 bool is_symbolic, 60 c10::ScalarType dtype, 61 c10::Device device, 62 c10::DispatchKeySet dispatch_key_set, 63 std::vector<int64_t> sizes, 64 std::vector<int64_t> strides, 65 bool requires_grad = false); 66 67 // Build TensorCheck for the tensor by using the data fields in TensorMetadata 68 void build_guard(const dynamo::LocalState& local_state); 69 70 // Compare two TensorMetadata objects 71 bool operator==(const TensorMetadata& other) const; 72 }; 73 74 // ParameterTag is to represent the type of the input parameters of a aten 75 // operation. Currently, we support the following types: 76 // 1. TENSOR: a single tensor 77 // 2. TENSOR_OPTIONAL: a single optional tensor 78 // 3. TENSOR_LIST: a list of tensors 79 // 4. TENSOR_LIST_OPTIONAL: a list of optional tensors 80 // 5. SCALAR: a scalar value 81 // If we need to support more types in the future, we will add more types in the 82 // ParameterTag enum. For example, we will extend the enum to support string, 83 // Dimname and so on to support more types of input parameters of aten 84 // operations. 85 enum ParameterTag { 86 TENSOR, 87 TENSOR_OPTIONAL, 88 TENSOR_LIST, 89 TENSOR_LIST_OPTIONAL, 90 SCALAR, 91 STRING, 92 DEVICE, 93 INVALID, 94 }; 95 96 // ParameterMetadataValue is to represent the value of the input parameters of a 97 // aten operation. 98 using ParameterMetadataValue = std::variant< 99 TensorMetadata, 100 std::vector<TensorMetadata>, 101 c10::Scalar, 102 std::string, 103 c10::Device>; 104 105 // ParameterMetadata is to represent the metadata of the input parameters of a 106 // aten operation. It includes the tag of the parameter, the value of the 107 // parameter and the order of the parameter. 108 struct ParameterMetadata { 109 // The tag of the parameter. It indicates the type of the parameter. 110 ParameterTag tag_; 111 // The value of the parameter. It can be a tensor, a list of tensors or a 112 // scalar. 113 ParameterMetadataValue value_; 114 // The order of the parameter is used to distinguish the parameters with the 115 // same tag. For example, an operation with two input tensors, the first 116 // tensor is a optional tensor and the second tensor is a tensor. The first 117 // tensor will have the order 0 and the second tensor will have the order 1. 118 uint64_t order_{}; 119 ParameterMetadataParameterMetadata120 ParameterMetadata() : tag_(INVALID) {} 121 ParameterMetadata(TensorMetadata tensor_metadata, uint64_t input_order); 122 ParameterMetadata(const at::Tensor& tensor, uint64_t input_order); 123 ParameterMetadata( 124 const std::vector<at::Tensor>& tensor_list, 125 uint64_t input_order); 126 ParameterMetadata( 127 const std::vector<TensorMetadata>& tensor_metadata_list, 128 uint64_t input_order); 129 ParameterMetadata(const c10::Scalar& scalar, uint64_t input_order); 130 ParameterMetadata(const std::string& string_value, uint64_t input_order); 131 ParameterMetadata(const c10::Device& device, uint64_t input_order); 132 133 bool operator==(const ParameterMetadata& other) const; 134 135 private: 136 // Helper function to compare two ParameterMetadata objects with the same 137 // SCALAR tag. 138 bool equal_to(const c10::Scalar& scalar) const; 139 }; 140 141 } // namespace torch::inductor 142 #endif 143