xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/DimVector.h>
4 #include <ATen/EmptyTensor.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/TensorGeometry.h>
7 #include <ATen/Utils.h>
8 
9 #include <utility>
10 
11 // These functions are NOT in Utils.h, because this file has a dep on Tensor.h
12 
13 #define TORCH_CHECK_TENSOR_ALL(cond, ...) \
14   TORCH_CHECK((cond)._is_all_true().item<bool>(), __VA_ARGS__);
15 
16 namespace at {
17 
18 // The following are utility functions for checking that arguments
19 // make sense.  These are particularly useful for native functions,
20 // which do NO argument checking by default.
21 
22 struct TORCH_API TensorArg {
23   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
24   const Tensor& tensor;
25   const char* name;
26   int pos; // 1-indexed
TensorArgTensorArg27   TensorArg(const Tensor& tensor, const char* name, int pos)
28       : tensor(tensor), name(name), pos(pos) {}
29   // Try to mitigate any possibility of dangling reference to temporaries.
30   // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
31   TensorArg(Tensor&& tensor, const char* name, int pos) = delete;
32   const Tensor* operator->() const {
33     return &tensor;
34   }
35   const Tensor& operator*() const {
36     return tensor;
37   }
38 };
39 
40 struct TORCH_API TensorGeometryArg {
41   TensorGeometry tensor;
42   const char* name;
43   int pos; // 1-indexed
TensorGeometryArgTensorGeometryArg44   /* implicit */ TensorGeometryArg(TensorArg arg)
45       : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {}
TensorGeometryArgTensorGeometryArg46   TensorGeometryArg(TensorGeometry tensor, const char* name, int pos)
47       : tensor(std::move(tensor)), name(name), pos(pos) {}
48   const TensorGeometry* operator->() const {
49     return &tensor;
50   }
51   const TensorGeometry& operator*() const {
52     return tensor;
53   }
54 };
55 
56 // A string describing which function did checks on its input
57 // arguments.
58 // TODO: Consider generalizing this into a call stack.
59 using CheckedFrom = const char*;
60 
61 // The undefined convention: singular operators assume their arguments
62 // are defined, but functions which take multiple tensors will
63 // implicitly filter out undefined tensors (to make it easier to perform
64 // tests which should apply if the tensor is defined, and should not
65 // otherwise.)
66 //
67 // NB: This means that the n-ary operators take lists of TensorArg,
68 // not TensorGeometryArg, because the Tensor to TensorGeometry
69 // conversion will blow up if you have undefined tensors.
70 
71 TORCH_API std::ostream& operator<<(
72     std::ostream& out,
73     const TensorGeometryArg& t);
74 TORCH_API void checkDim(
75     CheckedFrom c,
76     const Tensor& tensor,
77     const char* name,
78     int pos, // 1-indexed
79     int64_t dim);
80 TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim);
81 // NB: this is an inclusive-exclusive range
82 TORCH_API void checkDimRange(
83     CheckedFrom c,
84     const TensorGeometryArg& t,
85     int64_t dim_start,
86     int64_t dim_end);
87 TORCH_API void checkSameDim(
88     CheckedFrom c,
89     const TensorGeometryArg& t1,
90     const TensorGeometryArg& t2);
91 TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t);
92 TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef<TensorArg> ts);
93 TORCH_API void checkSize(
94     CheckedFrom c,
95     const TensorGeometryArg& t,
96     IntArrayRef sizes);
97 TORCH_API void checkSize_symint(
98     CheckedFrom c,
99     const TensorGeometryArg& t,
100     c10::SymIntArrayRef sizes);
101 TORCH_API void checkSize(
102     CheckedFrom c,
103     const TensorGeometryArg& t,
104     int64_t dim,
105     int64_t size);
106 TORCH_API void checkSize_symint(
107     CheckedFrom c,
108     const TensorGeometryArg& t,
109     int64_t dim,
110     const c10::SymInt& size);
111 TORCH_API void checkNumel(
112     CheckedFrom c,
113     const TensorGeometryArg& t,
114     int64_t numel);
115 TORCH_API void checkSameNumel(
116     CheckedFrom c,
117     const TensorArg& t1,
118     const TensorArg& t2);
119 TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef<TensorArg> tensors);
120 TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s);
121 TORCH_API void checkScalarTypes(
122     CheckedFrom c,
123     const TensorArg& t,
124     at::ArrayRef<ScalarType> l);
125 TORCH_API void checkSameGPU(
126     CheckedFrom c,
127     const TensorArg& t1,
128     const TensorArg& t2);
129 TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef<TensorArg> tensors);
130 TORCH_API void checkSameType(
131     CheckedFrom c,
132     const TensorArg& t1,
133     const TensorArg& t2);
134 TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef<TensorArg> tensors);
135 TORCH_API void checkSameSize(
136     CheckedFrom c,
137     const TensorArg& t1,
138     const TensorArg& t2);
139 TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef<TensorArg> tensors);
140 TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t);
141 TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);
142 
143 // FixMe: does TensorArg slow things down?
144 TORCH_API void checkBackend(
145     CheckedFrom c,
146     at::ArrayRef<Tensor> t,
147     at::Backend backend);
148 
149 TORCH_API void checkDeviceType(
150     CheckedFrom c,
151     at::ArrayRef<Tensor> tensors,
152     at::DeviceType device_type);
153 
154 TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout);
155 
156 TORCH_API void checkLayout(
157     CheckedFrom c,
158     at::ArrayRef<Tensor> tensors,
159     at::Layout layout);
160 
161 // Methods for getting data_ptr if tensor is defined
162 TORCH_API void* maybe_data_ptr(const Tensor& tensor);
163 TORCH_API void* maybe_data_ptr(const TensorArg& tensor);
164 
165 TORCH_API void check_dim_size(
166     const Tensor& tensor,
167     int64_t dim,
168     int64_t dim_size,
169     int64_t size);
170 
171 namespace detail {
172 TORCH_API std::vector<int64_t> defaultStrides(IntArrayRef sizes);
173 
174 TORCH_API std::optional<std::vector<int64_t>> computeStride(
175     IntArrayRef oldshape,
176     IntArrayRef oldstride,
177     IntArrayRef newshape);
178 
179 TORCH_API std::optional<SymDimVector> computeStride(
180     c10::SymIntArrayRef oldshape,
181     c10::SymIntArrayRef oldstride,
182     c10::SymIntArrayRef newshape);
183 
184 TORCH_API std::optional<DimVector> computeStride(
185     IntArrayRef oldshape,
186     IntArrayRef oldstride,
187     const DimVector& newshape);
188 
189 } // namespace detail
190 } // namespace at
191