xref: /aosp_15_r20/external/pytorch/c10/core/DynamicCast.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Load.h>
6 #include <c10/util/TypeCast.h>
7 
8 namespace c10 {
9 
10 // Dynamic type casting utils:
11 // - fetch_and_cast
12 // - cast_and_store
13 //
14 // fetch_and_cast fetch a value with dynamic type specified by a ScalarType
15 // from a void pointer and cast it to a static type.
16 //
17 // cast_and_store casts a static typed value into dynamic type specified
18 // by a ScalarType, and store it into a void pointer.
19 //
20 // NOTE:
21 //
22 // Dynamic casting allows us to support type promotion without blowing up
23 // the combination space: For example, without dynamic cast, in order to
24 // implement `add_` with type promotion, we would need something like
25 //
26 // AT_DISPATCH_ALL_TYPES(output.dtype(),
27 //    AT_DISPATCH_ALL_TYPES(input1.dtype(),
28 //       AT_DISPATCH_ALL_TYPES(input2.dtype(),
29 //           [](arg0_t a, arg1_t b) -> out_t { return a + b; }
30 //       )
31 //    )
32 // )
33 //
34 // If we support N dtypes, the above code would generate the a+b kernel for
35 // all the N * N * N different supported types, the compilation time and
36 // binary size would become horrible.
37 //
38 // Dynamic casting might sounds like a bad idea in terms of performance.
39 // Especially if you ever do it in a loop, you are going to do a billion tests.
40 // But in practice it is not as bad as it might look:
41 //
42 // - on CPU, this is a branch that always has the same outcome, therefore
43 //   hopefully the branch predictor could do the job pretty well
44 // - on GPU, these branches will not diverge, so we could still have the same
45 //   warp executing the same line of code
46 // - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to
47 //   check an integer does not hurt the performance much because the ALUs would
48 //   wait for load instructions anyway.
49 //
50 // For the discussion and benchmark, refer to:
51 // - https://github.com/pytorch/pytorch/pull/28343
52 // - https://github.com/pytorch/pytorch/pull/28344
53 // - https://github.com/pytorch/pytorch/pull/28345
54 //
55 
56 #ifdef C10_HOST_DEVICE
57 #define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false);
58 #else
59 #define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type");
60 #endif
61 
62 // Fetch a value with dynamic type src_type from ptr, and cast it to static type
63 // dest_t.
64 #define FETCH_AND_CAST_CASE(type, scalartype) \
65   case ScalarType::scalartype:                \
66     return c10::convert<dest_t>(c10::load<type>(ptr));
67 
68 template <typename dest_t>
fetch_and_cast(const ScalarType src_type,const void * ptr)69 C10_HOST_DEVICE inline dest_t fetch_and_cast(
70     const ScalarType src_type,
71     const void* ptr) {
72   switch (src_type) {
73     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE)
74     FETCH_AND_CAST_CASE(uint16_t, UInt16)
75     FETCH_AND_CAST_CASE(uint32_t, UInt32)
76     FETCH_AND_CAST_CASE(uint64_t, UInt64)
77     default:
78       ERROR_UNSUPPORTED_CAST
79   }
80   return dest_t(0); // just to avoid compiler warning
81 }
82 
83 // Cast a value with static type src_t into dynamic dest_type, and store it to
84 // ptr.
85 #define CAST_AND_STORE_CASE(type, scalartype) \
86   case ScalarType::scalartype:                \
87     *(type*)ptr = c10::convert<type>(value);  \
88     return;
89 template <typename src_t>
cast_and_store(const ScalarType dest_type,void * ptr,src_t value)90 C10_HOST_DEVICE inline void cast_and_store(
91     const ScalarType dest_type,
92     void* ptr,
93     src_t value) {
94   switch (dest_type) {
95     AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE)
96     CAST_AND_STORE_CASE(uint16_t, UInt16)
97     CAST_AND_STORE_CASE(uint32_t, UInt32)
98     CAST_AND_STORE_CASE(uint64_t, UInt64)
99     default:;
100   }
101   ERROR_UNSUPPORTED_CAST
102 }
103 
104 #define DEFINE_UNCASTABLE(T, scalartype_)                     \
105   template <>                                                 \
106   C10_HOST_DEVICE inline T fetch_and_cast<T>(                 \
107       const ScalarType src_type, const void* ptr) {           \
108     CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type);  \
109     return c10::load<T>(ptr);                                 \
110   }                                                           \
111   template <>                                                 \
112   C10_HOST_DEVICE inline void cast_and_store<T>(              \
113       const ScalarType dest_type, void* ptr, T value) {       \
114     CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \
115     *(T*)ptr = value;                                         \
116   }
117 
118 AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE)
119 
120 #undef FETCH_AND_CAST_CASE
121 #undef CAST_AND_STORE_CASE
122 #undef DEFINE_UNCASTABLE
123 #undef ERROR_UNSUPPORTED_CAST
124 
125 } // namespace c10
126