xref: /aosp_15_r20/external/pytorch/c10/core/TensorOptions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Backend.h>
4 #include <c10/core/DefaultDtype.h>
5 #include <c10/core/Device.h>
6 #include <c10/core/DeviceType.h>
7 #include <c10/core/DispatchKey.h>
8 #include <c10/core/Layout.h>
9 #include <c10/core/MemoryFormat.h>
10 #include <c10/core/ScalarType.h>
11 #include <c10/core/ScalarTypeToTypeMeta.h>
12 
13 #include <c10/macros/Export.h>
14 #include <c10/macros/Macros.h>
15 #include <c10/util/Exception.h>
16 #include <optional>
17 
18 #include <cstdint>
19 #include <iosfwd>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 
24 namespace c10 {
25 
26 DispatchKey computeDispatchKey(
27     std::optional<ScalarType> dtype,
28     std::optional<Layout> layout,
29     std::optional<Device> device);
30 
dtype_or_default(std::optional<ScalarType> dtype)31 inline ScalarType dtype_or_default(std::optional<ScalarType> dtype) {
32   return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); });
33 }
34 
dtype_or_default(std::optional<caffe2::TypeMeta> dtype)35 inline caffe2::TypeMeta dtype_or_default(
36     std::optional<caffe2::TypeMeta> dtype) {
37   return value_or_else(dtype, [] { return get_default_dtype(); });
38 }
39 
layout_or_default(std::optional<Layout> layout)40 inline Layout layout_or_default(std::optional<Layout> layout) {
41   return layout.value_or(kStrided);
42 }
43 
device_or_default(std::optional<Device> device)44 inline Device device_or_default(std::optional<Device> device) {
45   return value_or_else(device, [] { return Device(kCPU); });
46 }
47 
pinned_memory_or_default(std::optional<bool> pinned_memory)48 inline bool pinned_memory_or_default(std::optional<bool> pinned_memory) {
49   return pinned_memory.value_or(false);
50 }
51 
52 /// A class to encapsulate construction axes of an Tensor.  TensorOptions was
53 /// designed to support the Python style API for specifying construction options
54 /// on factory functions, e.g.,
55 ///
56 ///     torch.zeros(2, 3, dtype=torch.int32)
57 ///
58 /// Because C++ doesn't natively support keyword arguments, there must be
59 /// another way of specifying keyword-like arguments.  TensorOptions is a
60 /// builder class which can be used to construct this "dictionary" of keyword
61 /// arguments: functions which support TensorOptions conventionally take this
62 /// argument optionally as their last argument.
63 ///
64 /// WARNING: In PyTorch, there are `torch::` variants of factory functions,
65 /// e.g., torch::zeros for at::zeros.  These return Variables (while the
66 /// stock ATen functions return plain Tensors).  If you mix these functions
67 /// up, you WILL BE SAD.
68 ///
69 /// Rather than use the constructor of this class directly, you should prefer to
70 /// use the constructor functions, and then chain setter methods on top of them.
71 ///
72 ///     at::device(at::kCUDA).dtype(kInt)
73 ///     at::dtype(at::kInt)
74 ///
75 /// Additionally, anywhere a TensorOptions is expected, you can directly
76 /// pass at::kCUDA / at::kInt, and it will implicitly convert to a
77 /// TensorOptions.
78 ///
79 /// Here are some recommended ways to create a 2x2 tensor of zeros
80 /// with certain properties.  These all *implicitly* make use of
81 /// TensorOptions, even if they don't mention the class explicitly:
82 ///
83 ///     at::zeros({2,2}, at::kCUDA);
84 ///     at::zeros({2,2}, at::kLong);
85 ///     at::zeros({2,2}, at::device(at::kCUDA).dtype(at::kLong()));
86 ///     at::zeros({2,2}, at::device({at::kCUDA, 1})); // place on device 1
87 ///     at::zeros({2,2}, at::requires_grad());
88 ///
89 
90 /// NOTE [ TensorOptions Constructors ]
91 ///
92 /// TensorOptions is like a dictionary with entries from the set:
93 /// {requires_grad, device, dtype, layout}, where each entry may be
94 /// unspecified (i.e., is optional). It is used to specify the properties of
95 /// tensors in many places both in C++ internal and API, e.g., tensor factory
96 /// methods like `at::empty({10}, options)`, tensor conversions like
97 /// `tensor.to(...)`, etc.
98 ///
99 /// To provide a simple API that is consistent with Python, where one can do
100 /// `torch.empty(sizes, X)` with `X` being a `torch.device`, `torch.dtype`, or a
101 /// `torch.layout`, we want TensorOptions to be implicitly convertible from
102 /// `ScalarType dtype`, `Layout layout` and `Device device`. Therefore, we have
103 /// three implicit constructors from each of these three types.
104 ///
105 /// This is sufficient for `ScalarType` and `Layout` as they are simple Enum
106 /// classes. However, `Device` is an ordinary class with implicit constructors
107 /// `Device(DeviceType, DeviceIndex = -1)` and `Device(std::string)` to be
108 /// consistent with Python API, where strings are treated as equivalent with a
109 /// `torch.device` object (e.g., "cuda:1" can be passed to everywhere a
110 /// `torch.device("cuda:1")` is accepted). To support the syntax
111 /// `at::empty({10}, {kCUDA, 1})` and `tensor.to(kCUDA)`, we need to make sure
112 /// that `TensorOptions` is implicitly constructible with any arguments that a
113 /// `Device` can constructed from. So we have,
114 ///
115 ///    /* implicit */ TensorOptions(T&& device) : TensorOptions() {
116 ///      this->set_device(device);
117 ///    }
118 ///
119 ///    template <typename... Args,
120 ///             typename = std::enable_if_t<std::is_constructible<Device,
121 ///             Args&&...>::value>>
122 ///    /* implicit */  TensorOptions(Args&&... args)
123 ///     : TensorOptions(Device(std::forward<Args>(args)...)) {}
124 ///
125 ///
126 /// But this will be problematic. Consider this: `TensorOptions({kCUDA, 1})`.
127 /// Compiler will complain about ambiguity between the copy constructor and the
128 /// `Device` constructor because `{kCUDA, 1}` can be converted to both a
129 /// `TensorOption` and a `Device`.
130 ///
131 /// To get around this, we templatize the `Device` constructor. Since overload
132 /// resolution is done before template resolution, our problem is solved.
133 
134 DispatchKey computeDispatchKey(
135     std::optional<ScalarType> dtype,
136     std::optional<Layout> layout,
137     std::optional<Device> device);
138 
139 struct C10_API TensorOptions {
TensorOptionsTensorOptions140   TensorOptions()
141       : requires_grad_(false),
142         pinned_memory_(false),
143         has_device_(false),
144         has_dtype_(false),
145         has_layout_(false),
146         has_requires_grad_(false),
147         has_pinned_memory_(false),
148         has_memory_format_(false) {}
149 
150   /// Constructs a `TensorOptions` object with the given layout.
TensorOptionsTensorOptions151   /* implicit */ TensorOptions(Layout layout) : TensorOptions() {
152     this->set_layout(layout);
153   }
154 
155   /// Constructs a `TensorOptions` object with the given device.
156   /// See NOTE [ TensorOptions Constructors ] on why this is templatized.
157   template <
158       typename T,
159       typename = std::enable_if_t<std::is_same_v<std::decay_t<T>, Device>>>
TensorOptionsTensorOptions160   /* implicit */ TensorOptions(T&& device) : TensorOptions() {
161     this->set_device(std::forward<T>(device));
162   }
163 
164   /// Constructs a `TensorOptions` object from arguments allowed in `Device`
165   /// constructors.
166   ///
167   /// See NOTE [ TensorOptions Constructors ].
168   ///
169   /// NB: Ideally we only allow implicit constructors here. But there is no easy
170   ///     way to detect them. So we have this one that allows explicit
171   ///     constructors too.
172   template <
173       typename... Args,
174       typename = std::enable_if_t<std::is_constructible_v<Device, Args&&...>>>
TensorOptionsTensorOptions175   /* implicit */ TensorOptions(Args&&... args)
176       : TensorOptions(Device(std::forward<Args>(args)...)) {}
177 
178   /// Constructs a `TensorOptions` object with the given dtype.
TensorOptionsTensorOptions179   /* implicit */ TensorOptions(caffe2::TypeMeta dtype) : TensorOptions() {
180     this->set_dtype(dtype);
181   }
182 
183   /// legacy constructor to support ScalarType
TensorOptionsTensorOptions184   /* implicit */ TensorOptions(ScalarType dtype) : TensorOptions() {
185     this->set_dtype(dtype);
186   }
187 
188   /// Constructs a `TensorOptions` object with the given memory format.
TensorOptionsTensorOptions189   /* implicit */ TensorOptions(MemoryFormat memory_format) : TensorOptions() {
190     set_memory_format(memory_format);
191   }
192 
193   /// Return a copy of `TensorOptions` with `device` set to the given one, or
194   /// cleared if `device` is `nullopt`.
195   C10_NODISCARD TensorOptions
deviceTensorOptions196   device(std::optional<Device> device) const noexcept {
197     TensorOptions r = *this;
198     r.set_device(device);
199     return r;
200   }
201 
202   /// Return a copy of `TensorOptions` with `device` set to the given one.
203   /// (This overload ensures that variadic template std::optional constructor
204   /// for Device work correctly.)
205   template <typename... Args>
deviceTensorOptions206   C10_NODISCARD TensorOptions device(Args&&... args) const noexcept {
207     return device(
208         std::optional<Device>(std::in_place, std::forward<Args>(args)...));
209   }
210 
211   /// Return a copy of `TensorOptions`, but with device set to CUDA, and the
212   /// device index set to the given one.
213   ///
214   /// TODO: This function encourages bad behavior (assuming CUDA is
215   /// the only device that matters).  Get rid of it / rename it.
216   C10_NODISCARD TensorOptions
device_indexTensorOptions217   device_index(c10::DeviceIndex device_index) const noexcept {
218     return device(Device::Type::CUDA, device_index);
219   }
220 
221   /// Return a copy of `TensorOptions` with `dtype` set to the given one.
222   C10_NODISCARD TensorOptions
dtypeTensorOptions223   dtype(std::optional<caffe2::TypeMeta> dtype) const noexcept {
224     TensorOptions r = *this;
225     r.set_dtype(dtype);
226     return r;
227   }
228 
229   // legacy function to support ScalarType
230   C10_NODISCARD TensorOptions
dtypeTensorOptions231   dtype(std::optional<ScalarType> dtype) const noexcept {
232     TensorOptions r = *this;
233     r.set_dtype(dtype);
234     return r;
235   }
236 
237   // Since dtype is taken...
238   template <typename T>
dtypeTensorOptions239   TensorOptions& dtype() {
240     dtype_ = caffe2::TypeMeta::Make<T>();
241     has_dtype_ = true;
242     return *this;
243   }
244 
245   /// Sets the layout of the `TensorOptions`.
246   C10_NODISCARD TensorOptions
layoutTensorOptions247   layout(std::optional<Layout> layout) const noexcept {
248     TensorOptions r = *this;
249     r.set_layout(layout);
250     return r;
251   }
252 
253   /// Sets the `requires_grad` property of the `TensorOptions`.
254   C10_NODISCARD TensorOptions
requires_gradTensorOptions255   requires_grad(std::optional<bool> requires_grad) const noexcept {
256     TensorOptions r = *this;
257     r.set_requires_grad(requires_grad);
258     return r;
259   }
260 
261   /// Sets the `pinned_memory` property on the `TensorOptions`.
262   C10_NODISCARD TensorOptions
pinned_memoryTensorOptions263   pinned_memory(std::optional<bool> pinned_memory) const noexcept {
264     TensorOptions r = *this;
265     r.set_pinned_memory(pinned_memory);
266     return r;
267   }
268 
269   /// Sets the `memory_format` property on `TensorOptions`.
270   C10_NODISCARD TensorOptions
memory_formatTensorOptions271   memory_format(std::optional<MemoryFormat> memory_format) const noexcept {
272     TensorOptions r = *this;
273     r.set_memory_format(memory_format);
274     return r;
275   }
276 
277   /// Returns the device of the `TensorOptions`.
deviceTensorOptions278   Device device() const noexcept {
279     return device_or_default(device_opt());
280   }
281 
282   /// Returns whether the device is specified.
has_deviceTensorOptions283   bool has_device() const noexcept {
284     return has_device_;
285   }
286 
287   /// Returns the device of the `TensorOptions`, or `std::nullopt` if
288   /// device is not specified.
device_optTensorOptions289   std::optional<Device> device_opt() const noexcept {
290     return has_device_ ? std::make_optional(device_) : std::nullopt;
291   }
292 
293   /// Returns the device index of the `TensorOptions`.
device_indexTensorOptions294   c10::DeviceIndex device_index() const noexcept {
295     return device().index();
296   }
297 
298   /// Returns the dtype of the `TensorOptions`.
dtypeTensorOptions299   caffe2::TypeMeta dtype() const noexcept {
300     return dtype_or_default(dtype_opt());
301   }
302 
303   /// Returns whether the dtype is specified.
has_dtypeTensorOptions304   bool has_dtype() const noexcept {
305     return has_dtype_;
306   }
307 
308   /// Returns the dtype of the `TensorOptions`, or `std::nullopt` if
309   /// device is not specified.
dtype_optTensorOptions310   std::optional<caffe2::TypeMeta> dtype_opt() const noexcept {
311     return has_dtype_ ? std::make_optional(dtype_) : std::nullopt;
312   }
313 
314   /// Returns the layout of the `TensorOptions`.
layoutTensorOptions315   Layout layout() const noexcept {
316     return layout_or_default(layout_opt());
317   }
318 
319   /// Returns whether the layout is specified.
has_layoutTensorOptions320   bool has_layout() const noexcept {
321     return has_layout_;
322   }
323 
324   /// Returns the layout of the `TensorOptions`, or `std::nullopt` if
325   /// layout is not specified.
layout_optTensorOptions326   std::optional<Layout> layout_opt() const noexcept {
327     return has_layout_ ? std::make_optional(layout_) : std::nullopt;
328   }
329 
330   /// Returns the `requires_grad` property of the `TensorOptions`.
requires_gradTensorOptions331   bool requires_grad() const noexcept {
332     return has_requires_grad_ ? requires_grad_ : false;
333   }
334 
335   /// Returns whether the `requires_grad` is specified.
has_requires_gradTensorOptions336   bool has_requires_grad() const noexcept {
337     return has_requires_grad_;
338   }
339 
340   /// Returns the `requires_grad` property of the `TensorOptions`, or
341   /// `std::nullopt` if `requires_grad` is not specified.
requires_grad_optTensorOptions342   std::optional<bool> requires_grad_opt() const noexcept {
343     return has_requires_grad_ ? std::make_optional(requires_grad_)
344                               : std::nullopt;
345   }
346 
347   /// Returns the `pinned_memory` property of the `TensorOptions`.
pinned_memoryTensorOptions348   bool pinned_memory() const noexcept {
349     return pinned_memory_or_default(pinned_memory_opt());
350   }
351 
352   /// Returns whether the `pinned_memory` is specified.
has_pinned_memoryTensorOptions353   bool has_pinned_memory() const noexcept {
354     return has_pinned_memory_;
355   }
356 
357   /// Returns if the layout is sparse
is_sparseTensorOptions358   bool is_sparse() const {
359     return layout_ == c10::Layout::Sparse;
360   }
361 
362   /// Returns if the layout is sparse CSR, deprecated, use
363   /// is_sparse_compressed() instead
is_sparse_csrTensorOptions364   bool is_sparse_csr() const {
365     return layout_ == c10::Layout::SparseCsr;
366   }
367 
is_sparse_compressedTensorOptions368   bool is_sparse_compressed() const {
369     return layout_ == c10::Layout::SparseCsr ||
370         layout_ == c10::Layout::SparseCsc ||
371         layout_ == c10::Layout::SparseBsr || layout_ == c10::Layout::SparseBsc;
372   }
373 
374   // For compatibility with legacy tensor.type() comparisons
type_equalTensorOptions375   bool type_equal(const TensorOptions& other) const {
376     return computeDispatchKey() == other.computeDispatchKey() &&
377         typeMetaToScalarType(dtype_) == typeMetaToScalarType(other.dtype());
378   }
379 
380   /// Returns the `pinned_memory` property of the `TensorOptions`, or
381   /// `std::nullopt` if `pinned_memory` is not specified.
pinned_memory_optTensorOptions382   std::optional<bool> pinned_memory_opt() const noexcept {
383     return has_pinned_memory_ ? std::make_optional(pinned_memory_)
384                               : std::nullopt;
385   }
386 
387   /// Returns whether the `memory_layout` is specified
has_memory_formatTensorOptions388   bool has_memory_format() const noexcept {
389     return has_memory_format_;
390   }
391 
392   // NB: memory_format() getter is PURPOSELY not defined, as the default
393   // behavior of memory_format varies from function to function.
394 
395   /// Returns the `memory_layout` property of `TensorOptions, or
396   /// `std::nullopt` if `memory_format` is not specified.
memory_format_optTensorOptions397   std::optional<MemoryFormat> memory_format_opt() const noexcept {
398     return has_memory_format_ ? std::make_optional(memory_format_)
399                               : std::nullopt;
400   }
401 
402   // Resolves the ATen backend specified by the current construction axes.
403   // TODO: Deprecate this
backendTensorOptions404   Backend backend() const {
405     return at::dispatchKeyToBackend(computeDispatchKey());
406   }
407 
408   /// Return the right-biased merge of two TensorOptions.  This has the
409   /// effect of overwriting settings from self with specified options
410   /// of options.
411   ///
412   /// NB: This merging operation does NOT respect device merges.
413   /// For example, if you device({kCUDA, 1}).merge_in(kCUDA)
414   /// you will get kCUDA in the end!  Functions like Tensor.new_empty
415   /// ensure the right device is selected anyway by way of a
416   /// device guard.
417   ///
merge_inTensorOptions418   TensorOptions merge_in(TensorOptions options) const noexcept {
419     TensorOptions merged = *this;
420     if (options.has_device())
421       merged.set_device(options.device_opt());
422     if (options.has_dtype())
423       merged.set_dtype(options.dtype_opt());
424     if (options.has_layout())
425       merged.set_layout(options.layout_opt());
426     // NB: requires grad is right biased; not a logical AND/OR!
427     if (options.has_requires_grad())
428       merged.set_requires_grad(options.requires_grad_opt());
429     if (options.has_pinned_memory())
430       merged.set_pinned_memory(options.pinned_memory_opt());
431     if (options.has_memory_format())
432       merged.set_memory_format(options.memory_format_opt());
433     return merged;
434   }
435 
436   // TODO remove after TensorOptions rationalization
merge_memory_formatTensorOptions437   TensorOptions merge_memory_format(
438       std::optional<MemoryFormat> optional_memory_format) const noexcept {
439     TensorOptions merged = *this;
440     if (optional_memory_format.has_value()) {
441       merged.set_memory_format(*optional_memory_format);
442     }
443     return merged;
444   }
445 
446   // INVARIANT: computeDispatchKey returns only the subset of dispatch keys for
447   // which dispatchKeyToBackend is injective, if it is defined at all  (for
448   // the most part, this just means that this function never returns an
449   // Autograd key)
computeDispatchKeyTensorOptions450   DispatchKey computeDispatchKey() const {
451     return c10::computeDispatchKey(
452         optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
453   }
454 
455  private:
456   // These methods are currently private because I'm not sure if it's wise
457   // to actually publish them.  They are methods because I need them in
458   // the constructor and the functional API implementation.
459   //
460   // If you really, really need it, you can make these public, but check if you
461   // couldn't just do what you need with the functional API.  Similarly, these
462   // methods are not chainable, because if you wanted chaining, you probably
463   // want to use the functional API instead.  (It's probably OK to make
464   // these chainable, because these functions are all explicitly annotated
465   // with a ref-qualifier, the trailing &, that makes them illegal to call
466   // on temporaries.)
467 
468   /// Mutably set the device of `TensorOptions`.
set_deviceTensorOptions469   void set_device(std::optional<Device> device) & noexcept {
470     if (device) {
471       device_ = *device;
472       has_device_ = true;
473     } else {
474       has_device_ = false;
475     }
476   }
477 
478   /// Mutably set the dtype of `TensorOptions`.
set_dtypeTensorOptions479   void set_dtype(std::optional<caffe2::TypeMeta> dtype) & noexcept {
480     if (dtype) {
481       dtype_ = *dtype;
482       has_dtype_ = true;
483     } else {
484       has_dtype_ = false;
485     }
486   }
487 
488   // legacy function to support ScalarType
set_dtypeTensorOptions489   void set_dtype(std::optional<ScalarType> dtype) & noexcept {
490     if (dtype) {
491       dtype_ = scalarTypeToTypeMeta(*dtype);
492       has_dtype_ = true;
493     } else {
494       has_dtype_ = false;
495     }
496   }
497 
498   /// Mutably set the layout of `TensorOptions`.
set_layoutTensorOptions499   void set_layout(std::optional<Layout> layout) & noexcept {
500     if (layout) {
501       layout_ = *layout;
502       has_layout_ = true;
503     } else {
504       has_layout_ = false;
505     }
506   }
507 
508   /// Mutably set the `requires_grad` property of `TensorOptions`.
set_requires_gradTensorOptions509   void set_requires_grad(std::optional<bool> requires_grad) & noexcept {
510     if (requires_grad) {
511       requires_grad_ = *requires_grad;
512       has_requires_grad_ = true;
513     } else {
514       has_requires_grad_ = false;
515     }
516   }
517 
518   /// Mutably set the `pinned_memory` property of `TensorOptions`.
set_pinned_memoryTensorOptions519   void set_pinned_memory(std::optional<bool> pinned_memory) & noexcept {
520     if (pinned_memory) {
521       pinned_memory_ = *pinned_memory;
522       has_pinned_memory_ = true;
523     } else {
524       has_pinned_memory_ = false;
525     }
526   }
527 
528   /// Mutably set the `memory_Format` property of `TensorOptions`.
set_memory_formatTensorOptions529   void set_memory_format(std::optional<MemoryFormat> memory_format) & noexcept {
530     if (memory_format) {
531       memory_format_ = *memory_format;
532       has_memory_format_ = true;
533     } else {
534       has_memory_format_ = false;
535     }
536   }
537 
538   // WARNING: If you edit TensorOptions to add more options, you
539   // may need to adjust the implementation of Tensor::options.
540   // The criteria for whether or not Tensor::options must be adjusted
541   // is whether or not the new option you added should preserved
542   // by functions such as empty_like(); if it should be preserved,
543   // you must adjust options().
544   //
545   // TODO: MemoryFormat is not implemented in this way
546 
547   // NB: We didn't use std::optional here, because then we can't pack
548   // the has_***_ boolean fields.
549 
550   Device device_ = at::kCPU; // 16-bit
551   caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 16-bit
552   Layout layout_ = at::kStrided; // 8-bit
553   MemoryFormat memory_format_ = MemoryFormat::Contiguous; // 8-bit
554 
555   // Bitmask required here to get this to fit inside 32 bits (or even 64 bits,
556   // for that matter)
557 
558   bool requires_grad_ : 1;
559   bool pinned_memory_ : 1;
560 
561   bool has_device_ : 1;
562   bool has_dtype_ : 1;
563   bool has_layout_ : 1;
564   bool has_requires_grad_ : 1;
565   bool has_pinned_memory_ : 1;
566   bool has_memory_format_ : 1;
567 };
568 
569 // We should aspire to fit in one machine-size word; but a size greater than two
570 // words is too much.  (We are doing terribly on 32-bit archs, where we require
571 // three machine size words to store tensor options.  Eek!)
572 static_assert(
573     sizeof(TensorOptions) <= sizeof(int64_t) * 2,
574     "TensorOptions must fit in 128-bits");
575 
576 /// Convenience function that returns a `TensorOptions` object with the `dtype`
577 /// set to the given one.
dtype(caffe2::TypeMeta dtype)578 inline TensorOptions dtype(caffe2::TypeMeta dtype) {
579   return TensorOptions().dtype(dtype);
580 }
581 
582 // legacy function to support ScalarType
dtype(ScalarType dtype)583 inline TensorOptions dtype(ScalarType dtype) {
584   return TensorOptions().dtype(scalarTypeToTypeMeta(dtype));
585 }
586 
587 /// Convenience function that returns a `TensorOptions` object with the `layout`
588 /// set to the given one.
layout(Layout layout)589 inline TensorOptions layout(Layout layout) {
590   return TensorOptions().layout(layout);
591 }
592 
593 /// Convenience function that returns a `TensorOptions` object with the `device`
594 /// set to the given one.
device(Device device)595 inline TensorOptions device(Device device) {
596   return TensorOptions().device(device);
597 }
598 
599 /// Convenience function that returns a `TensorOptions` object with the
600 /// `device` set to CUDA and the `device_index` set to the given one.
device_index(c10::DeviceIndex device_index)601 inline TensorOptions device_index(c10::DeviceIndex device_index) {
602   return TensorOptions().device_index(device_index);
603 }
604 
605 /// Convenience function that returns a `TensorOptions` object with the
606 /// `requires_grad` set to the given one.
607 inline TensorOptions requires_grad(bool requires_grad = true) {
608   return TensorOptions().requires_grad(requires_grad);
609 }
610 
611 /// Convenience function that returns a `TensorOptions` object with the
612 /// `memory_format` set to the given one.
memory_format(MemoryFormat memory_format)613 inline TensorOptions memory_format(MemoryFormat memory_format) {
614   return TensorOptions().memory_format(memory_format);
615 }
616 
617 C10_API std::ostream& operator<<(
618     std::ostream& stream,
619     const TensorOptions& options);
620 
621 template <typename T>
dtype()622 inline TensorOptions dtype() {
623   return dtype(caffe2::TypeMeta::Make<T>());
624 }
625 
toString(const TensorOptions & options)626 inline std::string toString(const TensorOptions& options) {
627   std::ostringstream stream;
628   stream << options;
629   return stream.str();
630 }
631 
632 // This is intended to be a centralized location by which we can determine
633 // what an appropriate DispatchKey for a tensor is.
computeDispatchKey(std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device)634 inline DispatchKey computeDispatchKey(
635     std::optional<ScalarType> dtype,
636     std::optional<Layout> layout,
637     std::optional<Device> device) {
638   const auto layout_ = layout_or_default(layout);
639   const auto device_ = device_or_default(device);
640   switch (layout_) {
641     case Layout::Jagged:
642     case Layout::Strided: {
643       const auto dtype_ = dtype_or_default(dtype);
644       switch (device_.type()) {
645 #define DO_CASE(device, _)                   \
646   case c10::DeviceType::device: {            \
647     if (isQIntType(dtype_)) {                \
648       return DispatchKey::Quantized##device; \
649     }                                        \
650     return DispatchKey::device;              \
651   }
652         C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
653 #undef DO_CASE
654         case c10::DeviceType::FPGA:
655           return DispatchKey::FPGA;
656         case c10::DeviceType::MAIA:
657           return DispatchKey::MAIA;
658         case c10::DeviceType::Vulkan:
659           return DispatchKey::Vulkan;
660         case c10::DeviceType::Metal:
661           return DispatchKey::Metal;
662         case c10::DeviceType::MKLDNN:
663         case c10::DeviceType::OPENGL:
664         case c10::DeviceType::OPENCL:
665         case c10::DeviceType::IDEEP:
666           TORCH_INTERNAL_ASSERT(
667               0,
668               "This is a grandfathered Caffe2 device type ",
669               device_.type(),
670               ", it shouldn't ever convert to a DispatchKey.  File a bug describing what you were doing if you think this is in error.");
671         default:
672           TORCH_CHECK_NOT_IMPLEMENTED(
673               false,
674               "Unsupported device type for dense layout: ",
675               device_.type());
676       }
677     }
678     case Layout::Sparse:
679       switch (device_.type()) {
680 #define DO_CASE(device, _)              \
681   case c10::DeviceType::device: {       \
682     return DispatchKey::Sparse##device; \
683   }
684         C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
685 #undef DO_CASE
686         default:
687           TORCH_CHECK_NOT_IMPLEMENTED(
688               false,
689               "Unsupported device type for sparse layout: ",
690               device_.type());
691       }
692     case Layout::Mkldnn:
693       switch (device_.type()) {
694         case c10::DeviceType::CPU:
695           return DispatchKey::MkldnnCPU;
696         default:
697           TORCH_CHECK_NOT_IMPLEMENTED(
698               false,
699               "Unsupported device type for mkldnn layout: ",
700               device_.type());
701       }
702     case Layout::SparseCsr:
703     case Layout::SparseCsc:
704     case Layout::SparseBsr:
705     case Layout::SparseBsc:
706       switch (device_.type()) {
707 #define DO_CASE(device, _)                 \
708   case c10::DeviceType::device: {          \
709     return DispatchKey::SparseCsr##device; \
710   }
711         C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
712 #undef DO_CASE
713         default:
714           TORCH_CHECK_NOT_IMPLEMENTED(
715               false,
716               "Unsupported device type for ",
717               layout_,
718               " layout: ",
719               device_.type());
720       }
721     default:
722       TORCH_CHECK(false, "Unsupported layout: ", layout_);
723   }
724 }
725 
dispatchKeyToLayout(DispatchKey dispatch_key)726 inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
727   switch (dispatch_key) {
728 #define DO_CASE(bc, _) case DispatchKey::Sparse##bc:
729     C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
730 #undef DO_CASE
731     return Layout::Sparse;
732 #define DO_CASE(bc, _) case DispatchKey::SparseCsr##bc:
733     C10_FORALL_BACKEND_COMPONENTS(DO_CASE, unused)
734 #undef DO_CASE
735     TORCH_CHECK(
736         false, "Cannot map DispatchKey ", dispatch_key, " to a unique layout.");
737     case DispatchKey::MkldnnCPU:
738       return Layout::Mkldnn;
739     default:
740       return Layout::Strided;
741   }
742 }
743 
dispatchKeyToDeviceType(DispatchKey dispatch_key)744 inline c10::DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
745   switch (dispatch_key) {
746     // stuff that's real
747 #define DO_CASE(suffix, prefix)     \
748   case DispatchKey::prefix##suffix: \
749     return c10::DeviceType::suffix;
750 #define DO_CASES(_, prefix) C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, prefix)
751     C10_FORALL_FUNCTIONALITY_KEYS(DO_CASES)
752 #undef DO_CASES
753 #undef DO_CASE
754 
755     case DispatchKey::MkldnnCPU:
756       return c10::DeviceType::CPU;
757     case DispatchKey::Vulkan:
758       return c10::DeviceType::Vulkan;
759 
760     case DispatchKey::MAIA:
761       return c10::DeviceType::MAIA;
762     default:
763       TORCH_CHECK(
764           false,
765           "DispatchKey ",
766           dispatch_key,
767           " doesn't correspond to a device");
768   }
769 }
770 
dispatchKeyToTensorOptions(DispatchKey dispatch_key)771 inline TensorOptions dispatchKeyToTensorOptions(DispatchKey dispatch_key) {
772   return TensorOptions()
773       .layout(dispatchKeyToLayout(dispatch_key))
774       .device(dispatchKeyToDeviceType(dispatch_key));
775 }
776 
777 namespace detail {
backend_supports_empty_operator(const TensorOptions & options)778 inline bool backend_supports_empty_operator(const TensorOptions& options) {
779   // Quantized backends don't support at::empty().
780   // They have separate operators like at::empty_quantized() that take in
781   // extra information about how to quantize the tensor.
782   return !isQIntType(typeMetaToScalarType(options.dtype()));
783 }
784 
785 } // namespace detail
786 
787 } // namespace c10
788