1 #pragma once
2 #include <c10/core/DispatchKey.h>
3 #include <c10/macros/Export.h>
4 #include <c10/macros/Macros.h>
5 #include <c10/util/Exception.h>
6 #include <c10/util/Metaprogramming.h>
7 #include <c10/util/TypeList.h>
8 #include <c10/util/llvmMathExtras.h>
9 #include <array>
10 #include <cstddef>
11 #include <cstdint>
12 #include <initializer_list>
13 #include <iterator>
14 #include <ostream>
15 #include <string>
16 #include <type_traits>
17
18 namespace c10 {
19
20 struct FunctionalityOffsetAndMask {
21 // empty constructor shouldn't be used; only needed to initialize
22 // the array before populating it.
23 FunctionalityOffsetAndMask() = default;
FunctionalityOffsetAndMaskFunctionalityOffsetAndMask24 FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
25 : offset(offset), mask(mask) {}
26 // This needs to big enough to cover the size of the operator table.
27 uint16_t offset{};
28 // See Note [No More Than 16 Backends]
29 // This mask needs to be big enough to mask all of the backend bits.
30 // We probably don't ever want to have more than 16 backend bits, so uint16_t
31 // should be enough.
32 uint16_t mask{};
33 };
34 static_assert(
35 c10::num_runtime_entries < 65536,
36 "The dispatcher currently only supports up to 2^16 runtime entries");
37
38 C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
39 initializeFunctionalityOffsetsAndMasks();
40
41 C10_ALWAYS_INLINE static const std::
42 array<FunctionalityOffsetAndMask, num_functionality_keys>&
offsetsAndMasks()43 offsetsAndMasks() {
44 static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
45 return offsets_and_masks_;
46 }
47
48 // A representation of a set of DispatchKeys. A DispatchKeySet contains both
49 // "functionality" bits and "backend bits", and every tensor holds its own
50 // DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
51 // keyset on every input tensor, or’ing them together, and dispatching to a
52 // specific piece of functionality. The functionality bits are *ordered*. When
53 // multiple functionality bits are set, we use the highest priority
54 // functionality. Similarly, multiple backend bits can theoretically be set if
55 // you call an operator with multiple tensors from difference devices (e.g. CPU
56 // and CUDA), although support for mixed device dispatch is limited (the only
57 // kernels that gracefully handle mixed device inputs for now are cuda kernels
58 // that take in a scalar cpu tensor).
59
60 // A representation of a set of DispatchKeys. A tensor may have multiple
61 // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
62 // DispatchKeySet specifies what type ids apply. The internal representation is
63 // as a 64-bit bit set (this means only 64 tensor type ids are supported).
64 //
65 // As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
66 // "what is the highest priority DispatchKey in the set"? (The set itself is
67 // not ordered; two sets with the same ids will always have the ids ordered in
68 // the same way.)
69 //
70 // Note [DispatchKeySet Internal Representation]
71 // Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
72 // that get passed around at runtime.
73 // However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
74 // and individual dispatch keys.
75 //
76 // First: why do we have this distinction, and why not map every dispatch key
77 // directly to a bit? This is mostly because we have several types of
78 // functionalities that different backends would like to customize. For example,
79 // we have:
80 // - "Dense": CPU, CUDA, XLA, ... (~12 keys)
81 // - "Sparse": SparseCPU, SparseCUDA, ...
82 // - "SparseCsr": SparseCsrCPU, SparseCsrCUDA, ...
83 // - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
84 // - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
85 // The problem is that total number of keys grows quadratically with [#
86 // backends] x [# functionalities], making it very difficult to map each key
87 // directly to a bit in a bitset without dramatically increasing the size of the
88 // bitset over time.
89 //
90 // The two enums (BackendComponent and DispatchKey) can be divided roughly into
91 // 5 categories.
92 //
93 // (1) "Building block" keys
94 // (a) backends: Everything in the BackendComponent enum (e.g. CPUBit,
95 // CUDABit) (b) functionalities: (per-backend) functionality-bit DispatchKeys
96 // (e.g. AutogradFunctionality, SparseCsr, Sparse, Dense)
97 // (2) "Runtime" keys
98 // (a) "non-customizable backends" (e.g. FPGA)
99 // (b) "non-customizable functionalities" (e.g. Functionalize)
100 // (c) "per-backend instances of customizable functionalities" (e.g. CPU,
101 // SparseCPU, AutogradCPU)
102 // (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
103 //
104 // (1) Building block keys always correspond to individual bits in a
105 // DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
106 // runtime keys. e.g.
107 // auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
108 // DispatchKey::Dense});
109 // // The keyset has the runtime dense-cpu key.
110 // dense_cpu_ks.has(DispatchKey::CPU);
111 // // And it contains the building block keys too.
112 // dense_cpu_ks.has(DispatchKey::CPUBit);
113 // dense_cpu_ks.has(DispatchKey::Dense);
114 //
115 // Not every backend and not every functionality counts as a "building block
116 // key". This is mostly to give us more levers to pull in the design space.
117 // Backend keys and functionality keys that count as "building blocks" will
118 // contribute to a full cross product of functionality that can be overriden.
119 //
120 // For example, right now we have at least 12 "backend" building
121 // blocks (CPU, CUDA, XLA, ...) and at least 5 "functionality"
122 // building blocks (Dense, Sparse, SparseCsr, Quantized,
123 // AutogradFunctionality, ...). These keys together allow every
124 // dispatcher operator to be customized in up to 12*4 different
125 // ways. Each of those requires a slot in the operator table of every
126 // dispatcher operator. Not every piece of functionality necessarily
127 // needs to be customizable per-backend, and not every backend
128 // necessarily needs to be able to customize every type of
129 // functionality.
130 //
131 //
132 // (2) Every runtime key corresponds directly to a slot in an operator's runtime
133 // dispatch table, and you can directly register kernels to a runtime dispatch
134 // key.
135 //
136 // For per-backend functionalities like "Dense" or "AutogradFunctionality",
137 // you can think of the corresponding runtime dispatch keys as "instances" of
138 // that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
139 // runtime instances of the "Dense" building block key.
140
141 // (2a) and (2b) are represented identically in the DispatchKeySet logic:
142 // - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
143 // customizable per backend.
144 // In order to do so, we'd need to promote it to a per-backend functionality
145 // "building block" key.
146 // - non-customizable backends (e.g. FPGA) can NOT customize existing
147 // functionality like Sparse, Autograd, etc.
148 // In order to do so, we'd need to promote it to a backend "building block"
149 // key.
150 //
151 // In both cases, these keys directly correspond to runtime slots in the
152 // operator table.
153 //
154 //
155 // (3) "Alias" keys
156 // See Note [Alias Dispatch Keys]
157 //
158 // Final note: for anyone making future changes to the Dispatcher +
159 // DispatchKeySet internals, there's a closed PR with a basic
160 // python-implementation of the Dispatcher that might be useful in quickly
161 // testing out and validating changes. See it at
162 // https://github.com/pytorch/pytorch/pull/68743
163
164 // An undefined tensor is one with an empty tensor type set.
165 class DispatchKeySet final {
166 public:
167 enum Full { FULL };
168 enum FullAfter { FULL_AFTER };
169 enum Raw { RAW };
170
171 // NB: default constructor representation as zero is MANDATORY as
172 // use of DispatchKeySet in TLS requires this.
173 constexpr DispatchKeySet() = default;
174
DispatchKeySet(Full)175 constexpr DispatchKeySet(Full)
176 : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
177
DispatchKeySet(FullAfter,DispatchKey t)178 constexpr DispatchKeySet(FullAfter, DispatchKey t)
179 // LSB after t are OK, but not t itself.
180 // "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
181 // Quantized > Dense). But backends don't really have an ordering.
182 // Therefore, we're enforcing that FullAfter can only be used on
183 // "functionality" keys.
184 : repr_(
185 (1ULL
186 << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
187 1)) -
188 1) {
189 *this = add(DispatchKey::PythonDispatcher);
190 }
191
192 // Public version of DispatchKeySet(uint64_t) API; external users
193 // must be explicit when they do this!
DispatchKeySet(Raw,uint64_t x)194 constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
195
DispatchKeySet(BackendComponent k)196 constexpr explicit DispatchKeySet(BackendComponent k) {
197 if (k == BackendComponent::InvalidBit) {
198 repr_ = 0;
199 } else {
200 repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
201 }
202 }
203
DispatchKeySet(DispatchKey k)204 constexpr explicit DispatchKeySet(DispatchKey k) {
205 // NOLINTNEXTLINE(bugprone-branch-clone)
206 if (k == DispatchKey::Undefined) {
207 // Case 1: handle Undefined specifically
208 repr_ = 0;
209 } else if (k <= DispatchKey::EndOfFunctionalityKeys) {
210 // Case 2: handle "functionality-only" keys
211 // These keys have a functionality bit set, but no backend bits
212 // These can technically be either:
213 // - valid runtime keys (e.g. DispatchKey::AutogradOther,
214 // DispatchKey::FuncTorchBatched, etc)
215 // - "building block" keys that aren't actual runtime keys (e.g.
216 // DispatchKey::Dense or Sparse)
217 uint64_t functionality_val = 1ULL
218 << (num_backends + static_cast<uint8_t>(k) - 1);
219 repr_ = functionality_val;
220 } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
221 // Case 3: "runtime" keys that have a functionality bit AND a backend bit.
222 // First compute which bit to flip for the functionality.
223 auto functionality_k = toFunctionalityKey(k);
224 // The - 1 is because Undefined is technically a "functionality" that
225 // doesn't show up in the bitset. So e.g. Dense is technically the second
226 // functionality, but the lowest functionality bit.
227 uint64_t functionality_val = 1ULL
228 << (num_backends + static_cast<uint8_t>(functionality_k) - 1);
229
230 // then compute which bit to flip for the backend
231 // Case 4a: handle the runtime instances of "per-backend functionality"
232 // keys For example, given DispatchKey::CPU, we should set:
233 // - the Dense functionality bit
234 // - the CPUBit backend bit
235 // first compute which bit to flip for the backend
236 auto backend_k = toBackendComponent(k);
237 uint64_t backend_val = backend_k == BackendComponent::InvalidBit
238 ? 0
239 : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
240 repr_ = functionality_val + backend_val;
241 } else {
242 // At this point, we should have covered every case except for alias keys.
243 // Technically it would be possible to add alias dispatch keys to a
244 // DispatchKeySet, but the semantics are a little confusing and this
245 // currently isn't needed anywhere.
246 repr_ = 0;
247 }
248 }
249
keys_to_repr(std::initializer_list<DispatchKey> ks)250 constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
251 uint64_t repr = 0;
252 for (auto k : ks) {
253 repr |= DispatchKeySet(k).repr_;
254 }
255 return repr;
256 }
257
backend_bits_to_repr(std::initializer_list<BackendComponent> ks)258 constexpr uint64_t backend_bits_to_repr(
259 std::initializer_list<BackendComponent> ks) {
260 uint64_t repr = 0;
261 for (auto k : ks) {
262 repr |= DispatchKeySet(k).repr_;
263 }
264 return repr;
265 }
266
DispatchKeySet(std::initializer_list<DispatchKey> ks)267 explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
268 : repr_(keys_to_repr(ks)) {}
269
DispatchKeySet(std::initializer_list<BackendComponent> ks)270 explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
271 // Note: for some reason, putting this logic directly in the constructor
272 // appears to fail to compile on CUDA 10.1.
273 // See an example internal failure at
274 // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
275 : repr_(backend_bits_to_repr(ks)) {}
276
277 // Test if a DispatchKey is in the set
has(DispatchKey t)278 inline bool has(DispatchKey t) const {
279 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
280 return has_all(DispatchKeySet(t));
281 }
has_backend(BackendComponent t)282 constexpr bool has_backend(BackendComponent t) const {
283 return has_all(DispatchKeySet(t));
284 }
285
286 // Test if a DispatchKey is in the set
287 // Given a DispatchKeySet of functionality keys and (potentially) backend
288 // keys, tests if all of them are in the current set.
has_all(DispatchKeySet ks)289 constexpr bool has_all(DispatchKeySet ks) const {
290 return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
291 }
292
293 // Given a DispatchKeySet of functionality keys and (potentially) backend
294 // keys, tests if any of them are in the current set. This could technically
295 // be pretty easily implemented using has(). It is strictly a perf
296 // optimization though. There are many places in the code base where we want
297 // to test for multiple functionality keys together. HOWEVER, runtime
298 // per-backend functionality keys aren't allowed to be used with this
299 // function, because you can end up with weird results. e.g.
300 // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
301 // would return true.
has_any(DispatchKeySet ks)302 inline bool has_any(DispatchKeySet ks) const {
303 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
304 // Either there are no backend bits in the input keyset
305 ((ks.repr_ & full_backend_mask) == 0) ||
306 // or there are no per-backend-functionality bits
307 // See [Note: Per-Backend Functionality Dispatch Keys]
308 ((ks &
309 DispatchKeySet({
310 DispatchKey::Dense,
311 DispatchKey::Quantized,
312 DispatchKey::Sparse,
313 DispatchKey::SparseCsr,
314 DispatchKey::AutogradFunctionality,
315 })
316 .repr_) == 0));
317 return static_cast<bool>((repr_ & ks.repr_) != 0);
318 }
319 // Test if DispatchKeySet is a superset of ks.
isSupersetOf(DispatchKeySet ks)320 bool isSupersetOf(DispatchKeySet ks) const {
321 return (repr_ & ks.repr_) == ks.repr_;
322 }
323 // Perform set union
324 constexpr DispatchKeySet operator|(DispatchKeySet other) const {
325 return DispatchKeySet(repr_ | other.repr_);
326 }
327 // Perform set intersection
328 constexpr DispatchKeySet operator&(DispatchKeySet other) const {
329 return DispatchKeySet(repr_ & other.repr_);
330 }
331 // Compute the set difference self - other,
332 // but ONLY for the functionality keys.
333 // Any backend bits set on self will remain unchanged.
334 // See Note [Removing keys from DispatchKeySet Only Affects Functionality
335 // Keys]
336 constexpr DispatchKeySet operator-(DispatchKeySet other) const {
337 return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
338 }
339
340 // Compute self ^ other
341 constexpr DispatchKeySet operator^(DispatchKeySet other) const {
342 return DispatchKeySet(repr_ ^ other.repr_);
343 }
344 bool operator==(DispatchKeySet other) const {
345 return repr_ == other.repr_;
346 }
347 bool operator!=(DispatchKeySet other) const {
348 return repr_ != other.repr_;
349 }
350 // Add a DispatchKey to the DispatchKey set. Does NOT mutate,
351 // returns the extended DispatchKeySet!
add(DispatchKey t)352 C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const {
353 return *this | DispatchKeySet(t);
354 }
add(DispatchKeySet ks)355 C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const {
356 return *this | ks;
357 }
358
359 // Remove a DispatchKey from the DispatchKey set.
360 // This is generally not an operation you should be doing
361 // (it's used to implement the printing overload, operator<<)
362 //
363 // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
364 // Only functionality bits are allowed to be removed from a keyset.
365 // For now, we're only allowing removal of "functionality bits" from the
366 // keyset, which is specifically needed by the fallthrough key calculation
367 // logic. Why is removing backend bits problematic? Consider this example:
368 //
369 // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
370 // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
371 // DispatchKeySet([DispatchKey.CPU,
372 // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
373 //
374 // What do we want to happen?
375 // Technically, we'd like it to be true that after removal,
376 // the first keyset still has the CUDA dispatch key while the second doesn't.
377 // Unfortunately there's no way to represent that, because the two keysets are
378 // represented the same way internally: functionality bits: Autograd, Dense
379 // backend bits: CPU, CUDA
380 //
381 // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
382 // bit from the bitset.
remove(DispatchKey t)383 C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
384 return DispatchKeySet(
385 repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
386 }
387 // You're allowed to remove a backend bit from a DispatchKeySet,
388 // but you have to be explicit about it (remove_backend() instead of
389 // remove()).
remove_backend(BackendComponent b)390 constexpr DispatchKeySet remove_backend(BackendComponent b) const {
391 return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
392 }
393 // Is the set empty? (AKA undefined tensor)
empty()394 bool empty() const {
395 return repr_ == 0;
396 }
raw_repr()397 uint64_t raw_repr() {
398 return repr_;
399 }
400
highestFunctionalityKey()401 DispatchKey highestFunctionalityKey() const {
402 auto functionality_idx = indexOfHighestBit();
403 // This means that none of the functionality bits were set.
404 if (functionality_idx < num_backends)
405 return DispatchKey::Undefined;
406 // The first num_backend bits in the keyset don't correspond to real
407 // dispatch keys.
408 return static_cast<DispatchKey>(functionality_idx - num_backends);
409 }
410
411 // This is similar like toBackendComponent(DispatchKey), but less restrictive.
412 // toBackendComponent() errors out if the key that it was passed has no
413 // backend bits, which is useful for error checking. We need a version of that
414 // here that can also handle "fake" backends like FPGA, because they need to
415 // map to the AutogradOther key. For those backends, we return
416 // BackendComponent::InvalidBit.
highestBackendKey()417 BackendComponent highestBackendKey() const {
418 // mask to mask out functionality bits
419 auto backend_idx =
420 DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
421 // all zeros across the backend bits means that no backend bits are set.
422 if (backend_idx == 0)
423 return BackendComponent::InvalidBit;
424 return static_cast<BackendComponent>(backend_idx);
425 }
426
427 // returns the DispatchKey of highest priority in the set.
highestPriorityTypeId()428 DispatchKey highestPriorityTypeId() const {
429 auto functionality_k = highestFunctionalityKey();
430 if (isPerBackendFunctionalityKey(functionality_k)) {
431 return toRuntimePerBackendFunctionalityKey(
432 functionality_k, highestBackendKey());
433 }
434 return functionality_k;
435 }
436
437 // Returns the index of the most-significant bit in the keyset.
438 // This is used to as part of the calculation into the operator table to get:
439 // - the highest "functionality" bit in the keyset.
440 // - the highest "backend" bit in the keyset.
indexOfHighestBit()441 uint8_t indexOfHighestBit() const {
442 return 64 - llvm::countLeadingZeros(repr_);
443 }
444
445 #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
446 // [Note: Trimmed Mobile Dispatch Keys]
447 /**
448 * The method below maps the dispatch key in the enum DispatchKey to an
449 * integer index in the dispatchTable_ array in OperatorEntry. The array
450 * is trimmed for mobile to reduce peak memory usage since it's
451 * unnecessary to reserve additional space for dispatch keys that will
452 * never be used on mobile.
453 */
getDispatchTableIndexForDispatchKeySet()454 int getDispatchTableIndexForDispatchKeySet() const {
455 auto dk = highestPriorityTypeId();
456 switch (dk) {
457 case DispatchKey::Undefined:
458 return 0;
459 case DispatchKey::CPU:
460 return 1;
461 case DispatchKey::QuantizedCPU:
462 return 2;
463 case DispatchKey::SparseCPU:
464 return 3;
465 case DispatchKey::BackendSelect:
466 return 4;
467 case DispatchKey::ADInplaceOrView:
468 return 5;
469 case DispatchKey::AutogradOther:
470 return 6;
471 case DispatchKey::AutogradCPU:
472 return 7;
473 default:
474 return -1;
475 }
476 }
477 #else
478 // returns the index in the operator table of highest priority key in the the
479 // keyset Note that we could in theory implement this using
480 // highestPriorityTypeId(), but this code is very hotpath and we can do it
481 // faster without it.
getDispatchTableIndexForDispatchKeySet()482 int getDispatchTableIndexForDispatchKeySet() const {
483 auto functionality_idx =
484 DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
485 auto offset_and_mask = offsetsAndMasks()[functionality_idx];
486 // Mask the functionality bits out first, then right-shift by 1.
487 // right-shifting by 1 because everything is zero-indexed.
488 // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
489 // give us an offset of 1, etc.
490 auto backend_idx =
491 DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
492 return offset_and_mask.offset + backend_idx;
493 }
494 #endif
495
496 // returns the "index" of the highest priority backend in the keyset.
497 // This is pretty similar to getBackendKey(), but:
498 // - It's hotpath code (part of the runtime bitset calculation)
499 // - I's returns an integer index, not an enum value
500 // - Everything is shifted to the right by 1.
501 // BackendComponent::InvalidBit is technically the lowest enum value,
502 // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
503 // etc.
getBackendIndex()504 uint64_t getBackendIndex() const {
505 return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
506 }
507
508 private:
DispatchKeySet(uint64_t repr)509 constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
510 uint64_t repr_ = 0;
511
512 public:
513 // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
514 // in the set. The iterator is only invalidated by the destruction of the
515 // underlying DispatchKeySet as the iterator stores a pointer to the raw
516 // representation of the DispatchKeySet. Note: When we encounter a per-backend
517 // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
518 // in the keyset, for that functionality. For example, if the next
519 // functionality key to iterate over is Autograd, and the backend bits in the
520 // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
521 // then the next two keys we return will be DispatchKey::AutogradCPU,
522 // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
523 // CUDA in DispatchKey.h).
524 class iterator {
525 public:
526 using self_type = iterator;
527 using iterator_category = std::input_iterator_tag;
528 using value_type = DispatchKey;
529 using difference_type = ptrdiff_t;
530 using reference = value_type&;
531 using pointer = value_type*;
532 // final mask value should mask out the entire keyset
533 static const uint8_t end_iter_mask_val =
534 num_backends + num_functionality_keys;
535 // final key value should be the last DispatchKey
536 static const uint8_t end_iter_key_val = num_functionality_keys;
537
538 // current_dispatchkey_idx_ will iterate through all functionality bits.
539 // current_backendcomponent_idx_ will iterate through all backend bits.
540 explicit iterator(
541 const uint64_t* data_ptr,
542 uint8_t next_functionality = num_backends,
543 uint8_t next_backend = 0)
data_ptr_(data_ptr)544 : data_ptr_(data_ptr),
545 next_functionality_(next_functionality),
546 next_backend_(next_backend),
547 // These are in an invalid state at construction time, and set by the
548 // first increment call
549 current_dispatchkey_idx_(end_iter_key_val),
550 current_backendcomponent_idx_(end_iter_key_val) {
551 // Go to the first key in the set
552 TORCH_INTERNAL_ASSERT(
553 next_functionality_ >= num_backends,
554 "num_backends=",
555 static_cast<uint32_t>(num_backends),
556 "next_functionality_=",
557 static_cast<uint32_t>(next_functionality_));
558 ++(*this);
559 }
560
561 C10_API self_type& operator++();
562
563 self_type operator++(int) {
564 self_type previous_iterator = *this;
565 ++(*this);
566 return previous_iterator;
567 }
568
569 bool operator==(const self_type& rhs) const {
570 return next_functionality_ == rhs.next_functionality_ &&
571 current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
572 next_backend_ == rhs.next_backend_ &&
573 current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
574 }
575 bool operator!=(const self_type& rhs) const {
576 return next_functionality_ != rhs.next_functionality_ ||
577 current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
578 next_backend_ != rhs.next_backend_ ||
579 current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
580 }
581 DispatchKey operator*() const {
582 auto functionality_key =
583 static_cast<DispatchKey>(current_dispatchkey_idx_);
584 if (isPerBackendFunctionalityKey(functionality_key)) {
585 auto next_key = toRuntimePerBackendFunctionalityKey(
586 functionality_key,
587 static_cast<BackendComponent>(current_backendcomponent_idx_));
588 // We expect all of the Dense, Sparse, Quantized, and Autograd keys to
589 // be ordered the same way with respect to their backends
590 TORCH_INTERNAL_ASSERT(
591 toBackendComponent(next_key) ==
592 static_cast<BackendComponent>(current_backendcomponent_idx_),
593 "Tried to map functionality key ",
594 toString(functionality_key),
595 " and backend bit ",
596 toString(
597 static_cast<BackendComponent>(current_backendcomponent_idx_)),
598 " to a runtime key, but ended up with ",
599 toString(next_key),
600 ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
601 " Please double check that enum for inconsistencies.");
602 return next_key;
603 } else {
604 return functionality_key;
605 }
606 }
607
608 private:
609 const uint64_t* data_ptr_;
610 uint8_t next_functionality_;
611 uint8_t next_backend_;
612 uint8_t current_dispatchkey_idx_;
613 uint8_t current_backendcomponent_idx_;
614 };
615
616 public:
617 // Returns iterator to the first key in the set. If no keys are in the
618 // set, then will return the end iterator.
begin()619 iterator begin() const {
620 return iterator(&repr_);
621 }
622
623 // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
624 // this as the end iterator.
end()625 iterator end() const {
626 return iterator(&repr_, iterator::end_iter_mask_val);
627 }
628 };
629
630 C10_API std::string toString(DispatchKeySet);
631 C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
632
getDispatchTableIndexForDispatchKey(DispatchKey k)633 C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
634 return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
635 }
636
637 // Alias key DispatchKey::Autograd maps to
638 // (autograd_dispatch_keyset x full_backend_mask)
639 // NB: keys in this set also get associated with CompositeImplicitAutograd
640 //
641 // Note [autograd_dispatch_keyset Does Not Include Backend Bits]
642 // We don't want to include any backend bits (BackendComponent::CPUBit, etc)
643 // directly in autograd_dispatch_keyset.
644 // Why? keysets like autograd_dispatch_keyset are commonly used to remove
645 // autograd keys from a DispatchKeySet throughout the code base. However, you
646 // are only allowed to remove functionality bits from a keyset, not backend
647 // bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
648 // Keys] for details. To be consistent and avoid confusion, we're explicitly
649 // setting up autograd_dispatch_keyset to not have any backend bits.
650 constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
651 DispatchKey::AutogradFunctionality,
652 DispatchKey::AutogradOther,
653 DispatchKey::AutogradNestedTensor,
654 });
655
656 constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
657 DispatchKey::AutocastCPU,
658 DispatchKey::AutocastMPS,
659 DispatchKey::AutocastCUDA,
660 DispatchKey::AutocastXPU,
661 DispatchKey::AutocastIPU,
662 DispatchKey::AutocastHPU,
663 DispatchKey::AutocastXLA,
664 DispatchKey::AutocastPrivateUse1,
665 });
666
667 // See Note [TLS Initialization]
668 constexpr DispatchKeySet default_included_set = DispatchKeySet({
669 DispatchKey::BackendSelect,
670 DispatchKey::ADInplaceOrView,
671 });
672
673 constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
674 DispatchKey::AutocastCPU,
675 DispatchKey::AutocastMPS,
676 DispatchKey::AutocastCUDA,
677 DispatchKey::AutocastXPU,
678 DispatchKey::AutocastIPU,
679 DispatchKey::AutocastHPU,
680 DispatchKey::AutocastXLA,
681 DispatchKey::AutocastPrivateUse1,
682 });
683
684 constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
685 autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
686
687 constexpr DispatchKeySet python_ks = DispatchKeySet({
688 DispatchKey::Python,
689 DispatchKey::PythonTLSSnapshot,
690 });
691
692 constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
693
694 constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(DispatchKey::SparseCsr);
695
696 constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
697
698 // backend dispatch keys that map to DispatchKey::AutogradOther
699 // NB: keys in this set also get associated with CompositeImplicitAutograd
700 constexpr DispatchKeySet autogradother_backends =
701 DispatchKeySet(
702 // HIP and VE aren't in this list: they now have their own backend bits
703 // which means that they can now have their own Autograd keys.
704 // Technically, HIP will now redispatch to its own custom AutogradHIP
705 // slot in the runtime table.
706 {DispatchKey::FPGA,
707 DispatchKey::MAIA,
708 DispatchKey::Vulkan,
709 DispatchKey::Metal,
710 DispatchKey::CustomRNGKeyId,
711 DispatchKey::MkldnnCPU,
712 // Sparse and Quantized backends also live here.
713 DispatchKey::Sparse,
714 DispatchKey::SparseCsr,
715 DispatchKey::Quantized})
716 // Including the backend bits because this keyset is used during op
717 // registration, which requires looping over all runtime autogradother
718 // backend keys.
719 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
720
721 // The set of dispatch keys that come after autograd
722 // n.b. this relies on the fact that AutogradOther is currently the lowest
723 // Autograd key
724 constexpr DispatchKeySet after_autograd_keyset =
725 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
726
727 // The set of dispatch keys that come after ADInplaceOrView
728 constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
729 DispatchKeySet::FULL_AFTER,
730 c10::DispatchKey::ADInplaceOrView);
731
732 // The set of dispatch keys that come after Functionalize
733 constexpr DispatchKeySet after_func_keyset =
734 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
735 .remove(
736 // NOTE: we also need to remove ADInplaceOrView from the keyset when
737 // redispatching after the func kernels. This is because we're not
738 // calling the same op; we originally called an inplace op, and now
739 // we aren't. The original key calculation figured out which keys
740 // were Fallthrough based on the inplace op. That means that it did
741 // not include the ADInPlaceOrView kernel as a fallthrough key.
742 // However, we WANT the ADInPlaceOrView kernel to be ignored now
743 // that we're calling an out-of-place op. Re-invoking
744 // Dispatcher::call would re-run the Fallthrough key calculation and
745 // get us that, But at::redispatch is more performant. We can get
746 // away with it by explicitly removing the key here.
747 c10::DispatchKey::ADInplaceOrView);
748
749 constexpr DispatchKeySet backend_bitset_mask =
750 DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
751
752 constexpr auto inplace_or_view_ks =
753 DispatchKeySet(DispatchKey::ADInplaceOrView);
754 constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
755 constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
756 constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
757 constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
758 constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
759 constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
760 constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
761 constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
762 constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
763 constexpr auto autograd_privateuse1_ks =
764 DispatchKeySet(DispatchKey::AutogradPrivateUse1);
765 constexpr auto autograd_privateuse2_ks =
766 DispatchKeySet(DispatchKey::AutogradPrivateUse2);
767 constexpr auto autograd_privateuse3_ks =
768 DispatchKeySet(DispatchKey::AutogradPrivateUse3);
769 constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
770 constexpr auto autograd_nested =
771 DispatchKeySet(DispatchKey::AutogradNestedTensor);
772 // keyset corresponding to functorch keys that have their own dedicated
773 // TensorImpl subclass.
774 constexpr auto functorch_transforms_ks = DispatchKeySet(
775 {DispatchKey::FuncTorchBatched,
776 DispatchKey::FuncTorchVmapMode,
777 DispatchKey::Batched,
778 DispatchKey::VmapMode,
779 DispatchKey::FuncTorchGradWrapper});
780
781 constexpr auto functorch_batched_ks =
782 DispatchKeySet({DispatchKey::FuncTorchBatched});
783
784 // This keyset has:
785 // (1) the functionality bits corresponding to backends (dense, sparse,
786 // quantized) (2) all of the backend bits set
787 constexpr DispatchKeySet backend_functionality_keys =
788 DispatchKeySet({
789 DispatchKey::Dense,
790 DispatchKey::Quantized,
791 DispatchKey::Sparse,
792 DispatchKey::SparseCsr,
793 }) |
794 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
795
796 struct OpTableOffsetAndMask {
797 uint16_t offset;
798 uint16_t backend_mask;
799 };
800
801 static_assert(
802 num_backends <= 16,
803 "Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
804 " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
805
806 // true if t is a backend dispatch key
807 C10_API bool isBackendDispatchKey(DispatchKey t);
808
809 // Resolve alias dispatch key to DispatchKeySet if applicable
810 C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
811
812 // Resolve alias dispatch key to DispatchKeySet if applicable,
813 // and check if k is a part of that set
814 C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
815
816 // Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
817 // t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
818 C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
819
820 // Returns a DispatchKeySet of autograd related keys mapped to backend.
821 // for a given backend key, use the associated autograd key.
822 // for non-backend keys, use AutogradOther as a default.
823 // Note: it's convenient and fast to return a default here rather than (say)
824 // returning an std::optional<DispatchKey>, or throwing. But it makes callers
825 // responsible for either a) enforcing the invariant that only backend keys
826 // be passed as arguments, or b) interpreting our return value carefully.
getAutogradRelatedKeySetFromBackend(BackendComponent t)827 inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
828 switch (t) {
829 case BackendComponent::CPUBit:
830 return inplace_or_view_ks | autograd_cpu_ks;
831 case BackendComponent::IPUBit:
832 return inplace_or_view_ks | autograd_ipu_ks;
833 case BackendComponent::XPUBit:
834 return inplace_or_view_ks | autograd_xpu_ks;
835 case BackendComponent::CUDABit:
836 return inplace_or_view_ks | autograd_cuda_ks;
837 case BackendComponent::XLABit:
838 return inplace_or_view_ks | autograd_xla_ks;
839 case BackendComponent::LazyBit:
840 return inplace_or_view_ks | autograd_lazy_ks;
841 case BackendComponent::MetaBit:
842 return inplace_or_view_ks | autograd_meta_ks;
843 case BackendComponent::MPSBit:
844 return inplace_or_view_ks | autograd_mps_ks;
845 case BackendComponent::HPUBit:
846 return inplace_or_view_ks | autograd_hpu_ks;
847 case BackendComponent::PrivateUse1Bit:
848 return inplace_or_view_ks | autograd_privateuse1_ks;
849 case BackendComponent::PrivateUse2Bit:
850 return inplace_or_view_ks | autograd_privateuse2_ks;
851 case BackendComponent::PrivateUse3Bit:
852 return inplace_or_view_ks | autograd_privateuse3_ks;
853 default:
854 return inplace_or_view_ks | autograd_other_ks;
855 }
856 }
857
858 // Returns a DispatchKeySet of autocast related keys mapped to backend.
getAutocastRelatedKeySetFromBackend(BackendComponent t)859 inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
860 constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
861 constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
862 constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
863 constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
864 constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
865 constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
866 constexpr auto autocast_privateuse1_ks =
867 DispatchKeySet(DispatchKey::AutocastPrivateUse1);
868 constexpr auto autocast_mps_ks = DispatchKeySet(DispatchKey::AutocastMPS);
869 switch (t) {
870 case BackendComponent::CPUBit:
871 return autocast_cpu_ks;
872 case BackendComponent::XPUBit:
873 return autocast_xpu_ks;
874 case BackendComponent::IPUBit:
875 return autocast_ipu_ks;
876 case BackendComponent::HPUBit:
877 return autocast_hpu_ks;
878 case BackendComponent::CUDABit:
879 return autocast_cuda_ks;
880 case BackendComponent::XLABit:
881 return autocast_xla_ks;
882 case BackendComponent::PrivateUse1Bit:
883 return autocast_privateuse1_ks;
884 case BackendComponent::MPSBit:
885 return autocast_mps_ks;
886 default:
887 return DispatchKeySet();
888 }
889 }
890
891 // returns the "backend" DispatchKey of highest priority in the set.
892 // This is basically like highestBackendKey(), except that we have some
893 // "functionality" bits that correspond to backends (Sparse, Quantized)
highestPriorityBackendTypeId(DispatchKeySet ks)894 inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
895 return (ks & backend_functionality_keys).highestPriorityTypeId();
896 }
897
898 // This API exists because we have a use case for checking
899 // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
900 // in OperatorEntry.cpp but we disallow it in has() API.
901 C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
902
903 // Historically, every tensor only had a single DispatchKey, and it was always
904 // something like CPU, and there wasn't any of this business where TLS
905 // could cause the DispatchKey of a tensor to change. But we still have some
906 // legacy code that is still using DispatchKey for things like instanceof
907 // checks; if at all possible, refactor the code to stop using DispatchKey in
908 // those cases.
legacyExtractDispatchKey(DispatchKeySet s)909 inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
910 // NB: If you add any extra keys that can be stored in TensorImpl on
911 // top of existing "backend" keys like CPU/CUDA, you need to add it
912 // here. At the moment, autograd keys and ADInplaceOrView key need this
913 // treatment;
914 return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
915 autocast_dispatch_keyset -
916 DispatchKeySet(
917 {DispatchKey::Functionalize,
918 DispatchKey::PythonTLSSnapshot,
919 DispatchKey::FuncTorchGradWrapper,
920 DispatchKey::FuncTorchVmapMode,
921 DispatchKey::FuncTorchBatched,
922 DispatchKey::Python}))
923 .highestPriorityTypeId();
924 }
925
926 template <class T>
927 using is_not_DispatchKeySet = std::negation<std::is_same<DispatchKeySet, T>>;
928
929 // Given a function type, constructs a function_traits type that drops the first
930 // parameter type if the first parameter is of type DispatchKeySet. NB:
931 // DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
932 // pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
933 // the Dispatcher] for details). If at any point in the future we need to expose
934 // this type to JIT, revisit the usage of this type alias.
935 template <class FuncType>
936 using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
937 typename guts::infer_function_traits_t<FuncType>::return_type,
938 typename std::conditional_t<
939 std::is_same_v<
940 DispatchKeySet,
941 typename guts::typelist::head_with_default_t<
942 void,
943 typename guts::infer_function_traits_t<
944 FuncType>::parameter_types>>,
945 guts::typelist::drop_if_nonempty_t<
946 typename guts::infer_function_traits_t<FuncType>::parameter_types,
947 1>,
948 typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
949 } // namespace c10
950