xref: /aosp_15_r20/external/pytorch/c10/core/SymbolicShapeMeta.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/Contiguity.h>
2 #include <c10/core/MemoryFormat.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/core/SymIntArrayRef.h>
5 #include <c10/core/SymbolicShapeMeta.h>
6 
7 namespace c10 {
8 
9 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SymbolicShapeMeta(const SymbolicShapeMeta & other)10 SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other)
11     // Non-mutables can be accessed outside the mutex
12     : sizes_(other.sizes_),
13       strides_(other.strides_),
14       storage_offset_(other.storage_offset_),
15       strides_valid_(other.strides_valid_) {
16   std::scoped_lock lock(other.mutables_);
17   // These must be copied under lock, so ignore clang-tidy here!
18   // NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer)
19   numel_ = other.numel_;
20   is_contiguous_ = other.is_contiguous_;
21   is_channels_last_contiguous_ = other.is_channels_last_contiguous_;
22   is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_;
23   is_channels_last_ = other.is_channels_last_;
24   is_channels_last_3d_ = other.is_channels_last_3d_;
25   is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_;
26   available_.store(other.available_.load());
27   // NOLINTEND(cppcoreguidelines-prefer-member-initializer)
28 }
29 
30 // base, sizes, strides
31 static std::optional<
32     std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>>
normalize_sym_sizes_strides(SymIntArrayRef sizes,SymIntArrayRef strides)33 normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
34   // Look for a SymNode to dispatch on
35   SymNode base;
36   bool all_hinted = true;
37   // NB: sizes/strides guaranteed to be positive, so only need
38   // is_heap_allocated
39   for (const auto& s : sizes) {
40     if (all_hinted && !s.has_hint()) {
41       all_hinted = false;
42     }
43     if (!base && s.is_heap_allocated()) {
44       base = s.toSymNode();
45     }
46   }
47   for (const auto& s : strides) {
48     if (all_hinted && !s.has_hint()) {
49       all_hinted = false;
50     }
51     if (!base && s.is_heap_allocated()) {
52       base = s.toSymNode();
53     }
54   }
55   if (!base || all_hinted) {
56     // Couldn't find.  Tell the caller to do the normal computation
57     // Alternately, if everything is hinted, we want the normal computation
58     // too
59     return std::nullopt;
60   }
61   // Populate the SymNode array
62   std::vector<SymNode> size_nodes;
63   std::vector<SymNode> stride_nodes;
64   size_nodes.reserve(sizes.size());
65   stride_nodes.reserve(strides.size());
66   for (const auto& s : sizes) {
67     size_nodes.emplace_back(s.wrap_node(base));
68   }
69   for (const auto& s : strides) {
70     stride_nodes.emplace_back(s.wrap_node(base));
71   }
72   return std::make_optional(
73       std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
74           std::move(base), std::move(size_nodes), std::move(stride_nodes)));
75 }
76 
77 // Special treatment because of numel
compute_contiguous() const78 SymBool SymbolicShapeMeta::compute_contiguous() const {
79   if (!strides_valid_) {
80     return false;
81   }
82   c10::SymIntArrayRef sizes(sizes_);
83   c10::SymIntArrayRef strides(strides_);
84   return _compute_contiguous(sizes, strides, numel());
85 }
86 
87 // The rest of them
88 #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
89   SymBool SymbolicShapeMeta::name() const {                    \
90     if (!strides_valid_) {                                     \
91       return false;                                            \
92     }                                                          \
93     c10::SymIntArrayRef sizes(sizes_);                         \
94     c10::SymIntArrayRef strides(strides_);                     \
95     return fallback(sizes, strides);                           \
96   }
97 
98 #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback)        \
99   SymBool SymbolicShapeMeta::name() const {                     \
100     if (!strides_valid_) {                                      \
101       return false;                                             \
102     }                                                           \
103     auto n = normalize_sym_sizes_strides(sizes_, strides_);     \
104     if (n.has_value()) {                                        \
105       auto [base, size_nodes, stride_nodes] = *n;               \
106       return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \
107     } else {                                                    \
108       c10::SymIntArrayRef sizes(sizes_);                        \
109       c10::SymIntArrayRef strides(strides_);                    \
110       return fallback(sizes, strides);                          \
111     }                                                           \
112   }
113 
114 // clang-format off
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d,is_channels_last_contiguous_2d,_compute_channels_last_contiguous_2d)115 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
116 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
117 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d)
118 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d)
119 DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense)
120 // clang-format on
121 
122 #undef DEFINE_SYMBOOL_COMPUTE
123 
124 // Glue compute
125 // NB: this logic very intentionally short circuits if possible.  Without
126 // short circuiting, it causes
127 // python test/functorch/test_aotdispatch.py -k
128 // test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
129 // very slowly.
130 
131 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
132   init_is_contiguous();
133   if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
134     return true;
135   }
136   init_is_channels_last_contiguous();
137   if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
138     return true;
139   }
140   return is_contiguous() | is_channels_last_contiguous() |
141       compute_non_overlapping_and_dense();
142 }
143 
compute_channels_last_contiguous_3d_dim5() const144 SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
145   init_is_channels_last_contiguous();
146   if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
147     return false;
148   }
149   return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
150 }
151 
compute_channels_last_2d_dim5() const152 SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
153   init_is_channels_last_3d_contiguous();
154   if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
155     return false;
156   }
157   return ~is_channels_last_3d_contiguous() &
158       compute_strides_like_channels_last_2d();
159 }
160 
compute_channels_last_3d_dim5() const161 SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
162   if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
163     return false;
164   }
165   return ~is_channels_last() & compute_strides_like_channels_last_3d();
166 }
167 
compute_is_non_overlapping_and_dense_dim5() const168 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
169   if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
170     return true;
171   }
172   if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
173     return true;
174   }
175   if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
176     return true;
177   }
178   return is_contiguous() | is_channels_last_contiguous() |
179       is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense();
180 }
181 
compute_is_non_overlapping_and_dense_anydim() const182 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
183   if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
184     return true;
185   }
186   return is_contiguous() | compute_non_overlapping_and_dense();
187 }
188 
189 // NOLINTNEXTLINE(performance-unnecessary-value-param)
set_numel(SymInt val) const190 void SymbolicShapeMeta::set_numel(SymInt val) const {
191   std::scoped_lock lock(mutables_);
192   if (has_numel()) {
193     return;
194   }
195   numel_ = std::move(val);
196   available_.fetch_or(numel_avail);
197 }
set_is_contiguous(SymBool val) const198 void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
199   std::scoped_lock lock(mutables_);
200   if (has_is_contiguous()) {
201     return;
202   }
203   is_contiguous_ = std::move(val);
204   available_.fetch_or(is_contiguous_avail);
205 }
set_is_channels_last_contiguous(SymBool val) const206 void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
207   std::scoped_lock lock(mutables_);
208   if (has_is_channels_last_contiguous()) {
209     return;
210   }
211   is_channels_last_contiguous_ = std::move(val);
212   available_.fetch_or(is_channels_last_contiguous_avail);
213 }
set_is_channels_last_3d_contiguous(SymBool val) const214 void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
215   std::scoped_lock lock(mutables_);
216   if (has_is_channels_last_3d_contiguous()) {
217     return;
218   }
219   is_channels_last_3d_contiguous_ = std::move(val);
220   available_.fetch_or(is_channels_last_3d_contiguous_avail);
221 }
set_is_channels_last(SymBool val) const222 void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
223   std::scoped_lock lock(mutables_);
224   if (has_is_channels_last()) {
225     return;
226   }
227   is_channels_last_ = std::move(val);
228   available_.fetch_or(is_channels_last_avail);
229 }
set_is_channels_last_3d(SymBool val) const230 void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const {
231   std::scoped_lock lock(mutables_);
232   if (has_is_channels_last_3d()) {
233     return;
234   }
235   is_channels_last_3d_ = std::move(val);
236   available_.fetch_or(is_channels_last_3d_avail);
237 }
238 
set_is_non_overlapping_and_dense(SymBool val) const239 void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const {
240   std::scoped_lock lock(mutables_);
241   if (has_is_non_overlapping_and_dense()) {
242     return;
243   }
244   is_non_overlapping_and_dense_ = std::move(val);
245   available_.fetch_or(is_non_overlapping_and_dense_avail);
246 }
247 
init_numel() const248 void SymbolicShapeMeta::init_numel() const {
249   set_numel(multiply_integers(sizes_));
250 }
251 
init_is_contiguous() const252 void SymbolicShapeMeta::init_is_contiguous() const {
253   set_is_contiguous(compute_contiguous());
254 }
255 
init_is_channels_last_contiguous() const256 void SymbolicShapeMeta::init_is_channels_last_contiguous() const {
257   set_is_channels_last_contiguous([&] {
258     switch (dim()) {
259       case 5:
260       case 4: {
261         return compute_channels_last_contiguous_2d();
262       }
263       default:
264         return SymBool{false};
265     }
266   }());
267 }
268 
init_is_channels_last_3d_contiguous() const269 void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const {
270   set_is_channels_last_3d_contiguous([&] {
271     switch (dim()) {
272       case 5:
273         return compute_channels_last_contiguous_3d_dim5();
274       default:
275         return SymBool{false};
276     }
277   }());
278 }
279 
init_is_channels_last() const280 void SymbolicShapeMeta::init_is_channels_last() const {
281   set_is_channels_last([&] {
282     switch (dim()) {
283       case 5:
284         return compute_channels_last_2d_dim5();
285       case 4:
286         return compute_strides_like_channels_last_2d();
287       default:
288         return SymBool{false};
289     }
290   }());
291 }
292 
init_is_channels_last_3d() const293 void SymbolicShapeMeta::init_is_channels_last_3d() const {
294   set_is_channels_last_3d([&] {
295     switch (dim()) {
296       case 5:
297         return compute_channels_last_3d_dim5();
298       default:
299         return SymBool{false};
300     }
301   }());
302 }
303 
init_is_non_overlapping_and_dense() const304 void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const {
305   set_is_non_overlapping_and_dense([&] {
306     switch (dim()) {
307       case 5:
308         return compute_is_non_overlapping_and_dense_dim5();
309       case 4:
310         return compute_is_non_overlapping_and_dense_dim4();
311       default:
312         return compute_is_non_overlapping_and_dense_anydim();
313     }
314   }());
315 }
316 
317 } // namespace c10
318