1 #pragma once
2
3 #include <c10/core/ScalarType.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Load.h>
6 #include <c10/util/TypeCast.h>
7
8 namespace c10 {
9
10 // Dynamic type casting utils:
11 // - fetch_and_cast
12 // - cast_and_store
13 //
14 // fetch_and_cast fetch a value with dynamic type specified by a ScalarType
15 // from a void pointer and cast it to a static type.
16 //
17 // cast_and_store casts a static typed value into dynamic type specified
18 // by a ScalarType, and store it into a void pointer.
19 //
20 // NOTE:
21 //
22 // Dynamic casting allows us to support type promotion without blowing up
23 // the combination space: For example, without dynamic cast, in order to
24 // implement `add_` with type promotion, we would need something like
25 //
26 // AT_DISPATCH_ALL_TYPES(output.dtype(),
27 // AT_DISPATCH_ALL_TYPES(input1.dtype(),
28 // AT_DISPATCH_ALL_TYPES(input2.dtype(),
29 // [](arg0_t a, arg1_t b) -> out_t { return a + b; }
30 // )
31 // )
32 // )
33 //
34 // If we support N dtypes, the above code would generate the a+b kernel for
35 // all the N * N * N different supported types, the compilation time and
36 // binary size would become horrible.
37 //
38 // Dynamic casting might sounds like a bad idea in terms of performance.
39 // Especially if you ever do it in a loop, you are going to do a billion tests.
40 // But in practice it is not as bad as it might look:
41 //
42 // - on CPU, this is a branch that always has the same outcome, therefore
43 // hopefully the branch predictor could do the job pretty well
44 // - on GPU, these branches will not diverge, so we could still have the same
45 // warp executing the same line of code
46 // - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to
47 // check an integer does not hurt the performance much because the ALUs would
48 // wait for load instructions anyway.
49 //
50 // For the discussion and benchmark, refer to:
51 // - https://github.com/pytorch/pytorch/pull/28343
52 // - https://github.com/pytorch/pytorch/pull/28344
53 // - https://github.com/pytorch/pytorch/pull/28345
54 //
55
56 #ifdef C10_HOST_DEVICE
57 #define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false);
58 #else
59 #define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type");
60 #endif
61
62 // Fetch a value with dynamic type src_type from ptr, and cast it to static type
63 // dest_t.
64 #define FETCH_AND_CAST_CASE(type, scalartype) \
65 case ScalarType::scalartype: \
66 return c10::convert<dest_t>(c10::load<type>(ptr));
67
68 template <typename dest_t>
fetch_and_cast(const ScalarType src_type,const void * ptr)69 C10_HOST_DEVICE inline dest_t fetch_and_cast(
70 const ScalarType src_type,
71 const void* ptr) {
72 switch (src_type) {
73 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE)
74 FETCH_AND_CAST_CASE(uint16_t, UInt16)
75 FETCH_AND_CAST_CASE(uint32_t, UInt32)
76 FETCH_AND_CAST_CASE(uint64_t, UInt64)
77 default:
78 ERROR_UNSUPPORTED_CAST
79 }
80 return dest_t(0); // just to avoid compiler warning
81 }
82
83 // Cast a value with static type src_t into dynamic dest_type, and store it to
84 // ptr.
85 #define CAST_AND_STORE_CASE(type, scalartype) \
86 case ScalarType::scalartype: \
87 *(type*)ptr = c10::convert<type>(value); \
88 return;
89 template <typename src_t>
cast_and_store(const ScalarType dest_type,void * ptr,src_t value)90 C10_HOST_DEVICE inline void cast_and_store(
91 const ScalarType dest_type,
92 void* ptr,
93 src_t value) {
94 switch (dest_type) {
95 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE)
96 CAST_AND_STORE_CASE(uint16_t, UInt16)
97 CAST_AND_STORE_CASE(uint32_t, UInt32)
98 CAST_AND_STORE_CASE(uint64_t, UInt64)
99 default:;
100 }
101 ERROR_UNSUPPORTED_CAST
102 }
103
104 #define DEFINE_UNCASTABLE(T, scalartype_) \
105 template <> \
106 C10_HOST_DEVICE inline T fetch_and_cast<T>( \
107 const ScalarType src_type, const void* ptr) { \
108 CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \
109 return c10::load<T>(ptr); \
110 } \
111 template <> \
112 C10_HOST_DEVICE inline void cast_and_store<T>( \
113 const ScalarType dest_type, void* ptr, T value) { \
114 CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \
115 *(T*)ptr = value; \
116 }
117
118 AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE)
119
120 #undef FETCH_AND_CAST_CASE
121 #undef CAST_AND_STORE_CASE
122 #undef DEFINE_UNCASTABLE
123 #undef ERROR_UNSUPPORTED_CAST
124
125 } // namespace c10
126