xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/AsmUtils.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <cstdint>
3 
4 // Collection of direct PTX functions
5 
6 namespace at::cuda {
7 
8 template <typename T>
9 struct Bitfield {};
10 
11 template <>
12 struct Bitfield<unsigned int> {
13   static __device__ __host__ __forceinline__
getBitfieldat::cuda::Bitfield14   unsigned int getBitfield(unsigned int val, int pos, int len) {
15 #if !defined(__CUDA_ARCH__)
16     pos &= 0xff;
17     len &= 0xff;
18 
19     unsigned int m = (1u << len) - 1u;
20     return (val >> pos) & m;
21 #else
22     unsigned int ret;
23     asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
24     return ret;
25 #endif
26   }
27 
28   static __device__ __host__ __forceinline__
setBitfieldat::cuda::Bitfield29   unsigned int setBitfield(unsigned int val, unsigned int toInsert, int pos, int len) {
30 #if !defined(__CUDA_ARCH__)
31     pos &= 0xff;
32     len &= 0xff;
33 
34     unsigned int m = (1u << len) - 1u;
35     toInsert &= m;
36     toInsert <<= pos;
37     m <<= pos;
38 
39     return (val & ~m) | toInsert;
40 #else
41     unsigned int ret;
42     asm("bfi.b32 %0, %1, %2, %3, %4;" :
43         "=r"(ret) : "r"(toInsert), "r"(val), "r"(pos), "r"(len));
44     return ret;
45 #endif
46   }
47 };
48 
49 template <>
50 struct Bitfield<uint64_t> {
51   static __device__ __host__ __forceinline__
getBitfieldat::cuda::Bitfield52   uint64_t getBitfield(uint64_t val, int pos, int len) {
53 #if !defined(__CUDA_ARCH__)
54     pos &= 0xff;
55     len &= 0xff;
56 
57     uint64_t m = (1u << len) - 1u;
58     return (val >> pos) & m;
59 #else
60     uint64_t ret;
61     asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
62     return ret;
63 #endif
64   }
65 
66   static __device__ __host__ __forceinline__
setBitfieldat::cuda::Bitfield67   uint64_t setBitfield(uint64_t val, uint64_t toInsert, int pos, int len) {
68 #if !defined(__CUDA_ARCH__)
69     pos &= 0xff;
70     len &= 0xff;
71 
72     uint64_t m = (1u << len) - 1u;
73     toInsert &= m;
74     toInsert <<= pos;
75     m <<= pos;
76 
77     return (val & ~m) | toInsert;
78 #else
79     uint64_t ret;
80     asm("bfi.b64 %0, %1, %2, %3, %4;" :
81         "=l"(ret) : "l"(toInsert), "l"(val), "r"(pos), "r"(len));
82     return ret;
83 #endif
84   }
85 };
86 
getLaneId()87 __device__ __forceinline__ int getLaneId() {
88 #if defined(USE_ROCM)
89   return __lane_id();
90 #else
91   int laneId;
92   asm("mov.s32 %0, %%laneid;" : "=r"(laneId) );
93   return laneId;
94 #endif
95 }
96 
97 #if defined(USE_ROCM)
getLaneMaskLt()98 __device__ __forceinline__ unsigned long long int getLaneMaskLt() {
99   const std::uint64_t m = (1ull << getLaneId()) - 1ull;
100   return m;
101 }
102 #else
getLaneMaskLt()103 __device__ __forceinline__ unsigned getLaneMaskLt() {
104   unsigned mask;
105   asm("mov.u32 %0, %%lanemask_lt;" : "=r"(mask));
106   return mask;
107 }
108 #endif
109 
110 #if defined (USE_ROCM)
getLaneMaskLe()111 __device__ __forceinline__ unsigned long long int getLaneMaskLe() {
112   std::uint64_t m = UINT64_MAX >> (sizeof(std::uint64_t) * CHAR_BIT - (getLaneId() + 1));
113   return m;
114 }
115 #else
getLaneMaskLe()116 __device__ __forceinline__ unsigned getLaneMaskLe() {
117   unsigned mask;
118   asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
119   return mask;
120 }
121 #endif
122 
123 #if defined(USE_ROCM)
getLaneMaskGt()124 __device__ __forceinline__ unsigned long long int getLaneMaskGt() {
125   const std::uint64_t m = getLaneMaskLe();
126   return m ? ~m : m;
127 }
128 #else
getLaneMaskGt()129 __device__ __forceinline__ unsigned getLaneMaskGt() {
130   unsigned mask;
131   asm("mov.u32 %0, %%lanemask_gt;" : "=r"(mask));
132   return mask;
133 }
134 #endif
135 
136 #if defined(USE_ROCM)
getLaneMaskGe()137 __device__ __forceinline__ unsigned long long int getLaneMaskGe() {
138   const std::uint64_t m = getLaneMaskLt();
139   return ~m;
140 }
141 #else
getLaneMaskGe()142 __device__ __forceinline__ unsigned getLaneMaskGe() {
143   unsigned mask;
144   asm("mov.u32 %0, %%lanemask_ge;" : "=r"(mask));
145   return mask;
146 }
147 #endif
148 
149 } // namespace at::cuda
150