1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/cuda/Exceptions.h> 5 #include <ATen/cudnn/Handle.h> 6 #include <ATen/cudnn/cudnn-wrapper.h> 7 8 namespace at::native { 9 10 // cuDNN has a buggy check for tensor being contiguous (that is, it does 11 // not ignore stride for dimension that is equal to 0). This function 12 // makes tensors which have zero stride contiguous, by setting the 13 // strides to 1 as cuDNN likes. contiguousIfZeroInStrides(const Tensor & t)14inline Tensor contiguousIfZeroInStrides(const Tensor& t) { 15 for (auto s : t.strides()) { 16 if (s == 0) 17 return t.contiguous(); 18 } 19 return t; 20 } 21 22 } // namespace at::native 23