1 #pragma once
2
3 #include <c10/core/Backend.h>
4 #include <c10/core/DefaultDtype.h>
5 #include <c10/core/Device.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/DispatchKey.h>
8 #include <c10/core/Layout.h>
9 #include <c10/core/MemoryFormat.h>
10 #include <c10/core/ScalarType.h>
11 #include <c10/core/ScalarTypeToTypeMeta.h>
12
13 #include <c10/macros/Export.h>
14 #include <c10/macros/Macros.h>
15 #include <c10/util/Exception.h>
16 #include <optional>
17
18 #include <cstdint>
19 #include <iosfwd>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23
24 namespace c10 {
25
26 DispatchKey computeDispatchKey(
27 std::optional<ScalarType> dtype,
28 std::optional<Layout> layout,
29 std::optional<Device> device);
30
dtype_or_default(std::optional<ScalarType> dtype)31 inline ScalarType dtype_or_default(std::optional<ScalarType> dtype) {
32 return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
33 }
34
dtype_or_default(std::optional<caffe2::TypeMeta> dtype)35 inline caffe2::TypeMeta dtype_or_default(
36 std::optional<caffe2::TypeMeta> dtype) {
37 return value_or_else(dtype, [] { return get_default_dtype(); });
38 }
39
layout_or_default(std::optional<Layout> layout)40 inline Layout layout_or_default(std::optional<Layout> layout) {
41 return layout.value_or(kStrided);
42 }
43
device_or_default(std::optional<Device> device)44 inline Device device_or_default(std::optional<Device> device) {
45 return value_or_else(device, [] { return Device(kCPU); });
46 }
47
pinned_memory_or_default(std::optional<bool> pinned_memory)48 inline bool pinned_memory_or_default(std::optional<bool> pinned_memory) {
49 return pinned_memory.value_or(false);
50 }
51
52 /// A class to encapsulate construction axes of an Tensor. TensorOptions was
53 /// designed to support the Python style API for specifying construction options
54 /// on factory functions, e.g.,
55 ///
56 /// torch.zeros(2, 3, dtype=torch.int32)
57 ///
58 /// Because C++ doesn't natively support keyword arguments, there must be
59 /// another way of specifying keyword-like arguments. TensorOptions is a
60 /// builder class which can be used to construct this "dictionary" of keyword
61 /// arguments: functions which support TensorOptions conventionally take this
62 /// argument optionally as their last argument.
63 ///
64 /// WARNING: In PyTorch, there are `torch::` variants of factory functions,
65 /// e.g., torch::zeros for at::zeros. These return Variables (while the
66 /// stock ATen functions return plain Tensors). If you mix these functions
67 /// up, you WILL BE SAD.
68 ///
69 /// Rather than use the constructor of this class directly, you should prefer to
70 /// use the constructor functions, and then chain setter methods on top of them.
71 ///
72 /// at::device(at::kCUDA).dtype(kInt)
73 /// at::dtype(at::kInt)
74 ///
75 /// Additionally, anywhere a TensorOptions is expected, you can directly
76 /// pass at::kCUDA / at::kInt, and it will implicitly convert to a
77 /// TensorOptions.
78 ///
79 /// Here are some recommended ways to create a 2x2 tensor of zeros
80 /// with certain properties. These all *implicitly* make use of
81 /// TensorOptions, even if they don't mention the class explicitly:
82 ///
83 /// at::zeros({2,2}, at::kCUDA);
84 /// at::zeros({2,2}, at::kLong);
85 /// at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong()));
86 /// at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1
87 /// at::zeros({2,2}, at::requires_grad());
88 ///
89
90 /// NOTE [ TensorOptions Constructors ]
91 ///
92 /// TensorOptions is like a dictionary with entries from the set:
93 /// {requires_grad, device, dtype, layout}, where each entry may be
94 /// unspecified (i.e., is optional). It is used to specify the properties of
95 /// tensors in many places both in C++ internal and API, e.g., tensor factory
96 /// methods like `at::empty({10}, options)`, tensor conversions like
97 /// `tensor.to(...)`, etc.
98 ///
99 /// To provide a simple API that is consistent with Python, where one can do
100 /// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a
101 /// `torch.layout`, we want TensorOptions to be implicitly convertible from
102 /// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have
103 /// three implicit constructors from each of these three types.
104 ///
105 /// This is sufficient for `ScalarType` and `Layout` as they are simple Enum
106 /// classes. However, `Device` is an ordinary class with implicit constructors
107 /// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be
108 /// consistent with Python API, where strings are treated as equivalent with a
109 /// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a
110 /// `torch.device("cuda:1")` is accepted). To support the syntax
111 /// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure
112 /// that `TensorOptions` is implicitly constructible with any arguments that a
113 /// `Device` can constructed from. So we have,
114 ///
115 /// /* implicit */ TensorOptions(T&& device) : TensorOptions() {
116 /// this->set_device(device);
117 /// }
118 ///
119 /// template <typename... Args,
120 /// typename = std::enable_if_t<std::is_constructible<Device,
121 /// Args&&...>::value>>
122 /// /* implicit */ TensorOptions(Args&&... args)
123 /// : TensorOptions(Device(std::forward<Args>(args)...)) {}
124 ///
125 ///
126 /// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`.
127 /// Compiler will complain about ambiguity between the copy constructor and the
128 /// `Device` constructor because `{kCUDA, 1}` can be converted to both a
129 /// `TensorOption` and a `Device`.
130 ///
131 /// To get around this, we templatize the `Device` constructor. Since overload
132 /// resolution is done before template resolution, our problem is solved.
133
134 DispatchKey computeDispatchKey(
135 std::optional<ScalarType> dtype,
136 std::optional<Layout> layout,
137 std::optional<Device> device);
138
139 struct C10_API TensorOptions {
TensorOptionsTensorOptions140 TensorOptions()
141 : requires_grad_(false),
142 pinned_memory_(false),
143 has_device_(false),
144 has_dtype_(false),
145 has_layout_(false),
146 has_requires_grad_(false),
147 has_pinned_memory_(false),
148 has_memory_format_(false) {}
149
150 /// Constructs a `TensorOptions` object with the given layout.
TensorOptionsTensorOptions151 /* implicit */ TensorOptions(Layout layout) : TensorOptions() {
152 this->set_layout(layout);
153 }
154
155 /// Constructs a `TensorOptions` object with the given device.
156 /// See NOTE [ TensorOptions Constructors ] on why this is templatized.
157 template <
158 typename T,
159 typename = std::enable_if_t<std::is_same_v<std::decay_t<T>, Device>>>
TensorOptionsTensorOptions160 /* implicit */ TensorOptions(T&& device) : TensorOptions() {
161 this->set_device(std::forward<T>(device));
162 }
163
164 /// Constructs a `TensorOptions` object from arguments allowed in `Device`
165 /// constructors.
166 ///
167 /// See NOTE [ TensorOptions Constructors ].
168 ///
169 /// NB: Ideally we only allow implicit constructors here. But there is no easy
170 /// way to detect them. So we have this one that allows explicit
171 /// constructors too.
172 template <
173 typename... Args,
174 typename = std::enable_if_t<std::is_constructible_v<Device, Args&&...>>>
TensorOptionsTensorOptions175 /* implicit */ TensorOptions(Args&&... args)
176 : TensorOptions(Device(std::forward<Args>(args)...)) {}
177
178 /// Constructs a `TensorOptions` object with the given dtype.
TensorOptionsTensorOptions179 /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
180 this->set_dtype(dtype);
181 }
182
183 /// legacy constructor to support ScalarType
TensorOptionsTensorOptions184 /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() {
185 this->set_dtype(dtype);
186 }
187
188 /// Constructs a `TensorOptions` object with the given memory format.
TensorOptionsTensorOptions189 /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() {
190 set_memory_format(memory_format);
191 }
192
193 /// Return a copy of `TensorOptions` with `device` set to the given one, or
194 /// cleared if `device` is `nullopt`.
195 C10_NODISCARD TensorOptions
deviceTensorOptions196 device(std::optional<Device> device) const noexcept {
197 TensorOptions r = *this;
198 r.set_device(device);
199 return r;
200 }
201
202 /// Return a copy of `TensorOptions` with `device` set to the given one.
203 /// (This overload ensures that variadic template std::optional constructor
204 /// for Device work correctly.)
205 template <typename... Args>
deviceTensorOptions206 C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
207 return device(
208 std::optional<Device>(std::in_place, std::forward<Args>(args)...));
209 }
210
211 /// Return a copy of `TensorOptions`, but with device set to CUDA, and the
212 /// device index set to the given one.
213 ///
214 /// TODO: This function encourages bad behavior (assuming CUDA is
215 /// the only device that matters). Get rid of it / rename it.
216 C10_NODISCARD TensorOptions
device_indexTensorOptions217 device_index(c10::DeviceIndex device_index) const noexcept {
218 return device(Device::Type::CUDA, device_index);
219 }
220
221 /// Return a copy of `TensorOptions` with `dtype` set to the given one.
222 C10_NODISCARD TensorOptions
dtypeTensorOptions223 dtype(std::optional<caffe2::TypeMeta> dtype) const noexcept {
224 TensorOptions r = *this;
225 r.set_dtype(dtype);
226 return r;
227 }
228
229 // legacy function to support ScalarType
230 C10_NODISCARD TensorOptions
dtypeTensorOptions231 dtype(std::optional<ScalarType> dtype) const noexcept {
232 TensorOptions r = *this;
233 r.set_dtype(dtype);
234 return r;
235 }
236
237 // Since dtype is taken...
238 template <typename T>
dtypeTensorOptions239 TensorOptions& dtype() {
240 dtype_ = caffe2::TypeMeta::Make<T>();
241 has_dtype_ = true;
242 return *this;
243 }
244
245 /// Sets the layout of the `TensorOptions`.
246 C10_NODISCARD TensorOptions
layoutTensorOptions247 layout(std::optional<Layout> layout) const noexcept {
248 TensorOptions r = *this;
249 r.set_layout(layout);
250 return r;
251 }
252
253 /// Sets the `requires_grad` property of the `TensorOptions`.
254 C10_NODISCARD TensorOptions
requires_gradTensorOptions255 requires_grad(std::optional<bool> requires_grad) const noexcept {
256 TensorOptions r = *this;
257 r.set_requires_grad(requires_grad);
258 return r;
259 }
260
261 /// Sets the `pinned_memory` property on the `TensorOptions`.
262 C10_NODISCARD TensorOptions
pinned_memoryTensorOptions263 pinned_memory(std::optional<bool> pinned_memory) const noexcept {
264 TensorOptions r = *this;
265 r.set_pinned_memory(pinned_memory);
266 return r;
267 }
268
269 /// Sets the `memory_format` property on `TensorOptions`.
270 C10_NODISCARD TensorOptions
memory_formatTensorOptions271 memory_format(std::optional<MemoryFormat> memory_format) const noexcept {
272 TensorOptions r = *this;
273 r.set_memory_format(memory_format);
274 return r;
275 }
276
277 /// Returns the device of the `TensorOptions`.
deviceTensorOptions278 Device device() const noexcept {
279 return device_or_default(device_opt());
280 }
281
282 /// Returns whether the device is specified.
has_deviceTensorOptions283 bool has_device() const noexcept {
284 return has_device_;
285 }
286
287 /// Returns the device of the `TensorOptions`, or `std::nullopt` if
288 /// device is not specified.
device_optTensorOptions289 std::optional<Device> device_opt() const noexcept {
290 return has_device_ ? std::make_optional(device_) : std::nullopt;
291 }
292
293 /// Returns the device index of the `TensorOptions`.
device_indexTensorOptions294 c10::DeviceIndex device_index() const noexcept {
295 return device().index();
296 }
297
298 /// Returns the dtype of the `TensorOptions`.
dtypeTensorOptions299 caffe2::TypeMeta dtype() const noexcept {
300 return dtype_or_default(dtype_opt());
301 }
302
303 /// Returns whether the dtype is specified.
has_dtypeTensorOptions304 bool has_dtype() const noexcept {
305 return has_dtype_;
306 }
307
308 /// Returns the dtype of the `TensorOptions`, or `std::nullopt` if
309 /// device is not specified.
dtype_optTensorOptions310 std::optional<caffe2::TypeMeta> dtype_opt() const noexcept {
311 return has_dtype_ ? std::make_optional(dtype_) : std::nullopt;
312 }
313
314 /// Returns the layout of the `TensorOptions`.
layoutTensorOptions315 Layout layout() const noexcept {
316 return layout_or_default(layout_opt());
317 }
318
319 /// Returns whether the layout is specified.
has_layoutTensorOptions320 bool has_layout() const noexcept {
321 return has_layout_;
322 }
323
324 /// Returns the layout of the `TensorOptions`, or `std::nullopt` if
325 /// layout is not specified.
layout_optTensorOptions326 std::optional<Layout> layout_opt() const noexcept {
327 return has_layout_ ? std::make_optional(layout_) : std::nullopt;
328 }
329
330 /// Returns the `requires_grad` property of the `TensorOptions`.
requires_gradTensorOptions331 bool requires_grad() const noexcept {
332 return has_requires_grad_ ? requires_grad_ : false;
333 }
334
335 /// Returns whether the `requires_grad` is specified.
has_requires_gradTensorOptions336 bool has_requires_grad() const noexcept {
337 return has_requires_grad_;
338 }
339
340 /// Returns the `requires_grad` property of the `TensorOptions`, or
341 /// `std::nullopt` if `requires_grad` is not specified.
requires_grad_optTensorOptions342 std::optional<bool> requires_grad_opt() const noexcept {
343 return has_requires_grad_ ? std::make_optional(requires_grad_)
344 : std::nullopt;
345 }
346
347 /// Returns the `pinned_memory` property of the `TensorOptions`.
pinned_memoryTensorOptions348 bool pinned_memory() const noexcept {
349 return pinned_memory_or_default(pinned_memory_opt());
350 }
351
352 /// Returns whether the `pinned_memory` is specified.
has_pinned_memoryTensorOptions353 bool has_pinned_memory() const noexcept {
354 return has_pinned_memory_;
355 }
356
357 /// Returns if the layout is sparse
is_sparseTensorOptions358 bool is_sparse() const {
359 return layout_ == c10::Layout::Sparse;
360 }
361
362 /// Returns if the layout is sparse CSR, deprecated, use
363 /// is_sparse_compressed() instead
is_sparse_csrTensorOptions364 bool is_sparse_csr() const {
365 return layout_ == c10::Layout::SparseCsr;
366 }
367
is_sparse_compressedTensorOptions368 bool is_sparse_compressed() const {
369 return layout_ == c10::Layout::SparseCsr ||
370 layout_ == c10::Layout::SparseCsc ||
371 layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc;
372 }
373
374 // For compatibility with legacy tensor.type() comparisons
type_equalTensorOptions375 bool type_equal(const TensorOptions& other) const {
376 return computeDispatchKey() == other.computeDispatchKey() &&
377 typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
378 }
379
380 /// Returns the `pinned_memory` property of the `TensorOptions`, or
381 /// `std::nullopt` if `pinned_memory` is not specified.
pinned_memory_optTensorOptions382 std::optional<bool> pinned_memory_opt() const noexcept {
383 return has_pinned_memory_ ? std::make_optional(pinned_memory_)
384 : std::nullopt;
385 }
386
387 /// Returns whether the `memory_layout` is specified
has_memory_formatTensorOptions388 bool has_memory_format() const noexcept {
389 return has_memory_format_;
390 }
391
392 // NB: memory_format() getter is PURPOSELY not defined, as the default
393 // behavior of memory_format varies from function to function.
394
395 /// Returns the `memory_layout` property of `TensorOptions, or
396 /// `std::nullopt` if `memory_format` is not specified.
memory_format_optTensorOptions397 std::optional<MemoryFormat> memory_format_opt() const noexcept {
398 return has_memory_format_ ? std::make_optional(memory_format_)
399 : std::nullopt;
400 }
401
402 // Resolves the ATen backend specified by the current construction axes.
403 // TODO: Deprecate this
backendTensorOptions404 Backend backend() const {
405 return at::dispatchKeyToBackend(computeDispatchKey());
406 }
407
408 /// Return the right-biased merge of two TensorOptions. This has the
409 /// effect of overwriting settings from self with specified options
410 /// of options.
411 ///
412 /// NB: This merging operation does NOT respect device merges.
413 /// For example, if you device({kCUDA, 1}).merge_in(kCUDA)
414 /// you will get kCUDA in the end! Functions like Tensor.new_empty
415 /// ensure the right device is selected anyway by way of a
416 /// device guard.
417 ///
merge_inTensorOptions418 TensorOptions merge_in(TensorOptions options) const noexcept {
419 TensorOptions merged = *this;
420 if (options.has_device())
421 merged.set_device(options.device_opt());
422 if (options.has_dtype())
423 merged.set_dtype(options.dtype_opt());
424 if (options.has_layout())
425 merged.set_layout(options.layout_opt());
426 // NB: requires grad is right biased; not a logical AND/OR!
427 if (options.has_requires_grad())
428 merged.set_requires_grad(options.requires_grad_opt());
429 if (options.has_pinned_memory())
430 merged.set_pinned_memory(options.pinned_memory_opt());
431 if (options.has_memory_format())
432 merged.set_memory_format(options.memory_format_opt());
433 return merged;
434 }
435
436 // TODO remove after TensorOptions rationalization
merge_memory_formatTensorOptions437 TensorOptions merge_memory_format(
438 std::optional<MemoryFormat> optional_memory_format) const noexcept {
439 TensorOptions merged = *this;
440 if (optional_memory_format.has_value()) {
441 merged.set_memory_format(*optional_memory_format);
442 }
443 return merged;
444 }
445
446 // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
447 // which dispatchKeyToBackend is injective, if it is defined at all (for
448 // the most part, this just means that this function never returns an
449 // Autograd key)
computeDispatchKeyTensorOptions450 DispatchKey computeDispatchKey() const {
451 return c10::computeDispatchKey(
452 optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
453 }
454
455 private:
456 // These methods are currently private because I'm not sure if it's wise
457 // to actually publish them. They are methods because I need them in
458 // the constructor and the functional API implementation.
459 //
460 // If you really, really need it, you can make these public, but check if you
461 // couldn't just do what you need with the functional API. Similarly, these
462 // methods are not chainable, because if you wanted chaining, you probably
463 // want to use the functional API instead. (It's probably OK to make
464 // these chainable, because these functions are all explicitly annotated
465 // with a ref-qualifier, the trailing &, that makes them illegal to call
466 // on temporaries.)
467
468 /// Mutably set the device of `TensorOptions`.
set_deviceTensorOptions469 void set_device(std::optional<Device> device) & noexcept {
470 if (device) {
471 device_ = *device;
472 has_device_ = true;
473 } else {
474 has_device_ = false;
475 }
476 }
477
478 /// Mutably set the dtype of `TensorOptions`.
set_dtypeTensorOptions479 void set_dtype(std::optional<caffe2::TypeMeta> dtype) & noexcept {
480 if (dtype) {
481 dtype_ = *dtype;
482 has_dtype_ = true;
483 } else {
484 has_dtype_ = false;
485 }
486 }
487
488 // legacy function to support ScalarType
set_dtypeTensorOptions489 void set_dtype(std::optional<ScalarType> dtype) & noexcept {
490 if (dtype) {
491 dtype_ = scalarTypeToTypeMeta(*dtype);
492 has_dtype_ = true;
493 } else {
494 has_dtype_ = false;
495 }
496 }
497
498 /// Mutably set the layout of `TensorOptions`.
set_layoutTensorOptions499 void set_layout(std::optional<Layout> layout) & noexcept {
500 if (layout) {
501 layout_ = *layout;
502 has_layout_ = true;
503 } else {
504 has_layout_ = false;
505 }
506 }
507
508 /// Mutably set the `requires_grad` property of `TensorOptions`.
set_requires_gradTensorOptions509 void set_requires_grad(std::optional<bool> requires_grad) & noexcept {
510 if (requires_grad) {
511 requires_grad_ = *requires_grad;
512 has_requires_grad_ = true;
513 } else {
514 has_requires_grad_ = false;
515 }
516 }
517
518 /// Mutably set the `pinned_memory` property of `TensorOptions`.
set_pinned_memoryTensorOptions519 void set_pinned_memory(std::optional<bool> pinned_memory) & noexcept {
520 if (pinned_memory) {
521 pinned_memory_ = *pinned_memory;
522 has_pinned_memory_ = true;
523 } else {
524 has_pinned_memory_ = false;
525 }
526 }
527
528 /// Mutably set the `memory_Format` property of `TensorOptions`.
set_memory_formatTensorOptions529 void set_memory_format(std::optional<MemoryFormat> memory_format) & noexcept {
530 if (memory_format) {
531 memory_format_ = *memory_format;
532 has_memory_format_ = true;
533 } else {
534 has_memory_format_ = false;
535 }
536 }
537
538 // WARNING: If you edit TensorOptions to add more options, you
539 // may need to adjust the implementation of Tensor::options.
540 // The criteria for whether or not Tensor::options must be adjusted
541 // is whether or not the new option you added should preserved
542 // by functions such as empty_like(); if it should be preserved,
543 // you must adjust options().
544 //
545 // TODO: MemoryFormat is not implemented in this way
546
547 // NB: We didn't use std::optional here, because then we can't pack
548 // the has_***_ boolean fields.
549
550 Device device_ = at::kCPU; // 16-bit
551 caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 16-bit
552 Layout layout_ = at::kStrided; // 8-bit
553 MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit
554
555 // Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
556 // for that matter)
557
558 bool requires_grad_ : 1;
559 bool pinned_memory_ : 1;
560
561 bool has_device_ : 1;
562 bool has_dtype_ : 1;
563 bool has_layout_ : 1;
564 bool has_requires_grad_ : 1;
565 bool has_pinned_memory_ : 1;
566 bool has_memory_format_ : 1;
567 };
568
569 // We should aspire to fit in one machine-size word; but a size greater than two
570 // words is too much. (We are doing terribly on 32-bit archs, where we require
571 // three machine size words to store tensor options. Eek!)
572 static_assert(
573 sizeof(TensorOptions) <= sizeof(int64_t) * 2,
574 "TensorOptions must fit in 128-bits");
575
576 /// Convenience function that returns a `TensorOptions` object with the `dtype`
577 /// set to the given one.
dtype(caffe2::TypeMeta dtype)578 inline TensorOptions dtype(caffe2::TypeMeta dtype) {
579 return TensorOptions().dtype(dtype);
580 }
581
582 // legacy function to support ScalarType
dtype(ScalarType dtype)583 inline TensorOptions dtype(ScalarType dtype) {
584 return TensorOptions().dtype(scalarTypeToTypeMeta(dtype));
585 }
586
587 /// Convenience function that returns a `TensorOptions` object with the `layout`
588 /// set to the given one.
layout(Layout layout)589 inline TensorOptions layout(Layout layout) {
590 return TensorOptions().layout(layout);
591 }
592
593 /// Convenience function that returns a `TensorOptions` object with the `device`
594 /// set to the given one.
device(Device device)595 inline TensorOptions device(Device device) {
596 return TensorOptions().device(device);
597 }
598
599 /// Convenience function that returns a `TensorOptions` object with the
600 /// `device` set to CUDA and the `device_index` set to the given one.
device_index(c10::DeviceIndex device_index)601 inline TensorOptions device_index(c10::DeviceIndex device_index) {
602 return TensorOptions().device_index(device_index);
603 }
604
605 /// Convenience function that returns a `TensorOptions` object with the
606 /// `requires_grad` set to the given one.
607 inline TensorOptions requires_grad(bool requires_grad = true) {
608 return TensorOptions().requires_grad(requires_grad);
609 }
610
611 /// Convenience function that returns a `TensorOptions` object with the
612 /// `memory_format` set to the given one.
memory_format(MemoryFormat memory_format)613 inline TensorOptions memory_format(MemoryFormat memory_format) {
614 return TensorOptions().memory_format(memory_format);
615 }
616
617 C10_API std::ostream& operator<<(
618 std::ostream& stream,
619 const TensorOptions& options);
620
621 template <typename T>
dtype()622 inline TensorOptions dtype() {
623 return dtype(caffe2::TypeMeta::Make<T>());
624 }
625
toString(const TensorOptions & options)626 inline std::string toString(const TensorOptions& options) {
627 std::ostringstream stream;
628 stream << options;
629 return stream.str();
630 }
631
632 // This is intended to be a centralized location by which we can determine
633 // what an appropriate DispatchKey for a tensor is.
computeDispatchKey(std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device)634 inline DispatchKey computeDispatchKey(
635 std::optional<ScalarType> dtype,
636 std::optional<Layout> layout,
637 std::optional<Device> device) {
638 const auto layout_ = layout_or_default(layout);
639 const auto device_ = device_or_default(device);
640 switch (layout_) {
641 case Layout::Jagged:
642 case Layout::Strided: {
643 const auto dtype_ = dtype_or_default(dtype);
644 switch (device_.type()) {
645 #define DO_CASE(device, _) \
646 case c10::DeviceType::device: { \
647 if (isQIntType(dtype_)) { \
648 return DispatchKey::Quantized##device; \
649 } \
650 return DispatchKey::device; \
651 }
652 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
653 #undef DO_CASE
654 case c10::DeviceType::FPGA:
655 return DispatchKey::FPGA;
656 case c10::DeviceType::MAIA:
657 return DispatchKey::MAIA;
658 case c10::DeviceType::Vulkan:
659 return DispatchKey::Vulkan;
660 case c10::DeviceType::Metal:
661 return DispatchKey::Metal;
662 case c10::DeviceType::MKLDNN:
663 case c10::DeviceType::OPENGL:
664 case c10::DeviceType::OPENCL:
665 case c10::DeviceType::IDEEP:
666 TORCH_INTERNAL_ASSERT(
667 0,
668 "This is a grandfathered Caffe2 device type ",
669 device_.type(),
670 ", it shouldn't ever convert to a DispatchKey. File a bug describing what you were doing if you think this is in error.");
671 default:
672 TORCH_CHECK_NOT_IMPLEMENTED(
673 false,
674 "Unsupported device type for dense layout: ",
675 device_.type());
676 }
677 }
678 case Layout::Sparse:
679 switch (device_.type()) {
680 #define DO_CASE(device, _) \
681 case c10::DeviceType::device: { \
682 return DispatchKey::Sparse##device; \
683 }
684 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
685 #undef DO_CASE
686 default:
687 TORCH_CHECK_NOT_IMPLEMENTED(
688 false,
689 "Unsupported device type for sparse layout: ",
690 device_.type());
691 }
692 case Layout::Mkldnn:
693 switch (device_.type()) {
694 case c10::DeviceType::CPU:
695 return DispatchKey::MkldnnCPU;
696 default:
697 TORCH_CHECK_NOT_IMPLEMENTED(
698 false,
699 "Unsupported device type for mkldnn layout: ",
700 device_.type());
701 }
702 case Layout::SparseCsr:
703 case Layout::SparseCsc:
704 case Layout::SparseBsr:
705 case Layout::SparseBsc:
706 switch (device_.type()) {
707 #define DO_CASE(device, _) \
708 case c10::DeviceType::device: { \
709 return DispatchKey::SparseCsr##device; \
710 }
711 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
712 #undef DO_CASE
713 default:
714 TORCH_CHECK_NOT_IMPLEMENTED(
715 false,
716 "Unsupported device type for ",
717 layout_,
718 " layout: ",
719 device_.type());
720 }
721 default:
722 TORCH_CHECK(false, "Unsupported layout: ", layout_);
723 }
724 }
725
dispatchKeyToLayout(DispatchKey dispatch_key)726 inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
727 switch (dispatch_key) {
728 #define DO_CASE(bc, _) case DispatchKey::Sparse##bc:
729 C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
730 #undef DO_CASE
731 return Layout::Sparse;
732 #define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc:
733 C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
734 #undef DO_CASE
735 TORCH_CHECK(
736 false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout.");
737 case DispatchKey::MkldnnCPU:
738 return Layout::Mkldnn;
739 default:
740 return Layout::Strided;
741 }
742 }
743
dispatchKeyToDeviceType(DispatchKey dispatch_key)744 inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
745 switch (dispatch_key) {
746 // stuff that's real
747 #define DO_CASE(suffix, prefix) \
748 case DispatchKey::prefix##suffix: \
749 return c10::DeviceType::suffix;
750 #define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix)
751 C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES)
752 #undef DO_CASES
753 #undef DO_CASE
754
755 case DispatchKey::MkldnnCPU:
756 return c10::DeviceType::CPU;
757 case DispatchKey::Vulkan:
758 return c10::DeviceType::Vulkan;
759
760 case DispatchKey::MAIA:
761 return c10::DeviceType::MAIA;
762 default:
763 TORCH_CHECK(
764 false,
765 "DispatchKey ",
766 dispatch_key,
767 " doesn't correspond to a device");
768 }
769 }
770
dispatchKeyToTensorOptions(DispatchKey dispatch_key)771 inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
772 return TensorOptions()
773 .layout(dispatchKeyToLayout(dispatch_key))
774 .device(dispatchKeyToDeviceType(dispatch_key));
775 }
776
777 namespace detail {
backend_supports_empty_operator(const TensorOptions & options)778 inline bool backend_supports_empty_operator(const TensorOptions& options) {
779 // Quantized backends don't support at::empty().
780 // They have separate operators like at::empty_quantized() that take in
781 // extra information about how to quantize the tensor.
782 return !isQIntType(typeMetaToScalarType(options.dtype()));
783 }
784
785 } // namespace detail
786
787 } // namespace c10
788