xref: /aosp_15_r20/external/pytorch/aten/src/ATen/dlpack.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker /*!
2*da0073e9SAndroid Build Coastguard Worker  *  Copyright (c) 2017 by Contributors
3*da0073e9SAndroid Build Coastguard Worker  * \file dlpack.h
4*da0073e9SAndroid Build Coastguard Worker  * \brief The common header of DLPack.
5*da0073e9SAndroid Build Coastguard Worker  */
6*da0073e9SAndroid Build Coastguard Worker #ifndef DLPACK_DLPACK_H_
7*da0073e9SAndroid Build Coastguard Worker #define DLPACK_DLPACK_H_
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker /**
10*da0073e9SAndroid Build Coastguard Worker  * \brief Compatibility with C++
11*da0073e9SAndroid Build Coastguard Worker  */
12*da0073e9SAndroid Build Coastguard Worker #ifdef __cplusplus
13*da0073e9SAndroid Build Coastguard Worker #define DLPACK_EXTERN_C extern "C"
14*da0073e9SAndroid Build Coastguard Worker #else
15*da0073e9SAndroid Build Coastguard Worker #define DLPACK_EXTERN_C
16*da0073e9SAndroid Build Coastguard Worker #endif
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker /*! \brief The current version of dlpack */
19*da0073e9SAndroid Build Coastguard Worker #define DLPACK_VERSION 80
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker /*! \brief The current ABI version of dlpack */
22*da0073e9SAndroid Build Coastguard Worker #define DLPACK_ABI_VERSION 1
23*da0073e9SAndroid Build Coastguard Worker 
24*da0073e9SAndroid Build Coastguard Worker /*! \brief DLPACK_DLL prefix for windows */
25*da0073e9SAndroid Build Coastguard Worker #ifdef _WIN32
26*da0073e9SAndroid Build Coastguard Worker #ifdef DLPACK_EXPORTS
27*da0073e9SAndroid Build Coastguard Worker #define DLPACK_DLL __declspec(dllexport)
28*da0073e9SAndroid Build Coastguard Worker #else
29*da0073e9SAndroid Build Coastguard Worker #define DLPACK_DLL __declspec(dllimport)
30*da0073e9SAndroid Build Coastguard Worker #endif
31*da0073e9SAndroid Build Coastguard Worker #else
32*da0073e9SAndroid Build Coastguard Worker #define DLPACK_DLL
33*da0073e9SAndroid Build Coastguard Worker #endif
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker #include <stdint.h>
36*da0073e9SAndroid Build Coastguard Worker #include <stddef.h>
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker #ifdef __cplusplus
39*da0073e9SAndroid Build Coastguard Worker extern "C" {
40*da0073e9SAndroid Build Coastguard Worker #endif
41*da0073e9SAndroid Build Coastguard Worker /*!
42*da0073e9SAndroid Build Coastguard Worker  * \brief The device type in DLDevice.
43*da0073e9SAndroid Build Coastguard Worker  */
44*da0073e9SAndroid Build Coastguard Worker #ifdef __cplusplus
45*da0073e9SAndroid Build Coastguard Worker typedef enum : int32_t {
46*da0073e9SAndroid Build Coastguard Worker #else
47*da0073e9SAndroid Build Coastguard Worker typedef enum {
48*da0073e9SAndroid Build Coastguard Worker #endif
49*da0073e9SAndroid Build Coastguard Worker   /*! \brief CPU device */
50*da0073e9SAndroid Build Coastguard Worker   kDLCPU = 1,
51*da0073e9SAndroid Build Coastguard Worker   /*! \brief CUDA GPU device */
52*da0073e9SAndroid Build Coastguard Worker   kDLCUDA = 2,
53*da0073e9SAndroid Build Coastguard Worker   /*!
54*da0073e9SAndroid Build Coastguard Worker    * \brief Pinned CUDA CPU memory by cudaMallocHost
55*da0073e9SAndroid Build Coastguard Worker    */
56*da0073e9SAndroid Build Coastguard Worker   kDLCUDAHost = 3,
57*da0073e9SAndroid Build Coastguard Worker   /*! \brief OpenCL devices. */
58*da0073e9SAndroid Build Coastguard Worker   kDLOpenCL = 4,
59*da0073e9SAndroid Build Coastguard Worker   /*! \brief Vulkan buffer for next generation graphics. */
60*da0073e9SAndroid Build Coastguard Worker   kDLVulkan = 7,
61*da0073e9SAndroid Build Coastguard Worker   /*! \brief Metal for Apple GPU. */
62*da0073e9SAndroid Build Coastguard Worker   kDLMetal = 8,
63*da0073e9SAndroid Build Coastguard Worker   /*! \brief Verilog simulator buffer */
64*da0073e9SAndroid Build Coastguard Worker   kDLVPI = 9,
65*da0073e9SAndroid Build Coastguard Worker   /*! \brief ROCm GPUs for AMD GPUs */
66*da0073e9SAndroid Build Coastguard Worker   kDLROCM = 10,
67*da0073e9SAndroid Build Coastguard Worker   /*!
68*da0073e9SAndroid Build Coastguard Worker    * \brief Pinned ROCm CPU memory allocated by hipMallocHost
69*da0073e9SAndroid Build Coastguard Worker    */
70*da0073e9SAndroid Build Coastguard Worker   kDLROCMHost = 11,
71*da0073e9SAndroid Build Coastguard Worker   /*!
72*da0073e9SAndroid Build Coastguard Worker    * \brief Reserved extension device type,
73*da0073e9SAndroid Build Coastguard Worker    * used for quickly test extension device
74*da0073e9SAndroid Build Coastguard Worker    * The semantics can differ depending on the implementation.
75*da0073e9SAndroid Build Coastguard Worker    */
76*da0073e9SAndroid Build Coastguard Worker   kDLExtDev = 12,
77*da0073e9SAndroid Build Coastguard Worker   /*!
78*da0073e9SAndroid Build Coastguard Worker    * \brief CUDA managed/unified memory allocated by cudaMallocManaged
79*da0073e9SAndroid Build Coastguard Worker    */
80*da0073e9SAndroid Build Coastguard Worker   kDLCUDAManaged = 13,
81*da0073e9SAndroid Build Coastguard Worker   /*!
82*da0073e9SAndroid Build Coastguard Worker    * \brief Unified shared memory allocated on a oneAPI non-partititioned
83*da0073e9SAndroid Build Coastguard Worker    * device. Call to oneAPI runtime is required to determine the device
84*da0073e9SAndroid Build Coastguard Worker    * type, the USM allocation type and the sycl context it is bound to.
85*da0073e9SAndroid Build Coastguard Worker    *
86*da0073e9SAndroid Build Coastguard Worker    */
87*da0073e9SAndroid Build Coastguard Worker   kDLOneAPI = 14,
88*da0073e9SAndroid Build Coastguard Worker   /*! \brief GPU support for next generation WebGPU standard. */
89*da0073e9SAndroid Build Coastguard Worker   kDLWebGPU = 15,
90*da0073e9SAndroid Build Coastguard Worker   /*! \brief Qualcomm Hexagon DSP */
91*da0073e9SAndroid Build Coastguard Worker   kDLHexagon = 16,
92*da0073e9SAndroid Build Coastguard Worker   /*! \brief Microsoft AI Accelerator */
93*da0073e9SAndroid Build Coastguard Worker   kDLMAIA = 17,
94*da0073e9SAndroid Build Coastguard Worker } DLDeviceType;
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker /*!
97*da0073e9SAndroid Build Coastguard Worker  * \brief A Device for Tensor and operator.
98*da0073e9SAndroid Build Coastguard Worker  */
99*da0073e9SAndroid Build Coastguard Worker typedef struct {
100*da0073e9SAndroid Build Coastguard Worker   /*! \brief The device type used in the device. */
101*da0073e9SAndroid Build Coastguard Worker   DLDeviceType device_type;
102*da0073e9SAndroid Build Coastguard Worker   /*!
103*da0073e9SAndroid Build Coastguard Worker    * \brief The device index.
104*da0073e9SAndroid Build Coastguard Worker    * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
105*da0073e9SAndroid Build Coastguard Worker    */
106*da0073e9SAndroid Build Coastguard Worker   int32_t device_id;
107*da0073e9SAndroid Build Coastguard Worker } DLDevice;
108*da0073e9SAndroid Build Coastguard Worker 
109*da0073e9SAndroid Build Coastguard Worker /*!
110*da0073e9SAndroid Build Coastguard Worker  * \brief The type code options DLDataType.
111*da0073e9SAndroid Build Coastguard Worker  */
112*da0073e9SAndroid Build Coastguard Worker typedef enum {
113*da0073e9SAndroid Build Coastguard Worker   /*! \brief signed integer */
114*da0073e9SAndroid Build Coastguard Worker   kDLInt = 0U,
115*da0073e9SAndroid Build Coastguard Worker   /*! \brief unsigned integer */
116*da0073e9SAndroid Build Coastguard Worker   kDLUInt = 1U,
117*da0073e9SAndroid Build Coastguard Worker   /*! \brief IEEE floating point */
118*da0073e9SAndroid Build Coastguard Worker   kDLFloat = 2U,
119*da0073e9SAndroid Build Coastguard Worker   /*!
120*da0073e9SAndroid Build Coastguard Worker    * \brief Opaque handle type, reserved for testing purposes.
121*da0073e9SAndroid Build Coastguard Worker    * Frameworks need to agree on the handle data type for the exchange to be well-defined.
122*da0073e9SAndroid Build Coastguard Worker    */
123*da0073e9SAndroid Build Coastguard Worker   kDLOpaqueHandle = 3U,
124*da0073e9SAndroid Build Coastguard Worker   /*! \brief bfloat16 */
125*da0073e9SAndroid Build Coastguard Worker   kDLBfloat = 4U,
126*da0073e9SAndroid Build Coastguard Worker   /*!
127*da0073e9SAndroid Build Coastguard Worker    * \brief complex number
128*da0073e9SAndroid Build Coastguard Worker    * (C/C++/Python layout: compact struct per complex number)
129*da0073e9SAndroid Build Coastguard Worker    */
130*da0073e9SAndroid Build Coastguard Worker   kDLComplex = 5U,
131*da0073e9SAndroid Build Coastguard Worker   /*! \brief boolean */
132*da0073e9SAndroid Build Coastguard Worker   kDLBool = 6U,
133*da0073e9SAndroid Build Coastguard Worker } DLDataTypeCode;
134*da0073e9SAndroid Build Coastguard Worker 
135*da0073e9SAndroid Build Coastguard Worker /*!
136*da0073e9SAndroid Build Coastguard Worker  * \brief The data type the tensor can hold. The data type is assumed to follow the
137*da0073e9SAndroid Build Coastguard Worker  * native endian-ness. An explicit error message should be raised when attempting to
138*da0073e9SAndroid Build Coastguard Worker  * export an array with non-native endianness
139*da0073e9SAndroid Build Coastguard Worker  *
140*da0073e9SAndroid Build Coastguard Worker  *  Examples
141*da0073e9SAndroid Build Coastguard Worker  *   - float: type_code = 2, bits = 32, lanes = 1
142*da0073e9SAndroid Build Coastguard Worker  *   - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
143*da0073e9SAndroid Build Coastguard Worker  *   - int8: type_code = 0, bits = 8, lanes = 1
144*da0073e9SAndroid Build Coastguard Worker  *   - std::complex<float>: type_code = 5, bits = 64, lanes = 1
145*da0073e9SAndroid Build Coastguard Worker  *   - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
146*da0073e9SAndroid Build Coastguard Worker  */
147*da0073e9SAndroid Build Coastguard Worker typedef struct {
148*da0073e9SAndroid Build Coastguard Worker   /*!
149*da0073e9SAndroid Build Coastguard Worker    * \brief Type code of base types.
150*da0073e9SAndroid Build Coastguard Worker    * We keep it uint8_t instead of DLDataTypeCode for minimal memory
151*da0073e9SAndroid Build Coastguard Worker    * footprint, but the value should be one of DLDataTypeCode enum values.
152*da0073e9SAndroid Build Coastguard Worker    * */
153*da0073e9SAndroid Build Coastguard Worker   uint8_t code;
154*da0073e9SAndroid Build Coastguard Worker   /*!
155*da0073e9SAndroid Build Coastguard Worker    * \brief Number of bits, common choices are 8, 16, 32.
156*da0073e9SAndroid Build Coastguard Worker    */
157*da0073e9SAndroid Build Coastguard Worker   uint8_t bits;
158*da0073e9SAndroid Build Coastguard Worker   /*! \brief Number of lanes in the type, used for vector types. */
159*da0073e9SAndroid Build Coastguard Worker   uint16_t lanes;
160*da0073e9SAndroid Build Coastguard Worker } DLDataType;
161*da0073e9SAndroid Build Coastguard Worker 
162*da0073e9SAndroid Build Coastguard Worker /*!
163*da0073e9SAndroid Build Coastguard Worker  * \brief Plain C Tensor object, does not manage memory.
164*da0073e9SAndroid Build Coastguard Worker  */
165*da0073e9SAndroid Build Coastguard Worker typedef struct {
166*da0073e9SAndroid Build Coastguard Worker   /*!
167*da0073e9SAndroid Build Coastguard Worker    * \brief The data pointer points to the allocated data. This will be CUDA
168*da0073e9SAndroid Build Coastguard Worker    * device pointer or cl_mem handle in OpenCL. It may be opaque on some device
169*da0073e9SAndroid Build Coastguard Worker    * types. This pointer is always aligned to 256 bytes as in CUDA. The
170*da0073e9SAndroid Build Coastguard Worker    * `byte_offset` field should be used to point to the beginning of the data.
171*da0073e9SAndroid Build Coastguard Worker    *
172*da0073e9SAndroid Build Coastguard Worker    * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
173*da0073e9SAndroid Build Coastguard Worker    * TVM, perhaps others) do not adhere to this 256 byte aligment requirement
174*da0073e9SAndroid Build Coastguard Worker    * on CPU/CUDA/ROCm, and always use `byte_offset=0`.  This must be fixed
175*da0073e9SAndroid Build Coastguard Worker    * (after which this note will be updated); at the moment it is recommended
176*da0073e9SAndroid Build Coastguard Worker    * to not rely on the data pointer being correctly aligned.
177*da0073e9SAndroid Build Coastguard Worker    *
178*da0073e9SAndroid Build Coastguard Worker    * For given DLTensor, the size of memory required to store the contents of
179*da0073e9SAndroid Build Coastguard Worker    * data is calculated as follows:
180*da0073e9SAndroid Build Coastguard Worker    *
181*da0073e9SAndroid Build Coastguard Worker    * \code{.c}
182*da0073e9SAndroid Build Coastguard Worker    * static inline size_t GetDataSize(const DLTensor* t) {
183*da0073e9SAndroid Build Coastguard Worker    *   size_t size = 1;
184*da0073e9SAndroid Build Coastguard Worker    *   for (tvm_index_t i = 0; i < t->ndim; ++i) {
185*da0073e9SAndroid Build Coastguard Worker    *     size *= t->shape[i];
186*da0073e9SAndroid Build Coastguard Worker    *   }
187*da0073e9SAndroid Build Coastguard Worker    *   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
188*da0073e9SAndroid Build Coastguard Worker    *   return size;
189*da0073e9SAndroid Build Coastguard Worker    * }
190*da0073e9SAndroid Build Coastguard Worker    * \endcode
191*da0073e9SAndroid Build Coastguard Worker    */
192*da0073e9SAndroid Build Coastguard Worker   void* data;
193*da0073e9SAndroid Build Coastguard Worker   /*! \brief The device of the tensor */
194*da0073e9SAndroid Build Coastguard Worker   DLDevice device;
195*da0073e9SAndroid Build Coastguard Worker   /*! \brief Number of dimensions */
196*da0073e9SAndroid Build Coastguard Worker   int32_t ndim;
197*da0073e9SAndroid Build Coastguard Worker   /*! \brief The data type of the pointer*/
198*da0073e9SAndroid Build Coastguard Worker   DLDataType dtype;
199*da0073e9SAndroid Build Coastguard Worker   /*! \brief The shape of the tensor */
200*da0073e9SAndroid Build Coastguard Worker   const int64_t* shape;
201*da0073e9SAndroid Build Coastguard Worker   /*!
202*da0073e9SAndroid Build Coastguard Worker    * \brief strides of the tensor (in number of elements, not bytes)
203*da0073e9SAndroid Build Coastguard Worker    *  can be NULL, indicating tensor is compact and row-majored.
204*da0073e9SAndroid Build Coastguard Worker    */
205*da0073e9SAndroid Build Coastguard Worker   const int64_t* strides;
206*da0073e9SAndroid Build Coastguard Worker   /*! \brief The offset in bytes to the beginning pointer to data */
207*da0073e9SAndroid Build Coastguard Worker   uint64_t byte_offset;
208*da0073e9SAndroid Build Coastguard Worker } DLTensor;
209*da0073e9SAndroid Build Coastguard Worker 
210*da0073e9SAndroid Build Coastguard Worker /*!
211*da0073e9SAndroid Build Coastguard Worker  * \brief C Tensor object, manage memory of DLTensor. This data structure is
212*da0073e9SAndroid Build Coastguard Worker  *  intended to facilitate the borrowing of DLTensor by another framework. It is
213*da0073e9SAndroid Build Coastguard Worker  *  not meant to transfer the tensor. When the borrowing framework doesn't need
214*da0073e9SAndroid Build Coastguard Worker  *  the tensor, it should call the deleter to notify the host that the resource
215*da0073e9SAndroid Build Coastguard Worker  *  is no longer needed.
216*da0073e9SAndroid Build Coastguard Worker  */
217*da0073e9SAndroid Build Coastguard Worker typedef struct DLManagedTensor {
218*da0073e9SAndroid Build Coastguard Worker   /*! \brief DLTensor which is being memory managed */
219*da0073e9SAndroid Build Coastguard Worker   DLTensor dl_tensor;
220*da0073e9SAndroid Build Coastguard Worker   /*! \brief the context of the original host framework of DLManagedTensor in
221*da0073e9SAndroid Build Coastguard Worker    *   which DLManagedTensor is used in the framework. It can also be NULL.
222*da0073e9SAndroid Build Coastguard Worker    */
223*da0073e9SAndroid Build Coastguard Worker   void * manager_ctx;
224*da0073e9SAndroid Build Coastguard Worker   /*! \brief Destructor signature void (*)(void*) - this should be called
225*da0073e9SAndroid Build Coastguard Worker    *   to destruct manager_ctx which holds the DLManagedTensor. It can be NULL
226*da0073e9SAndroid Build Coastguard Worker    *   if there is no way for the caller to provide a reasonable destructor.
227*da0073e9SAndroid Build Coastguard Worker    *   The destructors deletes the argument self as well.
228*da0073e9SAndroid Build Coastguard Worker    */
229*da0073e9SAndroid Build Coastguard Worker   void (*deleter)(struct DLManagedTensor * self);
230*da0073e9SAndroid Build Coastguard Worker } DLManagedTensor;
231*da0073e9SAndroid Build Coastguard Worker #ifdef __cplusplus
232*da0073e9SAndroid Build Coastguard Worker }  // DLPACK_EXTERN_C
233*da0073e9SAndroid Build Coastguard Worker #endif
234*da0073e9SAndroid Build Coastguard Worker #endif  // DLPACK_DLPACK_H_
235