1 #pragma once
2 #include <ATen/NamedTensor.h>
3 #include <ATen/TensorNames.h>
4 #include <ATen/WrapDimUtilsMulti.h>
5
6 #include <ATen/core/DimVector.h>
7 #include <ATen/core/Tensor.h>
8
9 namespace at {
10
11 using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
12
has_names(const ITensorListRef & tensors)13 inline bool has_names(const ITensorListRef& tensors) {
14 return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
15 return t.has_names();
16 });
17 }
18
19 // Converts dim to an positional index. Errors if `dim` cannot be used to
20 // refer to any dimension of tensor.
21 TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
22 TORCH_API std::vector<int64_t> dimnames_to_positions(
23 const Tensor& tensor,
24 DimnameList dims);
25
26 // Unifies two DimnameList to produce a third. This is useful for implementing
27 // the named inference rule for binary broadcasting operations like add.
28 //
29 // There are three main constraints:
30 // 1) Check matching: Names must match positionally from the right.
31 // 2) Check misaligned: If a name `n` is in `names`, then it must appear at
32 // the same index from the right in other.
33 // 3) The output names are obtained by unifying the names individually from the
34 // right.
35 TORCH_API std::vector<Dimname> unify_from_right(
36 DimnameList names,
37 DimnameList other,
38 const char* action = "broadcast");
39
reportNYIDimnameOverload(const char * op_name)40 [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
41 TORCH_CHECK(
42 false,
43 op_name,
44 ": You passed a dimname (string) to this op in place of a dimension "
45 "index but it does not yet support this behavior. Please pass a dimension "
46 "index to work around this.");
47 }
48
49 // [NOTE] Writing name inference rules
50 //
51 // Operators that support named tensors are either composed of operations that
52 // support named tensors or implement some name inference rule. An op that
53 // implements its own name inference rule generally looks like the following:
54 //
55 // Tensor op(...) {
56 // perform_shape_checks(...);
57 // # (1)
58 // auto maybe_outnames = compute_outnames(...);
59 // auto result = [&]() {
60 // NoNamesGuard guard;
61 // return op_impl(...);
62 // }();
63 // # (2)
64 // propagate_names_if_nonempty(result, maybe_outnames);
65 //
66 // Each op has (1) a compute outnames step and (2) a propagate names step.
67 //
68 // compute_outnames is responsible for checking that input names match and
69 // determining what the output names should be. It returns either:
70 // - {} (if the inputs tensors are all unnamed)
71 // - non-empty outnames.
72 //
73 // propagate_names_if_nonempty propagates the outnames if they exist to the
74 // result tensors.
75 //
76 // The {} case is an optimization; if the user does not use named tensors they
77 // pay no perf cost for it.
78
79 namespace namedinference {
80
81 const Tensor& propagate_names_if_present_and_nonempty(
82 const Tensor& result,
83 std::optional<DimnameList> maybe_names,
84 bool validate_names = false);
85 // Propagates `names` to `result` if `names` is not empty.
86 // `names` can be empty; see [NOTE] Writing name inference rules
87 // If `names` is not empty, `names.size()` should equal `result.dim()`.
88 // When in doubt, use this overload instead of the others.
89 TORCH_API const Tensor& propagate_names_if_nonempty(
90 const Tensor& result,
91 DimnameList maybe_names,
92 bool validate_names = false);
93
94 // Propagates `names` to `result`. Only use this if we are certain that there
95 // are names to propagate (that names is not empty).
96 TORCH_API const Tensor& propagate_names(
97 const Tensor& result,
98 DimnameList names,
99 bool validate_names = false);
100
101 // Propagates all names from src to result.
102 TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
103
104 // Propagates all names except for those at the excluded_idxs.
105 TORCH_API void propagate_names_except(
106 const Tensor& result,
107 const Tensor& src,
108 IntArrayRef excluded_idxs);
109
110 // Used for reduction ops that have a `keepdim` arg.
111 TORCH_API void propagate_names_for_reduction(
112 const Tensor& result,
113 const Tensor& src,
114 IntArrayRef excluded_idxs,
115 bool keepdim);
116
117 TORCH_API void propagate_names_for_expand(
118 const Tensor& result,
119 const Tensor& self);
120
121 TORCH_API std::vector<Dimname> compute_cat_outnames(
122 const MaterializedITensorListRef& tensors);
123
124 TORCH_API std::vector<Dimname> compute_broadcast_outnames(
125 const Tensor& self,
126 const Tensor& other);
127
128 TORCH_API std::vector<Dimname> broadcast_to_outnames(
129 const Tensor& tensor,
130 const Tensor& reference_tensor,
131 const char* op_name);
132
133 TORCH_API std::vector<Dimname> compute_matmul_outnames(
134 const Tensor& self,
135 const Tensor& other);
136
137 TORCH_API std::vector<Dimname> compute_cdist_outnames(
138 const Tensor& self,
139 const Tensor& other);
140
141 TORCH_API std::vector<Dimname> compute_bmm_outnames(
142 const Tensor& result,
143 const Tensor& self,
144 const Tensor& other);
145
146 TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
147 TORCH_API std::vector<Dimname> compute_squeeze_outnames(
148 const Tensor& tensor,
149 std::bitset<dim_bitset_size> dims);
150
151 std::vector<Dimname> compute_diagonal_outnames(
152 const Tensor& tensor,
153 int64_t dim1,
154 int64_t dim2);
155
156 // TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
157
158 TORCH_API TensorImpl* propagate_names_if_nonempty(
159 TensorImpl* result,
160 DimnameList maybe_names,
161 bool validate_names = false);
162
163 TORCH_API TensorImpl* propagate_names(
164 TensorImpl* result,
165 DimnameList names,
166 bool validate_names = false);
167
168 TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
169
170 TORCH_API inline void propagate_names(
171 const TensorBase& result,
172 DimnameList names,
173 bool validate_names = false) {
174 propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
175 }
176
177 TORCH_API inline void propagate_names_if_nonempty(
178 const TensorBase& result,
179 DimnameList names,
180 bool validate_names = false) {
181 propagate_names_if_nonempty(
182 result.unsafeGetTensorImpl(), names, validate_names);
183 }
184
propagate_names(const TensorBase & result,const TensorBase & src)185 TORCH_API inline void propagate_names(
186 const TensorBase& result,
187 const TensorBase& src) {
188 propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
189 }
190
191 // result = m1 @ m2 + bias
192 TORCH_API std::vector<Dimname> propagate_names_for_addmm(
193 const Tensor& m1,
194 const Tensor& m2,
195 const Tensor& bias);
196
197 TORCH_API std::vector<Dimname> propagate_names_for_addmv(
198 const Tensor& mat,
199 const Tensor& vec,
200 const Tensor& bias);
201
202 TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
203
204 TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
205 const Tensor& result,
206 const Tensor& self,
207 const Tensor& other,
208 const Tensor& bias);
209
210 TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
211
212 } // namespace namedinference
213
214 } // namespace at
215