xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/Utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/Exceptions.h>
5 #include <ATen/cudnn/Handle.h>
6 #include <ATen/cudnn/cudnn-wrapper.h>
7 
8 namespace at::native {
9 
10 // cuDNN has a buggy check for tensor being contiguous (that is, it does
11 // not ignore stride for dimension that is equal to 0).  This function
12 // makes tensors which have zero stride contiguous, by setting the
13 // strides to 1 as cuDNN likes.
contiguousIfZeroInStrides(const Tensor & t)14 inline Tensor contiguousIfZeroInStrides(const Tensor& t) {
15   for (auto s : t.strides()) {
16     if (s == 0)
17       return t.contiguous();
18   }
19   return t;
20 }
21 
22 } // namespace at::native
23