1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/Exceptions.h>
5 #include <ATen/cuda/DeviceUtils.cuh>
6 #include <ATen/cuda/CUDAContext.h>
7 #include <math.h>
8
9 #include <ATen/native/cuda/block_reduce.cuh>
10 #include <ATen/native/cuda/DeviceSqrt.cuh>
11 #include <ATen/native/Distance.h>
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/sum.h>
18 #endif
19
20 #include <c10/macros/Macros.h>
21
22 namespace at::native {
23
24 namespace {
25
26 constexpr int kCUDANumThreads = 256;
27
28 template <typename scalar_t>
29 struct dists {
30
signat::native::__anon65d58f980111::dists31 static __forceinline__ __device__ scalar_t sign(scalar_t val) {
32 return (0 < val) - (val < 0);
33 }
34
35 // Zero norm
36 struct zero {
incat::native::__anon65d58f980111::dists::zero37 static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff != 0.0; }
finishat::native::__anon65d58f980111::dists::zero38 static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::zero39 static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
40 };
41
42 // One norm
43 struct one {
incat::native::__anon65d58f980111::dists::one44 static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff; }
finishat::native::__anon65d58f980111::dists::one45 static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::one46 static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::one47 static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t /*dist*/, const scalar_t /*p*/) { return grad * sign(diff); }
48 };
49
50 // Special case backward when p is less than two
51 struct lt_two {
backwardat::native::__anon65d58f980111::dists::lt_two52 static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) {
53 return (dist == 0.0 || (diff == 0.0 && p < 1)) ? 0 : (sign(diff) * std::pow(std::abs(diff), p - 1) * grad / std::pow(dist, p - 1));
54 }
55 };
56
57 // Two norm
58 struct two {
incat::native::__anon65d58f980111::dists::two59 static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { agg += diff * diff; }
finishat::native::__anon65d58f980111::dists::two60 static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return device_sqrt<scalar_t>(agg); }
aggat::native::__anon65d58f980111::dists::two61 static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::two62 static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return dist == 0.0 ? 0 : grad * diff / dist; }
63 };
64
65 // General p norm
66 struct p {
incat::native::__anon65d58f980111::dists::p67 static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t p) { agg += std::pow(diff, p); }
finishat::native::__anon65d58f980111::dists::p68 static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, static_cast<scalar_t>(1) / p); }
aggat::native::__anon65d58f980111::dists::p69 static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { update += other; }
backwardat::native::__anon65d58f980111::dists::p70 static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t p) { return dist == 0.0 ? 0 : diff * std::pow(std::abs(diff), p - 2) * grad / std::pow(dist, p - 1); }
71 };
72
73 // Inf norm
74 struct inf {
incat::native::__anon65d58f980111::dists::inf75 static __forceinline__ __device__ void inc(scalar_t& agg, const scalar_t diff, const scalar_t /*p*/) { if (diff > agg) { agg = diff; } }
finishat::native::__anon65d58f980111::dists::inf76 static __forceinline__ __device__ scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
aggat::native::__anon65d58f980111::dists::inf77 static __forceinline__ __device__ void agg(scalar_t& update, const scalar_t other) { if (other > update) { update = other; } }
backwardat::native::__anon65d58f980111::dists::inf78 static __forceinline__ __device__ scalar_t backward(const scalar_t diff, const scalar_t grad, const scalar_t dist, const scalar_t /*p*/) { return grad * sign(diff) * (std::abs(diff) == dist); }
79 };
80
81 };
82
83 template <typename scalar_t, typename F>
84 struct DistReduceOp {
combineat::native::__anon65d58f980111::DistReduceOp85 __forceinline__ __device__ scalar_t combine(scalar_t a, scalar_t b) const {
86 F::agg(a, b);
87 return a;
88 }
89
warp_shfl_downat::native::__anon65d58f980111::DistReduceOp90 __forceinline__ __device__ scalar_t warp_shfl_down(scalar_t data, int offset) const {
91 return WARP_SHFL_DOWN(data, offset);
92 }
93 };
94
95 template <typename scalar_t, typename F>
pdist_kernel_cuda_impl(scalar_t * result,const scalar_t * self,const int64_t n,const int64_t m,const scalar_t p,const double n2,const double n2_squared_minus_1)96 __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t * self, const int64_t n, const int64_t m, const scalar_t p,
97 const double n2, const double n2_squared_minus_1) {
98 const int64_t k = blockIdx.x;
99 const int stride = blockDim.x;
100
101 // The -1 accounts for floating point truncation issues
102 int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k)));
103 int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
104
105 const scalar_t * const start = self + i * m;
106 const scalar_t * const end = start + m;
107 const scalar_t * a = start + threadIdx.x;
108 const scalar_t * b = self + j * m + threadIdx.x;
109 scalar_t agg = 0.0;
110 for (; a < end; a += stride, b += stride) {
111 F::inc(agg, std::abs(*a - *b), p);
112 }
113
114 __shared__ scalar_t agg_smem[kCUDANumThreads];
115 scalar_t agg_init{0.0};
116 agg = cuda_utils::BlockReduce(agg, DistReduceOp<scalar_t, F>{}, agg_init, agg_smem);
117 if (threadIdx.x == 0) {
118 result[k] = F::finish(agg, p);
119 }
120 }
121
122 template <typename scalar_t, typename F>
cdist_backward_kernel_cuda_impl(scalar_t * buffer,const scalar_t * grad,const scalar_t * x1,const scalar_t * x2,const scalar_t * dist,const scalar_t p,const int64_t r1,const int64_t r2,const int64_t m,const int64_t count,const int64_t r_size,const int64_t l1_size,const int64_t l2_size)123 __global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist,
124 const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
125 const int y = (blockIdx.y * gridDim.z + blockIdx.z) * blockDim.y + threadIdx.y;
126 const int init = blockIdx.x * blockDim.x + threadIdx.x;
127 if (y >= count || init >= m) {
128 return;
129 }
130 const int l = y / r_size;
131 const int k = y % r_size;
132 const int stride = blockDim.x * gridDim.x;
133 const int l_size = r_size * m;
134
135 int64_t i = k / r2;
136 int64_t j = k % r2;
137
138 const scalar_t grad_k = grad[y];
139 const scalar_t dist_k = dist[y];
140
141 const scalar_t * const start = x1 + l * l1_size + i * m;
142 const scalar_t * const end = start + m;
143 const scalar_t * self_i = start + init;
144 const scalar_t * self_j = x2 + l * l2_size + j * m + init;
145
146 scalar_t * buff_i = buffer + l * l_size + (r1 * j + i) * m + init;
147
148 for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride) {
149 const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
150 *buff_i = res;
151 }
152 }
153
154 template <typename scalar_t, typename F>
pdist_backward_kernel_cuda_impl(scalar_t * buffer,const scalar_t * grad,const scalar_t * self,const scalar_t * dist,int64_t gs,const int64_t n,const int64_t m,const int64_t combs,const scalar_t p,const double n2,const double n2_squared_minus_1)155 __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * self, const scalar_t * dist, int64_t gs, const int64_t n, const int64_t m, const int64_t combs, const scalar_t p,
156 const double n2, const double n2_squared_minus_1) {
157 const int64_t k = blockIdx.x * blockDim.x + threadIdx.x;
158 const int init = blockIdx.y * blockDim.y + threadIdx.y;
159 const int stride = blockDim.y * gridDim.y;
160
161 if (k >= combs) {
162 return;
163 }
164
165 // The -1 accounts for floating point truncation issues
166 int64_t i = static_cast<int64_t>((n2 - device_sqrt<double>(n2_squared_minus_1 - 2 * k)));
167 int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
168 int64_t ib = j - i - 1;
169 int64_t jb = n - 2 - i;
170
171 const scalar_t grad_k = grad[k * gs];
172 const scalar_t dist_k = dist[k];
173
174 const scalar_t * const start = self + i * m;
175 const scalar_t * const end = start + m;
176 const scalar_t * self_i = start + init;
177 const scalar_t * self_j = self + j * m + init;
178 scalar_t * buff_i = buffer + (ib * n + i) * m + init;
179 scalar_t * buff_j = buffer + (jb * n + j) * m + init;
180 for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride, buff_j += stride) {
181 const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
182 *buff_i = res;
183 *buff_j = -res;
184 }
185 }
186
187 template <typename scalar_t, typename F>
cdist_kernel_cuda_impl(scalar_t * result,const scalar_t * x1,const scalar_t * x2,const scalar_t p,const int64_t r2,const int64_t m,const int64_t r_size,const int64_t l1_size,const int64_t l2_size)188 __global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2,
189 const scalar_t p, const int64_t r2, const int64_t m, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
190 const int64_t l = blockIdx.x / r_size;
191 const int64_t k = blockIdx.x % r_size;
192 const int64_t i = k / r2;
193 const int64_t j = k % r2;
194 const int stride = blockDim.x;
195
196 const scalar_t * const start = x1 + l * l1_size + i * m;
197 const scalar_t * const end = start + m;
198 const scalar_t * a = start + threadIdx.x;
199 const scalar_t * b = x2 + l * l2_size + j * m + threadIdx.x;
200
201 scalar_t agg = 0.0;
202 for (; a < end; a += stride, b += stride) {
203 F::inc(agg, std::abs(*a - *b), p);
204 }
205 __shared__ scalar_t agg_smem[kCUDANumThreads];
206 scalar_t agg_init{0.0};
207 agg = cuda_utils::BlockReduce(agg, DistReduceOp<scalar_t, F>{}, agg_init, agg_smem);
208 if (threadIdx.x == 0) {
209 result[blockIdx.x] = F::finish(agg, p);
210 }
211 }
212
cdist_kernel_impl(Tensor & result,const Tensor & x1,const Tensor & x2,double p)213 void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) {
214 const int64_t r1 = x1.size(-2);
215 const int64_t r2 = x2.size(-2);
216 const int64_t m = x1.size(-1);
217 const int64_t r_size = r1 * r2;
218 const int64_t l1_size = r1 * m;
219 const int64_t l2_size = r2 * m;
220 const dim3 grid(result.numel());
221 const dim3 block(kCUDANumThreads);
222
223 AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] {
224 auto impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
225 if (p == 0.0) {
226 impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero>;
227 } else if (p == 1.0) {
228 impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
229 } else if (p == 2.0) {
230 impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
231 } else if (std::isinf(p)) {
232 impl_fptr = cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
233 }
234 impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), x1.const_data_ptr<scalar_t>(), x2.const_data_ptr<scalar_t>(), p, r2, m, r_size, l1_size, l2_size);
235 C10_CUDA_KERNEL_LAUNCH_CHECK();
236 });
237 }
238
pdist_forward_kernel_impl(Tensor & result,const Tensor & self,double p)239 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) {
240 const dim3 grid(result.numel());
241 const dim3 block(kCUDANumThreads);
242 int64_t n = self.size(0);
243 int64_t m = self.size(1);
244 // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
245 // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
246 const double n2 = n - .5;
247 const double n2_squared_minus_1 = n2 * n2 - 1;
248
249 AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] {
250 auto impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
251 if (p == 0.0) {
252 impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero>;
253 } else if (p == 1.0) {
254 impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
255 } else if (p == 2.0) {
256 impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
257 } else if (std::isinf(p)) {
258 impl_fptr = pdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
259 }
260 impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), n, m, p, n2, n2_squared_minus_1);
261 C10_CUDA_KERNEL_LAUNCH_CHECK();
262 });
263 }
264
pdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & self,const double p,const Tensor & dist)265 void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
266 if (p == 0.0 || grad.numel() == 0 || self.numel() == 0) {
267 result.fill_(0);
268 return;
269 }
270
271 const int64_t n = result.size(0);
272 int64_t m = self.size(1);
273 const int block_x = 16;
274 // NB: be careful with changing block_y; as it's currently written, grid_y is limited to be 2^16.
275 // block_y of 64 gives us max pdist dim1 of 2**24
276 const int block_y = 64;
277 const int grid_x = (dist.numel() + block_x - 1) / block_x;
278 const int grid_y = (m + block_y * 8 - 1) / (block_y * 8);
279 const dim3 grid(grid_x, grid_y);
280 const dim3 block(block_x, block_y);
281 // https://github.com/pytorch/pytorch/issues/15511 demonstrated we need to do
282 // some math in fp64 -- this is just minimizing the amount of fp64 math we do on the device.
283 const double n2 = n - .5;
284 const double n2_squared_minus_1 = n2 * n2 - 1;
285
286 Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options());
287 AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] {
288 auto impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
289 if (p == 1.0) {
290 impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
291 } else if (p < 2.0) {
292 impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two>;
293 } else if (p == 2.0) {
294 impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
295 } else if (std::isinf(p)) {
296 impl_fptr = pdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
297 }
298 impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(buffer.mutable_data_ptr<scalar_t>(), grad.const_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(), dist.const_data_ptr<scalar_t>(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1);
299 C10_CUDA_KERNEL_LAUNCH_CHECK();
300 });
301
302 at::sum_out(result, buffer, 0);
303 }
304
cdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & x1,const Tensor & x2,const double p,const Tensor & dist)305 void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
306 if (p == 0.0 || grad.numel() == 0 || x1.numel() == 0 || x2.numel() == 0) {
307 result.fill_(0);
308 return;
309 }
310
311 const int64_t r1 = x1.size(-2);
312 const int64_t r2 = x2.size(-2);
313 const int64_t m = x1.size(-1);
314 // Just like we do in the CPU code, assume that result is always batched
315 int64_t batch = result.size(0);
316 const int block_x = 64;
317 const int block_y = 16;
318 const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
319
320 const int64_t count = dist.numel();
321 const int64_t grid_temp = (count + block_y - 1) / block_y;
322
323 const int grid_y = (grid_temp - 1) / 65535 + 1;
324 const int grid_z = (grid_temp - 1) / grid_y + 1;
325
326 const dim3 grid(grid_x, grid_y, grid_z);
327 const dim3 block(block_x, block_y);
328
329 const int64_t r_size = r1 * r2;
330 const int64_t l1_size = r1 * m;
331 const int64_t l2_size = r2 * m;
332 //current implementation supports only gradient that can be collapsed to 1D. However, to avoid checking this assumption,
333 //we call grad.contiguous() before backward, so stride is guaranteed to be 1
334
335 Tensor buffer = at::empty({batch, r2, r1, m}, result.options());
336 AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] {
337 auto impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p>;
338 if (p == 1.0) {
339 impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one>;
340 } else if (p < 2.0) {
341 impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two>;
342 } else if (p == 2.0) {
343 impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two>;
344 } else if (std::isinf(p)) {
345 impl_fptr = cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf>;
346 }
347 impl_fptr<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(buffer.mutable_data_ptr<scalar_t>(),
348 grad.const_data_ptr<scalar_t>(), x1.const_data_ptr<scalar_t>(), x2.const_data_ptr<scalar_t>(), dist.const_data_ptr<scalar_t>(),
349 p, r1, r2, m, count, r_size, l1_size, l2_size);
350 C10_CUDA_KERNEL_LAUNCH_CHECK();
351 });
352
353 at::sum_out(result, buffer, 1);
354
355 }
356
357
358 } // anonymous namespace
359
360 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
361 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
362 REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
363 REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
364
365 } // at::native
366