1 #pragma once 2 3 #include <ATen/DimVector.h> 4 #include <ATen/EmptyTensor.h> 5 #include <ATen/Tensor.h> 6 #include <ATen/TensorGeometry.h> 7 #include <ATen/Utils.h> 8 9 #include <utility> 10 11 // These functions are NOT in Utils.h, because this file has a dep on Tensor.h 12 13 #define TORCH_CHECK_TENSOR_ALL(cond, ...) \ 14 TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__); 15 16 namespace at { 17 18 // The following are utility functions for checking that arguments 19 // make sense. These are particularly useful for native functions, 20 // which do NO argument checking by default. 21 22 struct TORCH_API TensorArg { 23 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 24 const Tensor& tensor; 25 const char* name; 26 int pos; // 1-indexed TensorArgTensorArg27 TensorArg(const Tensor& tensor, const char* name, int pos) 28 : tensor(tensor), name(name), pos(pos) {} 29 // Try to mitigate any possibility of dangling reference to temporaries. 30 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) 31 TensorArg(Tensor&& tensor, const char* name, int pos) = delete; 32 const Tensor* operator->() const { 33 return &tensor; 34 } 35 const Tensor& operator*() const { 36 return tensor; 37 } 38 }; 39 40 struct TORCH_API TensorGeometryArg { 41 TensorGeometry tensor; 42 const char* name; 43 int pos; // 1-indexed TensorGeometryArgTensorGeometryArg44 /* implicit */ TensorGeometryArg(TensorArg arg) 45 : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {} TensorGeometryArgTensorGeometryArg46 TensorGeometryArg(TensorGeometry tensor, const char* name, int pos) 47 : tensor(std::move(tensor)), name(name), pos(pos) {} 48 const TensorGeometry* operator->() const { 49 return &tensor; 50 } 51 const TensorGeometry& operator*() const { 52 return tensor; 53 } 54 }; 55 56 // A string describing which function did checks on its input 57 // arguments. 58 // TODO: Consider generalizing this into a call stack. 59 using CheckedFrom = const char*; 60 61 // The undefined convention: singular operators assume their arguments 62 // are defined, but functions which take multiple tensors will 63 // implicitly filter out undefined tensors (to make it easier to perform 64 // tests which should apply if the tensor is defined, and should not 65 // otherwise.) 66 // 67 // NB: This means that the n-ary operators take lists of TensorArg, 68 // not TensorGeometryArg, because the Tensor to TensorGeometry 69 // conversion will blow up if you have undefined tensors. 70 71 TORCH_API std::ostream& operator<<( 72 std::ostream& out, 73 const TensorGeometryArg& t); 74 TORCH_API void checkDim( 75 CheckedFrom c, 76 const Tensor& tensor, 77 const char* name, 78 int pos, // 1-indexed 79 int64_t dim); 80 TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim); 81 // NB: this is an inclusive-exclusive range 82 TORCH_API void checkDimRange( 83 CheckedFrom c, 84 const TensorGeometryArg& t, 85 int64_t dim_start, 86 int64_t dim_end); 87 TORCH_API void checkSameDim( 88 CheckedFrom c, 89 const TensorGeometryArg& t1, 90 const TensorGeometryArg& t2); 91 TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); 92 TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts); 93 TORCH_API void checkSize( 94 CheckedFrom c, 95 const TensorGeometryArg& t, 96 IntArrayRef sizes); 97 TORCH_API void checkSize_symint( 98 CheckedFrom c, 99 const TensorGeometryArg& t, 100 c10::SymIntArrayRef sizes); 101 TORCH_API void checkSize( 102 CheckedFrom c, 103 const TensorGeometryArg& t, 104 int64_t dim, 105 int64_t size); 106 TORCH_API void checkSize_symint( 107 CheckedFrom c, 108 const TensorGeometryArg& t, 109 int64_t dim, 110 const c10::SymInt& size); 111 TORCH_API void checkNumel( 112 CheckedFrom c, 113 const TensorGeometryArg& t, 114 int64_t numel); 115 TORCH_API void checkSameNumel( 116 CheckedFrom c, 117 const TensorArg& t1, 118 const TensorArg& t2); 119 TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors); 120 TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s); 121 TORCH_API void checkScalarTypes( 122 CheckedFrom c, 123 const TensorArg& t, 124 at::ArrayRef<ScalarType> l); 125 TORCH_API void checkSameGPU( 126 CheckedFrom c, 127 const TensorArg& t1, 128 const TensorArg& t2); 129 TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors); 130 TORCH_API void checkSameType( 131 CheckedFrom c, 132 const TensorArg& t1, 133 const TensorArg& t2); 134 TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors); 135 TORCH_API void checkSameSize( 136 CheckedFrom c, 137 const TensorArg& t1, 138 const TensorArg& t2); 139 TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors); 140 TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t); 141 TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t); 142 143 // FixMe: does TensorArg slow things down? 144 TORCH_API void checkBackend( 145 CheckedFrom c, 146 at::ArrayRef<Tensor> t, 147 at::Backend backend); 148 149 TORCH_API void checkDeviceType( 150 CheckedFrom c, 151 at::ArrayRef<Tensor> tensors, 152 at::DeviceType device_type); 153 154 TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); 155 156 TORCH_API void checkLayout( 157 CheckedFrom c, 158 at::ArrayRef<Tensor> tensors, 159 at::Layout layout); 160 161 // Methods for getting data_ptr if tensor is defined 162 TORCH_API void* maybe_data_ptr(const Tensor& tensor); 163 TORCH_API void* maybe_data_ptr(const TensorArg& tensor); 164 165 TORCH_API void check_dim_size( 166 const Tensor& tensor, 167 int64_t dim, 168 int64_t dim_size, 169 int64_t size); 170 171 namespace detail { 172 TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes); 173 174 TORCH_API std::optional<std::vector<int64_t>> computeStride( 175 IntArrayRef oldshape, 176 IntArrayRef oldstride, 177 IntArrayRef newshape); 178 179 TORCH_API std::optional<SymDimVector> computeStride( 180 c10::SymIntArrayRef oldshape, 181 c10::SymIntArrayRef oldstride, 182 c10::SymIntArrayRef newshape); 183 184 TORCH_API std::optional<DimVector> computeStride( 185 IntArrayRef oldshape, 186 IntArrayRef oldstride, 187 const DimVector& newshape); 188 189 } // namespace detail 190 } // namespace at 191