xref: /aosp_15_r20/external/pytorch/aten/src/ATen/NamedTensorUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/NamedTensor.h>
3 #include <ATen/TensorNames.h>
4 #include <ATen/WrapDimUtilsMulti.h>
5 
6 #include <ATen/core/DimVector.h>
7 #include <ATen/core/Tensor.h>
8 
9 namespace at {
10 
11 using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
12 
has_names(const ITensorListRef & tensors)13 inline bool has_names(const ITensorListRef& tensors) {
14   return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
15     return t.has_names();
16   });
17 }
18 
19 // Converts dim to an positional index. Errors if `dim` cannot be used to
20 // refer to any dimension of tensor.
21 TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
22 TORCH_API std::vector<int64_t> dimnames_to_positions(
23     const Tensor& tensor,
24     DimnameList dims);
25 
26 // Unifies two DimnameList to produce a third. This is useful for implementing
27 // the named inference rule for binary broadcasting operations like add.
28 //
29 // There are three main constraints:
30 // 1) Check matching: Names must match positionally from the right.
31 // 2) Check misaligned: If a name `n` is in `names`, then it must appear at
32 //    the same index from the right in other.
33 // 3) The output names are obtained by unifying the names individually from the
34 // right.
35 TORCH_API std::vector<Dimname> unify_from_right(
36     DimnameList names,
37     DimnameList other,
38     const char* action = "broadcast");
39 
reportNYIDimnameOverload(const char * op_name)40 [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
41   TORCH_CHECK(
42       false,
43       op_name,
44       ": You passed a dimname (string) to this op in place of a dimension "
45       "index but it does not yet support this behavior. Please pass a dimension "
46       "index to work around this.");
47 }
48 
49 // [NOTE] Writing name inference rules
50 //
51 // Operators that support named tensors are either composed of operations that
52 // support named tensors or implement some name inference rule. An op that
53 // implements its own name inference rule generally looks like the following:
54 //
55 // Tensor op(...) {
56 //   perform_shape_checks(...);
57 //   # (1)
58 //   auto maybe_outnames = compute_outnames(...);
59 //   auto result = [&]() {
60 //     NoNamesGuard guard;
61 //     return op_impl(...);
62 //   }();
63 //   # (2)
64 //   propagate_names_if_nonempty(result, maybe_outnames);
65 //
66 // Each op has (1) a compute outnames step and (2) a propagate names step.
67 //
68 // compute_outnames is responsible for checking that input names match and
69 // determining what the output names should be. It returns either:
70 // - {} (if the inputs tensors are all unnamed)
71 // - non-empty outnames.
72 //
73 // propagate_names_if_nonempty propagates the outnames if they exist to the
74 // result tensors.
75 //
76 // The {} case is an optimization; if the user does not use named tensors they
77 // pay no perf cost for it.
78 
79 namespace namedinference {
80 
81 const Tensor& propagate_names_if_present_and_nonempty(
82     const Tensor& result,
83     std::optional<DimnameList> maybe_names,
84     bool validate_names = false);
85 // Propagates `names` to `result` if `names` is not empty.
86 // `names` can be empty; see [NOTE] Writing name inference rules
87 // If `names` is not empty, `names.size()` should equal `result.dim()`.
88 // When in doubt, use this overload instead of the others.
89 TORCH_API const Tensor& propagate_names_if_nonempty(
90     const Tensor& result,
91     DimnameList maybe_names,
92     bool validate_names = false);
93 
94 // Propagates `names` to `result`. Only use this if we are certain that there
95 // are names to propagate (that names is not empty).
96 TORCH_API const Tensor& propagate_names(
97     const Tensor& result,
98     DimnameList names,
99     bool validate_names = false);
100 
101 // Propagates all names from src to result.
102 TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
103 
104 // Propagates all names except for those at the excluded_idxs.
105 TORCH_API void propagate_names_except(
106     const Tensor& result,
107     const Tensor& src,
108     IntArrayRef excluded_idxs);
109 
110 // Used for reduction ops that have a `keepdim` arg.
111 TORCH_API void propagate_names_for_reduction(
112     const Tensor& result,
113     const Tensor& src,
114     IntArrayRef excluded_idxs,
115     bool keepdim);
116 
117 TORCH_API void propagate_names_for_expand(
118     const Tensor& result,
119     const Tensor& self);
120 
121 TORCH_API std::vector<Dimname> compute_cat_outnames(
122     const MaterializedITensorListRef& tensors);
123 
124 TORCH_API std::vector<Dimname> compute_broadcast_outnames(
125     const Tensor& self,
126     const Tensor& other);
127 
128 TORCH_API std::vector<Dimname> broadcast_to_outnames(
129     const Tensor& tensor,
130     const Tensor& reference_tensor,
131     const char* op_name);
132 
133 TORCH_API std::vector<Dimname> compute_matmul_outnames(
134     const Tensor& self,
135     const Tensor& other);
136 
137 TORCH_API std::vector<Dimname> compute_cdist_outnames(
138     const Tensor& self,
139     const Tensor& other);
140 
141 TORCH_API std::vector<Dimname> compute_bmm_outnames(
142     const Tensor& result,
143     const Tensor& self,
144     const Tensor& other);
145 
146 TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
147 TORCH_API std::vector<Dimname> compute_squeeze_outnames(
148     const Tensor& tensor,
149     std::bitset<dim_bitset_size> dims);
150 
151 std::vector<Dimname> compute_diagonal_outnames(
152     const Tensor& tensor,
153     int64_t dim1,
154     int64_t dim2);
155 
156 // TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
157 
158 TORCH_API TensorImpl* propagate_names_if_nonempty(
159     TensorImpl* result,
160     DimnameList maybe_names,
161     bool validate_names = false);
162 
163 TORCH_API TensorImpl* propagate_names(
164     TensorImpl* result,
165     DimnameList names,
166     bool validate_names = false);
167 
168 TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
169 
170 TORCH_API inline void propagate_names(
171     const TensorBase& result,
172     DimnameList names,
173     bool validate_names = false) {
174   propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
175 }
176 
177 TORCH_API inline void propagate_names_if_nonempty(
178     const TensorBase& result,
179     DimnameList names,
180     bool validate_names = false) {
181   propagate_names_if_nonempty(
182       result.unsafeGetTensorImpl(), names, validate_names);
183 }
184 
propagate_names(const TensorBase & result,const TensorBase & src)185 TORCH_API inline void propagate_names(
186     const TensorBase& result,
187     const TensorBase& src) {
188   propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
189 }
190 
191 // result = m1 @ m2 + bias
192 TORCH_API std::vector<Dimname> propagate_names_for_addmm(
193     const Tensor& m1,
194     const Tensor& m2,
195     const Tensor& bias);
196 
197 TORCH_API std::vector<Dimname> propagate_names_for_addmv(
198     const Tensor& mat,
199     const Tensor& vec,
200     const Tensor& bias);
201 
202 TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
203 
204 TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
205     const Tensor& result,
206     const Tensor& self,
207     const Tensor& other,
208     const Tensor& bias);
209 
210 TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
211 
212 } // namespace namedinference
213 
214 } // namespace at
215