1 #include <c10/core/Contiguity.h>
2 #include <c10/core/MemoryFormat.h>
3 #include <c10/core/SymInt.h>
4 #include <c10/core/SymIntArrayRef.h>
5 #include <c10/core/SymbolicShapeMeta.h>
6
7 namespace c10 {
8
9 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
SymbolicShapeMeta(const SymbolicShapeMeta & other)10 SymbolicShapeMeta::SymbolicShapeMeta(const SymbolicShapeMeta& other)
11 // Non-mutables can be accessed outside the mutex
12 : sizes_(other.sizes_),
13 strides_(other.strides_),
14 storage_offset_(other.storage_offset_),
15 strides_valid_(other.strides_valid_) {
16 std::scoped_lock lock(other.mutables_);
17 // These must be copied under lock, so ignore clang-tidy here!
18 // NOLINTBEGIN(cppcoreguidelines-prefer-member-initializer)
19 numel_ = other.numel_;
20 is_contiguous_ = other.is_contiguous_;
21 is_channels_last_contiguous_ = other.is_channels_last_contiguous_;
22 is_channels_last_3d_contiguous_ = other.is_channels_last_3d_contiguous_;
23 is_channels_last_ = other.is_channels_last_;
24 is_channels_last_3d_ = other.is_channels_last_3d_;
25 is_non_overlapping_and_dense_ = other.is_non_overlapping_and_dense_;
26 available_.store(other.available_.load());
27 // NOLINTEND(cppcoreguidelines-prefer-member-initializer)
28 }
29
30 // base, sizes, strides
31 static std::optional<
32 std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>>
normalize_sym_sizes_strides(SymIntArrayRef sizes,SymIntArrayRef strides)33 normalize_sym_sizes_strides(SymIntArrayRef sizes, SymIntArrayRef strides) {
34 // Look for a SymNode to dispatch on
35 SymNode base;
36 bool all_hinted = true;
37 // NB: sizes/strides guaranteed to be positive, so only need
38 // is_heap_allocated
39 for (const auto& s : sizes) {
40 if (all_hinted && !s.has_hint()) {
41 all_hinted = false;
42 }
43 if (!base && s.is_heap_allocated()) {
44 base = s.toSymNode();
45 }
46 }
47 for (const auto& s : strides) {
48 if (all_hinted && !s.has_hint()) {
49 all_hinted = false;
50 }
51 if (!base && s.is_heap_allocated()) {
52 base = s.toSymNode();
53 }
54 }
55 if (!base || all_hinted) {
56 // Couldn't find. Tell the caller to do the normal computation
57 // Alternately, if everything is hinted, we want the normal computation
58 // too
59 return std::nullopt;
60 }
61 // Populate the SymNode array
62 std::vector<SymNode> size_nodes;
63 std::vector<SymNode> stride_nodes;
64 size_nodes.reserve(sizes.size());
65 stride_nodes.reserve(strides.size());
66 for (const auto& s : sizes) {
67 size_nodes.emplace_back(s.wrap_node(base));
68 }
69 for (const auto& s : strides) {
70 stride_nodes.emplace_back(s.wrap_node(base));
71 }
72 return std::make_optional(
73 std::tuple<SymNode, std::vector<SymNode>, std::vector<SymNode>>(
74 std::move(base), std::move(size_nodes), std::move(stride_nodes)));
75 }
76
77 // Special treatment because of numel
compute_contiguous() const78 SymBool SymbolicShapeMeta::compute_contiguous() const {
79 if (!strides_valid_) {
80 return false;
81 }
82 c10::SymIntArrayRef sizes(sizes_);
83 c10::SymIntArrayRef strides(strides_);
84 return _compute_contiguous(sizes, strides, numel());
85 }
86
87 // The rest of them
88 #define DEFINE_EAGER_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
89 SymBool SymbolicShapeMeta::name() const { \
90 if (!strides_valid_) { \
91 return false; \
92 } \
93 c10::SymIntArrayRef sizes(sizes_); \
94 c10::SymIntArrayRef strides(strides_); \
95 return fallback(sizes, strides); \
96 }
97
98 #define DEFINE_SYMBOOL_COMPUTE(name, nodeimpl, fallback) \
99 SymBool SymbolicShapeMeta::name() const { \
100 if (!strides_valid_) { \
101 return false; \
102 } \
103 auto n = normalize_sym_sizes_strides(sizes_, strides_); \
104 if (n.has_value()) { \
105 auto [base, size_nodes, stride_nodes] = *n; \
106 return SymBool(base->nodeimpl(size_nodes, stride_nodes)); \
107 } else { \
108 c10::SymIntArrayRef sizes(sizes_); \
109 c10::SymIntArrayRef strides(strides_); \
110 return fallback(sizes, strides); \
111 } \
112 }
113
114 // clang-format off
DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d,is_channels_last_contiguous_2d,_compute_channels_last_contiguous_2d)115 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_2d, is_channels_last_contiguous_2d, _compute_channels_last_contiguous_2d)
116 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_channels_last_contiguous_3d, is_channels_last_contiguous_3d, _compute_channels_last_contiguous_3d)
117 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_2d, is_channels_last_strides_2d, is_channels_last_strides_2d)
118 DEFINE_EAGER_SYMBOOL_COMPUTE(compute_strides_like_channels_last_3d, is_channels_last_strides_3d, is_channels_last_strides_3d)
119 DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and_dense, _compute_non_overlapping_and_dense)
120 // clang-format on
121
122 #undef DEFINE_SYMBOOL_COMPUTE
123
124 // Glue compute
125 // NB: this logic very intentionally short circuits if possible. Without
126 // short circuiting, it causes
127 // python test/functorch/test_aotdispatch.py -k
128 // test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
129 // very slowly.
130
131 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
132 init_is_contiguous();
133 if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
134 return true;
135 }
136 init_is_channels_last_contiguous();
137 if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
138 return true;
139 }
140 return is_contiguous() | is_channels_last_contiguous() |
141 compute_non_overlapping_and_dense();
142 }
143
compute_channels_last_contiguous_3d_dim5() const144 SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
145 init_is_channels_last_contiguous();
146 if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
147 return false;
148 }
149 return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
150 }
151
compute_channels_last_2d_dim5() const152 SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
153 init_is_channels_last_3d_contiguous();
154 if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
155 return false;
156 }
157 return ~is_channels_last_3d_contiguous() &
158 compute_strides_like_channels_last_2d();
159 }
160
compute_channels_last_3d_dim5() const161 SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
162 if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
163 return false;
164 }
165 return ~is_channels_last() & compute_strides_like_channels_last_3d();
166 }
167
compute_is_non_overlapping_and_dense_dim5() const168 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
169 if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
170 return true;
171 }
172 if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
173 return true;
174 }
175 if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
176 return true;
177 }
178 return is_contiguous() | is_channels_last_contiguous() |
179 is_channels_last_3d_contiguous() | compute_non_overlapping_and_dense();
180 }
181
compute_is_non_overlapping_and_dense_anydim() const182 SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
183 if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
184 return true;
185 }
186 return is_contiguous() | compute_non_overlapping_and_dense();
187 }
188
189 // NOLINTNEXTLINE(performance-unnecessary-value-param)
set_numel(SymInt val) const190 void SymbolicShapeMeta::set_numel(SymInt val) const {
191 std::scoped_lock lock(mutables_);
192 if (has_numel()) {
193 return;
194 }
195 numel_ = std::move(val);
196 available_.fetch_or(numel_avail);
197 }
set_is_contiguous(SymBool val) const198 void SymbolicShapeMeta::set_is_contiguous(SymBool val) const {
199 std::scoped_lock lock(mutables_);
200 if (has_is_contiguous()) {
201 return;
202 }
203 is_contiguous_ = std::move(val);
204 available_.fetch_or(is_contiguous_avail);
205 }
set_is_channels_last_contiguous(SymBool val) const206 void SymbolicShapeMeta::set_is_channels_last_contiguous(SymBool val) const {
207 std::scoped_lock lock(mutables_);
208 if (has_is_channels_last_contiguous()) {
209 return;
210 }
211 is_channels_last_contiguous_ = std::move(val);
212 available_.fetch_or(is_channels_last_contiguous_avail);
213 }
set_is_channels_last_3d_contiguous(SymBool val) const214 void SymbolicShapeMeta::set_is_channels_last_3d_contiguous(SymBool val) const {
215 std::scoped_lock lock(mutables_);
216 if (has_is_channels_last_3d_contiguous()) {
217 return;
218 }
219 is_channels_last_3d_contiguous_ = std::move(val);
220 available_.fetch_or(is_channels_last_3d_contiguous_avail);
221 }
set_is_channels_last(SymBool val) const222 void SymbolicShapeMeta::set_is_channels_last(SymBool val) const {
223 std::scoped_lock lock(mutables_);
224 if (has_is_channels_last()) {
225 return;
226 }
227 is_channels_last_ = std::move(val);
228 available_.fetch_or(is_channels_last_avail);
229 }
set_is_channels_last_3d(SymBool val) const230 void SymbolicShapeMeta::set_is_channels_last_3d(SymBool val) const {
231 std::scoped_lock lock(mutables_);
232 if (has_is_channels_last_3d()) {
233 return;
234 }
235 is_channels_last_3d_ = std::move(val);
236 available_.fetch_or(is_channels_last_3d_avail);
237 }
238
set_is_non_overlapping_and_dense(SymBool val) const239 void SymbolicShapeMeta::set_is_non_overlapping_and_dense(SymBool val) const {
240 std::scoped_lock lock(mutables_);
241 if (has_is_non_overlapping_and_dense()) {
242 return;
243 }
244 is_non_overlapping_and_dense_ = std::move(val);
245 available_.fetch_or(is_non_overlapping_and_dense_avail);
246 }
247
init_numel() const248 void SymbolicShapeMeta::init_numel() const {
249 set_numel(multiply_integers(sizes_));
250 }
251
init_is_contiguous() const252 void SymbolicShapeMeta::init_is_contiguous() const {
253 set_is_contiguous(compute_contiguous());
254 }
255
init_is_channels_last_contiguous() const256 void SymbolicShapeMeta::init_is_channels_last_contiguous() const {
257 set_is_channels_last_contiguous([&] {
258 switch (dim()) {
259 case 5:
260 case 4: {
261 return compute_channels_last_contiguous_2d();
262 }
263 default:
264 return SymBool{false};
265 }
266 }());
267 }
268
init_is_channels_last_3d_contiguous() const269 void SymbolicShapeMeta::init_is_channels_last_3d_contiguous() const {
270 set_is_channels_last_3d_contiguous([&] {
271 switch (dim()) {
272 case 5:
273 return compute_channels_last_contiguous_3d_dim5();
274 default:
275 return SymBool{false};
276 }
277 }());
278 }
279
init_is_channels_last() const280 void SymbolicShapeMeta::init_is_channels_last() const {
281 set_is_channels_last([&] {
282 switch (dim()) {
283 case 5:
284 return compute_channels_last_2d_dim5();
285 case 4:
286 return compute_strides_like_channels_last_2d();
287 default:
288 return SymBool{false};
289 }
290 }());
291 }
292
init_is_channels_last_3d() const293 void SymbolicShapeMeta::init_is_channels_last_3d() const {
294 set_is_channels_last_3d([&] {
295 switch (dim()) {
296 case 5:
297 return compute_channels_last_3d_dim5();
298 default:
299 return SymBool{false};
300 }
301 }());
302 }
303
init_is_non_overlapping_and_dense() const304 void SymbolicShapeMeta::init_is_non_overlapping_and_dense() const {
305 set_is_non_overlapping_and_dense([&] {
306 switch (dim()) {
307 case 5:
308 return compute_is_non_overlapping_and_dense_dim5();
309 case 4:
310 return compute_is_non_overlapping_and_dense_dim4();
311 default:
312 return compute_is_non_overlapping_and_dense_anydim();
313 }
314 }());
315 }
316
317 } // namespace c10
318