1 #pragma once
2
3 #include <c10/core/Allocator.h>
4 #include <c10/core/Device.h>
5 #include <c10/core/DeviceType.h>
6 #include <c10/core/DispatchKey.h>
7 #include <c10/core/DispatchKeySet.h>
8 #include <c10/core/InferenceMode.h>
9 #include <c10/core/Layout.h>
10 #include <c10/core/MemoryFormat.h>
11 #include <c10/core/ScalarType.h>
12 #include <c10/core/ScalarTypeToTypeMeta.h>
13 #include <c10/core/Storage.h>
14 #include <c10/core/SymBool.h>
15 #include <c10/core/SymInt.h>
16 #include <c10/core/SymIntArrayRef.h>
17 #include <c10/core/SymbolicShapeMeta.h>
18 #include <c10/core/WrapDimMinimal.h>
19 #include <c10/core/impl/PyObjectSlot.h>
20 #include <c10/core/impl/SizesAndStrides.h>
21 #include <c10/macros/Export.h>
22 #include <c10/macros/Macros.h>
23 #include <c10/util/ArrayRef.h>
24 #include <c10/util/DimVector.h>
25 #include <c10/util/Exception.h>
26 #include <c10/util/Flags.h>
27 #include <c10/util/accumulate.h>
28 #include <c10/util/intrusive_ptr.h>
29 #include <c10/util/irange.h>
30 #include <c10/util/safe_numerics.h>
31 #include <c10/util/typeid.h>
32 #include <optional>
33
34 #include <algorithm>
35 #include <atomic>
36 #include <cstddef>
37 #include <cstdint>
38 #include <limits>
39 #include <memory>
40 #include <string>
41 #include <type_traits>
42 #include <utility>
43 #include <vector>
44
45 // A global boolean variable to control whether we free memory when a Tensor
46 // is shrunk to a smaller size. As a result, a Tensor is always going to
47 // keep the memory allocated for its maximum capacity reshaped to so far.
48 //
49 // This parameter is respected "upper-case" methods which call Resize()
50 // (e.g., CopyFrom, ResizeLike); it is NOT respected by Tensor::resize_
51 // or ShrinkTo, both of which guarantee to never to free memory.
52 C10_DECLARE_bool(caffe2_keep_on_shrink);
53
54 // Since we can have high variance in blob memory allocated across different
55 // inputs in the same run, we will shrink the blob only if the memory gain
56 // is larger than this flag in bytes. This only applies to functions which
57 // respect caffe2_keep_on_shrink.
58 C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
59
60 namespace at {
61 class Tensor;
62 class TensorBase;
63 } // namespace at
64
65 namespace c10 {
66
67 /**
68 * A utility function to convert vector<int> to vector<int64_t>.
69 */
ToVectorint64_t(const ArrayRef<int> & src)70 inline std::vector<int64_t> ToVectorint64_t(const ArrayRef<int>& src) {
71 return std::vector<int64_t>(src.begin(), src.end());
72 }
73
74 /**
75 * Return product of all dimensions starting from k
76 */
size_from_dim_(int k,IntArrayRef dims)77 inline int64_t size_from_dim_(int k, IntArrayRef dims) {
78 int64_t r = 1;
79 for (const auto i : c10::irange(k, dims.size())) {
80 r *= dims[i];
81 }
82 return r;
83 }
84
85 // Product of all dims up to k (not including dims[k])
size_to_dim_(int k,IntArrayRef dims)86 inline int64_t size_to_dim_(int k, IntArrayRef dims) {
87 TORCH_CHECK(k >= 0 && static_cast<size_t>(k) <= dims.size());
88 int64_t r = 1;
89 for (const auto i : c10::irange(k)) {
90 r *= dims[i];
91 }
92 return r;
93 }
94
95 // Product of all dims between k and l (not including dims[k] and dims[l])
size_between_dim_(int k,int l,IntArrayRef dims)96 inline int64_t size_between_dim_(int k, int l, IntArrayRef dims) {
97 TORCH_CHECK((unsigned)l < dims.size() && (unsigned)k < dims.size());
98 int64_t r = 1;
99 if (k < l) {
100 for (int i = k + 1; i < l; ++i) {
101 r *= dims[i];
102 }
103 } else {
104 for (int i = l + 1; i < k; ++i) {
105 r *= dims[i];
106 }
107 }
108 return r;
109 }
110
111 // Wrap around axis_index if it is negative, s.t., -1 is the last dim
canonical_axis_index_(int axis_index,int ndims)112 inline int canonical_axis_index_(int axis_index, int ndims) {
113 TORCH_CHECK(axis_index >= -ndims);
114 TORCH_CHECK(axis_index < ndims);
115 if (axis_index < 0) {
116 return axis_index + ndims;
117 }
118 return axis_index;
119 }
120
121 using PlacementDtor = void (*)(void*, size_t);
122
123 /*
124 * A Context that will call extra placement deleter during
125 * deconstruction.
126 *
127 * Accept a already constructed DataPtr and store it as member
128 * during destruction, we'll call extra deleter on the underlying
129 * data pointer before the DataPtr is destructed.
130 * `data_ptr_` owns the memory.
131 */
132 struct C10_API PlacementDeleteContext {
133 DataPtr data_ptr_;
134 PlacementDtor placement_dtor_;
135 size_t size_;
PlacementDeleteContextPlacementDeleteContext136 PlacementDeleteContext(
137 DataPtr&& data_ptr,
138 PlacementDtor placement_dtor,
139 size_t size)
140 : data_ptr_(std::move(data_ptr)),
141 placement_dtor_(placement_dtor),
142 size_(size) {}
143 static DataPtr makeDataPtr(
144 DataPtr&& data_ptr,
145 PlacementDtor placement_dtor,
146 size_t size,
147 Device device);
~PlacementDeleteContextPlacementDeleteContext148 ~PlacementDeleteContext() {
149 placement_dtor_(data_ptr_.get(), size_);
150 // original memory will be freed when data_ptr_ is destructed
151 }
152 };
153
154 struct C10_API AutogradMetaInterface {
155 virtual void set_requires_grad(
156 bool requires_grad,
157 at::TensorImpl* self_impl) = 0;
158 virtual bool requires_grad() const = 0;
159 virtual at::Tensor& mutable_grad() = 0;
160 virtual const at::Tensor& grad() const = 0;
161 virtual const at::Tensor& fw_grad(uint64_t level, const at::TensorBase& self)
162 const = 0;
163 virtual void set_fw_grad(
164 const at::TensorBase& new_grad,
165 const at::TensorBase& self,
166 uint64_t level,
167 bool is_inplace_op) = 0;
168 virtual ~AutogradMetaInterface();
169 };
170
171 namespace impl {
172
173 // Unfortunately, the definition of AutogradMeta lives in a separate
174 // compilation unit than TensorImpl (libtorch.so versus libc10.so)
175 // which means that we cannot construct an AutogradMeta from TensorImpl,
176 // not even from the cpp file. So we have to indirect it through a factory
177 // function which will be initialized when we load libtorch.so.
178
179 struct C10_API AutogradMetaFactory {
180 virtual ~AutogradMetaFactory() = default;
181 virtual std::unique_ptr<AutogradMetaInterface> make() const = 0;
182 // This method is the dumbest method. But I don't have access
183 // to Tensor (not TensorImpl) which is undefined in this header.
184 virtual const at::Tensor& undefined_tensor() const = 0;
185 };
186
187 C10_API void SetAutogradMetaFactory(AutogradMetaFactory* factory);
188 C10_API AutogradMetaFactory* GetAutogradMetaFactory();
189
190 struct C10_API AutogradMetaFactoryRegisterer {
AutogradMetaFactoryRegistererAutogradMetaFactoryRegisterer191 explicit AutogradMetaFactoryRegisterer(AutogradMetaFactory* factory) {
192 SetAutogradMetaFactory(factory);
193 }
194 };
195
196 } // namespace impl
197
198 struct C10_API NamedTensorMetaInterface {
199 virtual ~NamedTensorMetaInterface() = default;
cloneNamedTensorMetaInterface200 virtual std::unique_ptr<NamedTensorMetaInterface> clone() const {
201 TORCH_INTERNAL_ASSERT(
202 false, "Not implemented: NamedTensorMetaInterface::clone");
203 };
slow_dimNamedTensorMetaInterface204 virtual int64_t slow_dim() const {
205 TORCH_INTERNAL_ASSERT(
206 false, "Not implemented: NamedTensorMetaInterface::slow_dim");
207 };
208 };
209
210 // For ease of copy pasting
211 #if 0
212 is_contiguous
213 is_channels_last_contiguous
214 is_channels_last_3d_contiguous
215 is_channels_last
216 is_channels_last_3d
217 is_non_overlapping_and_dense
218 #endif
219
220 /**
221 * This structure is intended to hold additional metadata of the specific device
222 * backend.
223 **/
224 struct C10_API BackendMeta : intrusive_ptr_target {
225 ~BackendMeta() override = default;
cloneBackendMeta226 virtual intrusive_ptr<BackendMeta> clone(
227 const intrusive_ptr<BackendMeta>& ptr) const {
228 return ptr;
229 }
230 };
231
232 struct C10_API ExtraMeta {
233 std::unique_ptr<c10::SymbolicShapeMeta> symbolic_shape_meta_ = nullptr;
234 std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
235 intrusive_ptr<c10::BackendMeta> backend_meta_ = nullptr;
236 std::optional<std::string> custom_data_ptr_error_msg_ = std::nullopt;
237 std::optional<std::string> custom_storage_error_msg_ = std::nullopt;
238
239 ExtraMeta() = default;
ExtraMetaExtraMeta240 ExtraMeta(const ExtraMeta& other) {
241 if (other.symbolic_shape_meta_) {
242 symbolic_shape_meta_ =
243 std::make_unique<c10::SymbolicShapeMeta>(*other.symbolic_shape_meta_);
244 }
245 if (other.named_tensor_meta_) {
246 named_tensor_meta_ = other.named_tensor_meta_->clone();
247 }
248 if (other.backend_meta_) {
249 backend_meta_ = other.backend_meta_->clone(other.backend_meta_);
250 }
251 if (other.custom_data_ptr_error_msg_) {
252 custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_;
253 }
254 if (other.custom_storage_error_msg_) {
255 custom_storage_error_msg_ = other.custom_storage_error_msg_;
256 }
257 }
258 ExtraMeta& operator=(const ExtraMeta& other) = delete;
259 ExtraMeta(ExtraMeta&& other) = delete;
260 ExtraMeta& operator=(ExtraMeta&& other) = delete;
261
262 ExtraMeta(
263 std::unique_ptr<c10::SymbolicShapeMeta> symbolic_shape_meta,
264 std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta,
265 intrusive_ptr<c10::BackendMeta> backend_meta,
266 std::optional<std::string> custom_data_ptr_error_msg = std::nullopt,
267 std::optional<std::string> custom_storage_access_error_msg = std::nullopt)
symbolic_shape_meta_ExtraMeta268 : symbolic_shape_meta_(std::move(symbolic_shape_meta)),
269 named_tensor_meta_(std::move(named_tensor_meta)),
270 backend_meta_(std::move(backend_meta)),
271 custom_data_ptr_error_msg_(std::move(custom_data_ptr_error_msg)),
272 custom_storage_error_msg_(std::move(custom_storage_access_error_msg)) {}
273
cloneExtraMeta274 std::unique_ptr<ExtraMeta> clone() const {
275 return std::make_unique<ExtraMeta>(*this);
276 }
277 };
278
279 // NOTE [ Version Counter Sharing ]
280 //
281 // Every Tensor has a version counter. Version counters are incremented whenever
282 // the data or size of a tensor changes through in-place Variable operations.
283 // Version counters are used to detect modifications to saved variables which
284 // would result in incorrect gradient calculations. Version counters may be
285 // shared between Variables:
286 //
287 // 1. A view shares the version counter of the base Variable,
288 // 2. `x.detach()` shares the version counter of `x`,
289 // 3. Unpacked saved variables share the version counter of the source.
290 //
291 // Version counters are not shared in these scenarios:
292 //
293 // 1. When we replace a `Variable`'s underlying `Tensor` by calling
294 // `set_data(...)`,
295 // 2. `x.data` does not share the version counter of `x`. (See discussion at
296 // https://github.com/pytorch/pytorch/issues/5396)
297 //
298 // Question: Why do we put the version counter in TensorImpl instead of
299 // AutogradMeta?
300 //
301 // Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta
302 // when its `requires_grad_` is false, but when we use this tensor in the
303 // forward pass of a function that requires saving this tensor for backward, we
304 // need to keep track of this tensor's version to make sure it's always valid in
305 // the autograd graph.
306 //
307 // To achieve this goal, we put the version counter in TensorImpl instead of
308 // AutogradMeta, and have it always be available. This allows us to have the
309 // optimization of not carrying AutogradMeta when a tensor doesn't require
310 // gradient.
311 //
312 // A hypothetical alternative way to achieve this goal is to initialize
313 // AutogradMeta and create the version counter for the non-requires-grad tensor
314 // only when it's saved for backward. However, since saving a tensor for
315 // backward happens in the forward pass, and our invariant is that forward pass
316 // needs to be thread-safe, lazy-initializing AutogradMeta when saving a tensor
317 // can introduce race conditions when we are running the forward pass in
318 // multi-thread scenarios, thus making the forward pass not thread-safe anymore,
319 // which breaks the invariant.
320 struct C10_API VariableVersion {
321 private:
322 struct VersionCounter : intrusive_ptr_target {
VersionCounterVariableVersion::VersionCounter323 VersionCounter(uint32_t version) : version_(version) {}
324 std::atomic<uint32_t> version_;
325 };
326 c10::intrusive_ptr<VersionCounter> version_counter_;
327
328 public:
329 // Note [Disabled VariableVersion]
330 // VariableVersion struct has an intrusive_ptr pointing VersionCounter struct
331 // with an atomic variable. Thus `VariableVersion(/*version=*/0)` is not as
332 // cheap as we expected. In some cases constructing a VariableVersion with
333 // version 0 is not necessary so we add a cheap constructor which
334 // doesn't allocate the intrusive_ptr.
335 // Example use cases are:
336 // - Inference tensors don't track version counter, so they'll just always
337 // have disabled VariableVersion.
338 // - In SavedVariable class we override version_counter_ inside its
339 // constructor
340 // so that we can use the cheap constructor there.
341 enum Disabled { DISABLED };
342 // It's okay to return true even for inference tensor which
343 // doesn't have version counter enabled.
344 // We want to be permissive here since in many cases (e.g. make_variable)
345 // we can std::move a TensorImpl if there's no other uses which saves us
346 // an additional TensorImpl allocation.
uniqueVariableVersion347 bool unique() const {
348 return version_counter_ ? 1 == version_counter_.use_count() : true;
349 }
350 // NOTE: As of C++11 and 14, default-constructing a std::atomic variable
351 // leaves it in a persistently undefined state. See
352 // https://cplusplus.github.io/LWG/issue2334.
VariableVersionVariableVersion353 VariableVersion(uint32_t version)
354 : version_counter_(c10::make_intrusive<VersionCounter>(version)) {}
355 VariableVersion(Disabled = DISABLED) {}
356
enabledVariableVersion357 bool enabled() const {
358 return version_counter_;
359 }
360
361 // Note [Inplace update inference tensor]
362 // 1. Inplace update to inference tensor is forbidden in normal mode.
363 // For example:
364 // inference_tensor.copy_(normal_tensor_requires_grad)
365 // This inplace makes inference_tensor have requires_grad=True and
366 // have a grad_fn. This is bad because views of `inference_tensor`
367 // created in InferenceMode won't be able to know the grad_fn since
368 // their ViewMeta were not recorded. To match NoGradMode behavior
369 // that "inplace update to a view created in NoGradMode raise an error",
370 // we just ban inplace update to inference tensor since we can't tell
371 // if an inference tensor is a view created in InferenceMode.
372 //
373 // Note that views of normal tensor created in InferenceMode has proper
374 // ViewMeta so that they're aware of the grad_fn correctly.
375 //
376 // 2. Inplace update to inference tensor in inference tensor doesn't bump
377 // version counter.
378 // * It either doesn't call bump() by skipping ADInplaceOrView kernel,
379 // - e.g. inference_tensor.add_(1)
380 // * or bump() is a no-op for inference tensor.
381 // - e.g. inference_tensor.add_(normal_tensor)
bumpVariableVersion382 void bump() {
383 // TODO: Replace the link to the documentation once it's available.
384 TORCH_CHECK(
385 version_counter_ || InferenceMode::is_enabled(),
386 "Inplace update to inference tensor outside InferenceMode is not allowed."
387 "You can make a clone to get a normal tensor before doing inplace update."
388 "See https://github.com/pytorch/rfcs/pull/17 for more details.");
389 if (version_counter_) {
390 ++version_counter_->version_;
391 }
392 }
393
set_versionVariableVersion394 void set_version(int64_t i) {
395 TORCH_CHECK(
396 version_counter_,
397 "Tried to call torch.autograd._unsafe_set_version() on a tensor "
398 "that does not have a version counter. Was it created in inference mode?");
399 TORCH_CHECK(i >= 0, "Cannot set a version_counter to a value below 0: ", i);
400 version_counter_->version_ = i;
401 }
402
403 // Inference tensor doesn't have version counter so it shouldn't be
404 // accessed.
current_versionVariableVersion405 uint32_t current_version() const {
406 TORCH_CHECK(
407 version_counter_, "Inference tensors do not track version counter.");
408 return version_counter_->version_;
409 }
410 };
411
412 // Forward declaration of TensorImpl needed for forward declaration of
413 // C10_TensorImpl_Size_Check_Dummy_Class
414 struct C10_API TensorImpl;
415
416 /**
417 * NOTE: Some TensorImpl methods are small and not overridden in the
418 * PyTorch codebase itself, but may theoretically need to be
419 * overridden by third-party TensorImpl subclasses. This macro allows
420 * users that need maximum performance and don't need these extension
421 * points to disable them with a build-time flag. (In particular,
422 * XLA's XLATensorImpl currently overrides these methods, so we can't
423 * enable this flag by default.)
424 */
425 #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
426 #define TENSORIMPL_MAYBE_VIRTUAL
427 #else
428 #define TENSORIMPL_MAYBE_VIRTUAL virtual
429 #endif
430
431 /**
432 * The low-level representation of a tensor, which contains a pointer
433 * to a storage (which contains the actual data) and metadata (e.g., sizes and
434 * strides) describing this particular view of the data as a tensor.
435 *
436 * Some basic characteristics about our in-memory representation of
437 * tensors:
438 *
439 * - It contains a pointer to a storage struct (Storage/StorageImpl)
440 * which contains the pointer to the actual data and records the
441 * data type and device of the view. This allows multiple tensors
442 * to alias the same underlying data, which allows to efficiently
443 * implement differing *views* on a tensor.
444 *
445 * - The tensor struct itself records view-specific metadata about
446 * the tensor, e.g., sizes, strides and offset into storage.
447 * Each view of a storage can have a different size or offset.
448 *
449 * - This class is intrusively refcounted. It is refcounted so that
450 * we can support prompt deallocation of large tensors; it is
451 * intrusively refcounted so that we can still perform reference
452 * counted operations on raw pointers, which is often more convenient
453 * when passing tensors across language boundaries.
454 *
455 * - For backwards-compatibility reasons, a tensor may be in an
456 * uninitialized state. A tensor may be uninitialized in the following
457 * two ways:
458 *
459 * - A tensor may be DTYPE UNINITIALIZED. A tensor of this
460 * form has an uninitialized dtype. This situation most
461 * frequently arises when a user writes Tensor x(CPU). The dtype
462 * is subsequently initialized when mutable_data<T>() is
463 * invoked for the first time.
464 *
465 * - A tensor may be STORAGE UNINITIALIZED. A tensor of this form
466 * has non-zero size, but has a storage with a null data pointer.
467 * This situation most frequently arises when a user calls
468 * Resize() or FreeMemory(). This is because Caffe2 historically
469 * does lazy allocation: allocation of data doesn't occur until
470 * mutable_data<T>() is invoked. A tensor with zero size is
471 * always storage initialized, because no allocation is necessary
472 * in this case.
473 *
474 * All combinations of these two uninitialized states are possible.
475 * Consider the following transcript in idiomatic Caffe2 API:
476 *
477 * Tensor x(CPU); // x is storage-initialized, dtype-UNINITIALIZED
478 * x.Resize(4); // x is storage-UNINITIALIZED, dtype-UNINITIALIZED
479 * x.mutable_data<float>(); // x is storage-initialized, dtype-initialized
480 * x.FreeMemory(); // x is storage-UNINITIALIZED, dtype-initialized.
481 *
482 * All other fields on tensor are always initialized. In particular,
483 * size is always valid. (Historically, a tensor declared as Tensor x(CPU)
484 * also had uninitialized size, encoded as numel == -1, but we have now
485 * decided to default to zero size, resulting in numel == 0).
486 *
487 * Uninitialized storages MUST be uniquely owned, to keep our model
488 * simple. Thus, we will reject operations which could cause an
489 * uninitialized storage to become shared (or a shared storage to
490 * become uninitialized, e.g., from FreeMemory).
491 *
492 * In practice, tensors which are storage-UNINITIALIZED and
493 * dtype-UNINITIALIZED are *extremely* ephemeral: essentially,
494 * after you do a Resize(), you basically always call mutable_data()
495 * immediately afterwards. Most functions are not designed to
496 * work if given a storage-UNINITIALIZED, dtype-UNINITIALIZED tensor.
497 *
498 * We intend to eliminate all uninitialized states, so that every
499 * tensor is fully initialized in all fields. Please do not write new code
500 * that depends on these uninitialized states.
501 */
502 struct C10_API TensorImpl : public c10::intrusive_ptr_target {
503 TensorImpl() = delete;
504 ~TensorImpl() override;
505 // Note [Enum ImplType]
506 // This enum is temporary. In the followup refactor we should
507 // think about how to specialize TensorImpl creation for view
508 // tensors. Currently we only special case its key_set_ but
509 // there's also potential to share version_counter_ directly
510 // without creating first and then override in as_view.
511 enum ImplType { VIEW };
512
513 /**
514 * Construct a 1-dim 0-size tensor backed by the given storage.
515 */
516 TensorImpl(
517 Storage&& storage,
518 DispatchKeySet,
519 const caffe2::TypeMeta data_type);
520
521 // See Note [Enum ImplType]
522 TensorImpl(
523 ImplType,
524 Storage&& storage,
525 DispatchKeySet,
526 const caffe2::TypeMeta data_type);
527
528 /**
529 * Construct a 1-dim 0 size tensor that doesn't have a storage.
530 */
531 TensorImpl(
532 DispatchKeySet,
533 const caffe2::TypeMeta data_type,
534 std::optional<c10::Device> device_opt);
535
536 // Legacy constructors so I don't have to go update call sites.
537 // TODO: When Variable is added, delete these constructors
TensorImplTensorImpl538 TensorImpl(
539 Storage&& storage,
540 DispatchKey dispatch_key,
541 const caffe2::TypeMeta data_type)
542 : TensorImpl(
543 std::move(storage),
544 DispatchKeySet(dispatch_key),
545 data_type) {}
TensorImplTensorImpl546 TensorImpl(
547 DispatchKey dispatch_key,
548 const caffe2::TypeMeta data_type,
549 std::optional<c10::Device> device_opt)
550 : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {}
551
552 private:
553 // This constructor is private, because the data_type is redundant with
554 // storage. Still, we pass it in separately because it's easier to write
555 // the initializer list if we're not worried about storage being moved out
556 // from under us.
557 TensorImpl(
558 Storage&& storage,
559 DispatchKeySet,
560 const caffe2::TypeMeta data_type,
561 std::optional<c10::Device>);
562
563 public:
564 TensorImpl(const TensorImpl&) = delete;
565 TensorImpl& operator=(const TensorImpl&) = delete;
566 TensorImpl(TensorImpl&&) = delete;
567 TensorImpl& operator=(TensorImpl&&) = delete;
568
569 /**
570 * Release (decref) storage, and any other external allocations. This
571 * override is for `intrusive_ptr_target` and is used to implement weak
572 * tensors.
573 */
574 void release_resources() override;
575
576 public:
577 /**
578 * Return the DispatchKeySet corresponding to this Tensor, specifying
579 * all of the DispatchKeys that this Tensor identifies as. This is the
580 * information used to dispatch operations on this tensor.
581 */
key_setTensorImpl582 DispatchKeySet key_set() const {
583 return key_set_;
584 }
585
586 private:
587 [[noreturn]] void throw_cannot_call_with_symbolic(const char* meth) const;
588
589 // NOTE: The general recipe for customizable methods is that the fastpath
590 // function (e.g., sizes()) does an unlikely policy test, and if doesn't
591 // trigger, it does the fast path implementation with no checks and going
592 // directly to on-TensorImpl fields. In particular, you never need to
593 // check ExtraMeta if the policy doesn't trigger, as non-trivial ExtraMeta
594 // implies the policy will always match.
595 //
596 // The default implementations of methods are "safe": they do extra tests
597 // to make sure the internal state is consistent no matter if you are
598 // doing symbolic shapes or not. If you don't want the tests, directly
599 // override the custom method (e.g., custom_sizes()) to do your preferred
600 // behavior.
601
602 public:
603 /**
604 * Return a reference to the sizes of this tensor. This reference remains
605 * valid as long as the tensor is live and not resized.
606 */
sizesTensorImpl607 IntArrayRef sizes() const {
608 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
609 return sizes_custom();
610 }
611 return sizes_and_strides_.sizes_arrayref();
612 }
613
sym_sizesTensorImpl614 SymIntArrayRef sym_sizes() const {
615 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
616 return sym_sizes_custom();
617 }
618 // Sizes guaranteed to be non-negative, so unchecked cast is OK
619 return c10::fromIntArrayRefKnownNonNegative(
620 sizes_and_strides_.sizes_arrayref());
621 }
622
sizes_defaultTensorImpl623 IntArrayRef sizes_default() const {
624 if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
625 throw_cannot_call_with_symbolic("sizes");
626 }
627 return sizes_and_strides_.sizes_arrayref();
628 }
629
sym_sizes_defaultTensorImpl630 SymIntArrayRef sym_sizes_default() const {
631 if (has_symbolic_sizes_strides_) {
632 return symbolic_shape_meta().sizes_;
633 } else {
634 // Sizes guaranteed to be non-negative, so unchecked cast is OK
635 return c10::fromIntArrayRefKnownNonNegative(sizes_default());
636 }
637 }
638
639 // From https://stackoverflow.com/a/3057522/23845
640 // TODO: does C++14 have a stdlib template for this?
641 template <typename T>
642 struct identity {
643 typedef T type;
644 };
645
646 template <typename T>
generic_sizesTensorImpl647 ArrayRef<T> generic_sizes() {
648 return _generic_sizes(identity<T>());
649 }
650
_generic_sizesTensorImpl651 ArrayRef<int64_t> _generic_sizes(identity<int64_t>) {
652 return sizes();
653 }
_generic_sizesTensorImpl654 ArrayRef<c10::SymInt> _generic_sizes(identity<c10::SymInt>) {
655 return sym_sizes();
656 }
657
658 template <typename T>
generic_stridesTensorImpl659 ArrayRef<T> generic_strides() {
660 return _generic_strides(identity<T>());
661 }
662
_generic_stridesTensorImpl663 ArrayRef<int64_t> _generic_strides(identity<int64_t>) {
664 return strides();
665 }
_generic_stridesTensorImpl666 ArrayRef<c10::SymInt> _generic_strides(identity<c10::SymInt>) {
667 return sym_strides();
668 }
669
670 template <typename T>
generic_storage_offsetTensorImpl671 T generic_storage_offset() {
672 return _generic_storage_offset(identity<T>());
673 }
674
_generic_storage_offsetTensorImpl675 int64_t _generic_storage_offset(identity<int64_t>) {
676 return storage_offset();
677 }
_generic_storage_offsetTensorImpl678 c10::SymInt _generic_storage_offset(identity<c10::SymInt>) {
679 return sym_storage_offset();
680 }
681
682 /**
683 * The number of elements in a tensor.
684 *
685 * WARNING: Previously, if you were using the Caffe2 API, you could
686 * test numel() == -1 to see if a tensor was uninitialized. This
687 * is no longer true; numel always accurately reports the product
688 * of sizes of a tensor.
689 */
numelTensorImpl690 int64_t numel() const {
691 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
692 return numel_custom();
693 }
694 return numel_;
695 }
696
sym_numelTensorImpl697 c10::SymInt sym_numel() const {
698 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
699 return sym_numel_custom();
700 }
701 return c10::SymInt(SymInt::UNCHECKED, numel_);
702 }
703
numel_defaultTensorImpl704 int64_t numel_default() const {
705 if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
706 throw_cannot_call_with_symbolic("numel");
707 }
708 return numel_;
709 }
710
sym_numel_defaultTensorImpl711 c10::SymInt sym_numel_default() const {
712 if (has_symbolic_sizes_strides_) {
713 return symbolic_shape_meta().numel();
714 } else {
715 return c10::SymInt(SymInt::UNCHECKED, numel_);
716 }
717 }
718
719 /**
720 * Return the number of dimensions of this tensor. Note that 0-dimension
721 * represents a Tensor that is a Scalar, e.g., one that has a single element.
722 */
dimTensorImpl723 int64_t dim() const {
724 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
725 return dim_custom();
726 }
727 return static_cast<int64_t>(sizes_and_strides_.size());
728 }
729
dim_defaultTensorImpl730 int64_t dim_default() const {
731 if (has_symbolic_sizes_strides_) {
732 return static_cast<int64_t>(symbolic_shape_meta().sizes_.size());
733 } else {
734 return static_cast<int64_t>(sizes_and_strides_.size());
735 }
736 }
737
738 /**
739 * Return the offset in number of elements into the storage that this
740 * tensor points to. Most tensors have storage_offset() == 0, but,
741 * for example, an index into a tensor will have a non-zero storage_offset().
742 *
743 * WARNING: This is NOT computed in bytes.
744 */
storage_offsetTensorImpl745 int64_t storage_offset() const {
746 // TODO: maybe this should be toggled by strides
747 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
748 return storage_offset_custom();
749 }
750 return storage_offset_;
751 }
752
sym_storage_offsetTensorImpl753 c10::SymInt sym_storage_offset() const {
754 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
755 return sym_storage_offset_custom();
756 }
757 return c10::SymInt(SymInt::UNCHECKED, storage_offset_);
758 }
759
storage_offset_defaultTensorImpl760 int64_t storage_offset_default() const {
761 if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
762 throw_cannot_call_with_symbolic("storage_offset");
763 }
764 return storage_offset_;
765 }
766
sym_storage_offset_defaultTensorImpl767 c10::SymInt sym_storage_offset_default() const {
768 if (has_symbolic_sizes_strides_) {
769 return symbolic_shape_meta().storage_offset_;
770 } else {
771 return c10::SymInt(SymInt::UNCHECKED, storage_offset_);
772 }
773 }
774
775 /**
776 * Return a reference to the strides of this tensor. This reference remains
777 * valid as long as the tensor is live and not restrided.
778 */
stridesTensorImpl779 IntArrayRef strides() const {
780 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
781 return strides_custom();
782 }
783 return sizes_and_strides_.strides_arrayref();
784 }
785
sym_stridesTensorImpl786 c10::SymIntArrayRef sym_strides() const {
787 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
788 return sym_strides_custom();
789 }
790 return c10::fromIntArrayRefKnownNonNegative(strides_default());
791 }
792
strides_defaultTensorImpl793 IntArrayRef strides_default() const {
794 if (C10_UNLIKELY(has_symbolic_sizes_strides_)) {
795 throw_cannot_call_with_symbolic("strides");
796 }
797 return sizes_and_strides_.strides_arrayref();
798 }
799
sym_strides_defaultTensorImpl800 c10::SymIntArrayRef sym_strides_default() const {
801 if (has_symbolic_sizes_strides_) {
802 return symbolic_shape_meta().strides_;
803 } else {
804 return c10::fromIntArrayRefKnownNonNegative(strides_default());
805 }
806 }
807
808 /**
809 * Whether or not a tensor is laid out in contiguous memory.
810 *
811 * Tensors with non-trivial strides are not contiguous. See
812 * compute_contiguous() for the exact definition of whether or not
813 * a tensor is contiguous or not.
814 */
815 bool is_contiguous(
816 at::MemoryFormat memory_format = at::MemoryFormat::Contiguous) const {
817 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
818 return is_contiguous_custom(memory_format);
819 }
820 return is_contiguous_default(memory_format);
821 }
822
823 // These are factored into separate functions in case subclasses
824 // want to use them
is_contiguous_defaultTensorImpl825 bool is_contiguous_default(at::MemoryFormat memory_format) const {
826 if (has_symbolic_sizes_strides_) {
827 if (memory_format == at::MemoryFormat::ChannelsLast) {
828 return symbolic_shape_meta().is_channels_last_contiguous().guard_bool(
829 __FILE__, __LINE__);
830 } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
831 return symbolic_shape_meta()
832 .is_channels_last_3d_contiguous()
833 .guard_bool(__FILE__, __LINE__);
834 }
835 return symbolic_shape_meta().is_contiguous().guard_bool(
836 __FILE__, __LINE__);
837 }
838
839 if (memory_format == at::MemoryFormat::ChannelsLast) {
840 return is_channels_last_contiguous_;
841 } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
842 return is_channels_last_3d_contiguous_;
843 }
844 return is_contiguous_;
845 }
846
is_strides_like_defaultTensorImpl847 bool is_strides_like_default(at::MemoryFormat memory_format) const {
848 if (has_symbolic_sizes_strides_) {
849 if (memory_format == at::MemoryFormat::ChannelsLast) {
850 return symbolic_shape_meta().is_channels_last().guard_bool(
851 __FILE__, __LINE__);
852 } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
853 return symbolic_shape_meta().is_channels_last_3d().guard_bool(
854 __FILE__, __LINE__);
855 } else {
856 return false;
857 }
858 }
859
860 if (memory_format == at::MemoryFormat::ChannelsLast) {
861 return is_channels_last_;
862 } else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
863 return is_channels_last_3d_;
864 } else {
865 return false;
866 }
867 }
868
is_non_overlapping_and_dense_defaultTensorImpl869 bool is_non_overlapping_and_dense_default() const {
870 if (has_symbolic_sizes_strides_) {
871 return symbolic_shape_meta().is_non_overlapping_and_dense().guard_bool(
872 __FILE__, __LINE__);
873 } else {
874 return is_non_overlapping_and_dense_;
875 }
876 }
877
878 // NB: these dim accessor functions don't have _default(), as you can use
879 // sizes_default/strides_default
880 /**
881 * Return the size of a tensor at some dimension, wrapping the dimension if
882 * necessary.
883 *
884 * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will
885 * be faster
886 */
sizeTensorImpl887 int64_t size(int64_t d) const {
888 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
889 return size_custom(d);
890 }
891 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
892 return sizes_and_strides_.size_at_unchecked(d);
893 }
894
sym_sizeTensorImpl895 c10::SymInt sym_size(int64_t d) const {
896 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomSizes))) {
897 return sym_size_custom(d);
898 }
899 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
900 const auto sizes = this->sym_sizes();
901 return sizes[d];
902 }
903
904 /**
905 * Return the stride of a tensor at some dimension, wrapping the dimension
906 * if necessary.
907 *
908 * NOTE: if you know wrapping is unnecessary, do sizes()[d] instead; it will
909 * be faster
910 */
strideTensorImpl911 int64_t stride(int64_t d) const {
912 d = maybe_wrap_dim(d, dim(), false);
913 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
914 // TODO: provide stride_custom, symmetrically with size_custom.
915 // There is presently no user for it; only NestedTensor is using
916 // size_custom overrideability
917 return strides_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
918 }
919 // Intentionally don't call default, which also handles symbolic
920 return sizes_and_strides_.stride_at_unchecked(d);
921 }
922
923 enum class SizesStridesPolicy : uint8_t {
924 // Default behavior, e.g., dense tensor.
925 //
926 // Can override: nothing
927 Default = 0,
928 // Customizable strides behavior, e.g., sparse tensor,
929 // mkldnn tensor.
930 //
931 // Can override: strides(), is_contiguous()
932 CustomStrides = 1,
933 // Customizable sizes behavior, e.g., nested tensor
934 //
935 // Can override: strides(), is_contiguous(), sizes(), dim(), numel()
936 CustomSizes = 2
937 };
938
939 protected:
matches_policyTensorImpl940 inline bool matches_policy(SizesStridesPolicy policy) const {
941 return sizes_strides_policy_ >= static_cast<uint8_t>(policy);
942 }
943
matches_customTensorImpl944 inline bool matches_custom(SizesStridesPolicy policy) const {
945 return custom_sizes_strides_ >= static_cast<uint8_t>(policy);
946 }
947
matches_python_customTensorImpl948 inline bool matches_python_custom(SizesStridesPolicy policy) const {
949 auto r = python_custom_sizes_strides_ >= static_cast<uint8_t>(policy);
950 if (r) {
951 TORCH_INTERNAL_ASSERT(is_python_dispatch())
952 }
953 return r;
954 }
955
956 /**
957 * Customization points for the functions above. sizes_strides_policy_
958 * must be set to enable these.
959 *
960 * NB: dim is overrideable separately from sizes because it is possible
961 * for a tensor to have rank, but not well defined sizes.
962 */
963 // sizes_strides_policy_ >= CustomStrides
964 virtual bool is_contiguous_custom(at::MemoryFormat memory_format) const;
965 virtual bool is_strides_like_custom(at::MemoryFormat memory_format) const;
966 virtual bool is_non_overlapping_and_dense_custom() const;
967 // sizes_strides_policy_ >= CustomSizes
968 // Currently this method only exists to be overwritten by subclasses such as
969 // NestedTensorImpl.
size_customTensorImpl970 virtual int64_t size_custom(int64_t d) const {
971 // TODO: We could add support to Python dispatch here.
972 // TODO: We could call into aten::size.int instead of
973 // sizes_custom()[d] and enable use of the dispatcher.
974 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
975 return sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
976 }
977
sym_size_customTensorImpl978 virtual c10::SymInt sym_size_custom(int64_t d) const {
979 // TODO: We could add support to Python dispatch here.
980 // TODO: We could call into aten::size.int instead of
981 // sym_sizes_custom()[d] and enable use of the dispatcher.
982 d = maybe_wrap_dim(d, dim(), /*wrap_scalar=*/false);
983 return sym_sizes_custom()[d]; // unchecked (maybe_wrap_dim enforces bounds)
984 }
985
986 virtual IntArrayRef sizes_custom() const;
987 virtual IntArrayRef strides_custom() const;
988 virtual int64_t numel_custom() const;
989 virtual int64_t storage_offset_custom() const;
990 virtual int64_t dim_custom() const;
991 virtual Device device_custom() const;
992 virtual Layout layout_custom() const;
993
994 virtual c10::SymIntArrayRef sym_sizes_custom() const;
995 virtual c10::SymIntArrayRef sym_strides_custom() const;
996 virtual c10::SymInt sym_numel_custom() const;
997 virtual c10::SymInt sym_storage_offset_custom() const;
998
999 public:
1000 /**
1001 * True if this tensor has storage. See storage() for details.
1002 */
1003 #ifdef DEBUG
1004 // Allow subclasses to check that their storage_ is never getting set in debug
1005 // builds.
1006 virtual
1007 #else
1008 TENSORIMPL_MAYBE_VIRTUAL
1009 #endif
1010 bool
has_storageTensorImpl1011 has_storage() const
1012 // NOTE: we devirtualize this because it arguably shouldn't be an
1013 // error just to ask subclasses if they have storage.
1014 // This used to throw for most subclasses, but OpaqueTensorImpl
1015 // wanted it to successfully return false, so we went ahead and made
1016 // it a non-error.
1017 #ifdef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
1018 {
1019 return storage_;
1020 }
1021 #else
1022 ;
1023 #endif
1024
1025 /**
1026 * Return the underlying storage of a Tensor. Multiple tensors may share
1027 * a single storage. A Storage is an impoverished, Tensor-like class
1028 * which supports far less operations than Tensor.
1029 *
1030 * Avoid using this method if possible; try to use only Tensor APIs to perform
1031 * operations.
1032 */
storageTensorImpl1033 TENSORIMPL_MAYBE_VIRTUAL const Storage& storage() const {
1034 if (C10_UNLIKELY(storage_access_should_throw_)) {
1035 throw_storage_access_error();
1036 }
1037 return storage_;
1038 }
1039
1040 /**
1041 * Return the underlying storage, unsafely assuming this is a basic strided
1042 * tensor. In cases where `storage` access would throw, this returns a
1043 * default-constructed Storage.
1044 */
unsafe_storageTensorImpl1045 inline const Storage& unsafe_storage() const {
1046 return storage_;
1047 }
1048
unique_versionTensorImpl1049 bool unique_version() const {
1050 return version_counter_.unique();
1051 }
1052
1053 protected:
layout_implTensorImpl1054 virtual Layout layout_impl() const {
1055 TORCH_CHECK(
1056 false, "layout_impl is only implemented for TensorImpl subclasses.");
1057 }
1058
1059 public:
1060 // Whether a tensor is sparse COO or not.
is_sparseTensorImpl1061 bool is_sparse() const {
1062 // NB: This method is not virtual and avoid dispatches for performance
1063 // reasons.
1064 return key_set_.has_all(c10::sparse_ks);
1065 }
1066
1067 // Whether a tensor is sparse CSR or not.
is_sparse_csrTensorImpl1068 bool is_sparse_csr() const {
1069 return layout() == kSparseCsr;
1070 }
1071
1072 // Whether a tensor is sparse CSR/CSC/BSR/BSC or not.
is_sparse_compressedTensorImpl1073 bool is_sparse_compressed() const {
1074 return key_set_.has_all(c10::sparse_csr_ks);
1075 }
1076
is_quantizedTensorImpl1077 bool is_quantized() const {
1078 // NB: This method is not virtual and avoid dispatches for performance
1079 // reasons.
1080 constexpr auto quantized_ks = DispatchKeySet(DispatchKey::Quantized);
1081 return key_set_.has_all(quantized_ks);
1082 }
1083
is_metaTensorImpl1084 bool is_meta() const {
1085 // NB: This method is not virtual and avoid dispatches for performance
1086 // reasons.
1087 if (C10_UNLIKELY(device_policy_)) {
1088 return device_custom().is_meta();
1089 }
1090 return device_opt_.has_value() && device_opt_->type() == kMeta;
1091 }
1092
is_cpuTensorImpl1093 bool is_cpu() const {
1094 // NB: This method is not virtual and avoid dispatches for performance
1095 // reasons.
1096 if (C10_UNLIKELY(device_policy_)) {
1097 return device_custom().is_cpu();
1098 }
1099 // Note: we cannot rely on dispatch keys to determine the device type
1100 // of a tensor, because "wrapper" tensors (like FunctionalTensorWrapper)
1101 // don't include backend dispatch keys.
1102 return device_opt_.has_value() && device_opt_->type() == kCPU;
1103 }
1104
is_cudaTensorImpl1105 bool is_cuda() const {
1106 // NB: This method is not virtual and avoid dispatches for performance
1107 // reasons.
1108 if (C10_UNLIKELY(device_policy_)) {
1109 return device_custom().is_cuda();
1110 }
1111 return device_opt_.has_value() && device_opt_->type() == kCUDA;
1112 }
1113
is_xpuTensorImpl1114 bool is_xpu() const {
1115 // NB: This method is not virtual and avoid dispatches for performance
1116 // reasons.
1117 if (C10_UNLIKELY(device_policy_)) {
1118 return device_custom().is_xpu();
1119 }
1120 return device_opt_.has_value() && device_opt_->type() == kXPU;
1121 }
1122
is_ipuTensorImpl1123 bool is_ipu() const {
1124 if (C10_UNLIKELY(device_policy_)) {
1125 return device_custom().is_ipu();
1126 }
1127 return device_opt_.has_value() && device_opt_->type() == kIPU;
1128 }
1129
is_xlaTensorImpl1130 bool is_xla() const {
1131 if (C10_UNLIKELY(device_policy_)) {
1132 return device_custom().is_xla();
1133 }
1134 return device_opt_.has_value() && device_opt_->type() == kXLA;
1135 }
1136
is_mtiaTensorImpl1137 bool is_mtia() const {
1138 if (C10_UNLIKELY(device_policy_)) {
1139 return device_custom().is_mtia();
1140 }
1141 return device_opt_.has_value() && device_opt_->type() == kMTIA;
1142 }
1143
is_hpuTensorImpl1144 bool is_hpu() const {
1145 if (C10_UNLIKELY(device_policy_)) {
1146 return device_custom().is_hpu();
1147 }
1148 return device_opt_.has_value() && device_opt_->type() == kHPU;
1149 }
1150
is_lazyTensorImpl1151 bool is_lazy() const {
1152 if (C10_UNLIKELY(device_policy_)) {
1153 return device_custom().is_lazy();
1154 }
1155 return device_opt_.has_value() && device_opt_->type() == kLazy;
1156 }
1157
is_hipTensorImpl1158 bool is_hip() const {
1159 // NB: This method is not virtual and avoid dispatches for performance
1160 // reasons.
1161 if (C10_UNLIKELY(device_policy_)) {
1162 return device_custom().is_hip();
1163 }
1164 return device_opt_.has_value() && device_opt_->type() == kHIP;
1165 }
1166
is_veTensorImpl1167 bool is_ve() const {
1168 // NB: This method is not virtual and avoid dispatches for performance
1169 // reasons.
1170 if (C10_UNLIKELY(device_policy_)) {
1171 return device_custom().is_ve();
1172 }
1173 return device_opt_.has_value() && device_opt_->type() == kVE;
1174 }
1175
is_privateuseoneTensorImpl1176 bool is_privateuseone() const {
1177 // NB: This method is not virtual and avoid dispatches for performance
1178 // reasons.
1179 if (C10_UNLIKELY(device_policy_)) {
1180 return device_custom().is_privateuseone();
1181 }
1182 return device_opt_.has_value() && device_opt_->type() == kPrivateUse1;
1183 }
1184
is_mkldnnTensorImpl1185 bool is_mkldnn() const {
1186 return key_set_.has_all(c10::mkldnn_ks);
1187 }
1188
is_vulkanTensorImpl1189 bool is_vulkan() const {
1190 if (C10_UNLIKELY(device_policy_)) {
1191 return device_custom().is_vulkan();
1192 }
1193 return device_opt_.has_value() && device_opt_->type() == kVulkan;
1194 }
1195
is_metalTensorImpl1196 bool is_metal() const {
1197 if (C10_UNLIKELY(device_policy_)) {
1198 return device_custom().is_metal();
1199 }
1200 return device_opt_.has_value() && device_opt_->type() == kMetal;
1201 }
1202
is_mpsTensorImpl1203 bool is_mps() const {
1204 if (C10_UNLIKELY(device_policy_)) {
1205 return device_custom().is_mps();
1206 }
1207 return device_opt_.has_value() && device_opt_->type() == kMPS;
1208 }
1209
is_maiaTensorImpl1210 bool is_maia() const {
1211 if (C10_UNLIKELY(device_policy_)) {
1212 return device_custom().is_maia();
1213 }
1214 return device_opt_.has_value() && device_opt_->type() == kMAIA;
1215 }
1216
is_nestedTensorImpl1217 bool is_nested() const {
1218 return key_set_.has(DispatchKey::NestedTensor);
1219 }
1220
1221 // TODO: remove this once we don't automatically enabled Autograd dispatch
1222 // keys
1223 // in TensorImpl constructor.
1224 // DON'T USE THIS API!! It's only created for testing purpose in
1225 // file aten/src/ATen/core/boxing/impl/test_helpers.h
remove_autograd_keyTensorImpl1226 void remove_autograd_key() {
1227 key_set_ = key_set_ - autograd_dispatch_keyset;
1228 }
1229
1230 // Inference tensor doesn't have autograd or ADInplaceOrView key.
1231 // Invariant:
1232 // Inference tensor has version_counter_.enabled() == false
is_inferenceTensorImpl1233 bool is_inference() {
1234 bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks);
1235 bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset);
1236 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1237 no_ADInplaceOrView == no_Autograd,
1238 "ADInplaceOrView and Autograd keys must be on/off at the same time.");
1239 return no_ADInplaceOrView && no_Autograd;
1240 }
1241
get_deviceTensorImpl1242 DeviceIndex get_device() const {
1243 if (C10_UNLIKELY(device_policy_)) {
1244 return device_custom().index();
1245 }
1246 return device_default().index();
1247 }
1248
deviceTensorImpl1249 Device device() const {
1250 if (C10_UNLIKELY(device_policy_)) {
1251 return device_custom();
1252 }
1253 return device_default();
1254 }
1255
1256 protected:
device_defaultTensorImpl1257 c10::Device device_default() const {
1258 TORCH_CHECK(device_opt_.has_value(), "tensor does not have a device");
1259 // See NOTE [std::optional operator usage in CUDA]
1260 return *device_opt_;
1261 }
1262
1263 public:
layoutTensorImpl1264 Layout layout() const {
1265 if (C10_UNLIKELY(layout_policy_)) {
1266 return layout_custom();
1267 }
1268
1269 // NB: This method is not virtual and avoid dispatches for perf.
1270 // strided is also the most common layout type, so we check for
1271 // strided case first.
1272 // This keyset must also be kept in sync with the logic in
1273 // is_sparse() / is_sparse_csr() / is_mkldnn()
1274 constexpr auto sparse_and_sparsecsr_and_mkldnn_ks =
1275 c10::sparse_ks | c10::sparse_csr_ks | c10::mkldnn_ks;
1276 if (!key_set_.has_any(sparse_and_sparsecsr_and_mkldnn_ks)) {
1277 return kStrided;
1278 } else if (is_sparse()) {
1279 return kSparse;
1280 } else if (is_sparse_compressed()) {
1281 // Typically, the tensor dispatch keys define the tensor layout
1282 // uniquely. This allows using non-virtual layout method for
1283 // better performance. However, when tensor's layout depends,
1284 // say, on tensor attributes, one must use this execution path
1285 // where the corresponding tensor impl class overwrites virtual
1286 // layout_impl() method.
1287 //
1288 // TODO: implement layout() as native function/method so that
1289 // __torch_dispatch__ users will be able to redefine the
1290 // layout() method.
1291 return layout_impl();
1292 } else {
1293 TORCH_INTERNAL_ASSERT(
1294 is_mkldnn(), "There is an error in the layout calculation logic.");
1295 return kMkldnn;
1296 }
1297 }
1298
1299 /**
1300 * True if a tensor was auto-wrapped from a C++ or Python number.
1301 * For example, when you write 't + 2', 2 is auto-wrapped into a Tensor
1302 * with `is_wrapped_number_` set to true.
1303 *
1304 * Wrapped numbers do not participate in the result type computation for
1305 * mixed-type operations if there are any Tensors that are not wrapped
1306 * numbers. This is useful, because we want 't + 2' to work with
1307 * any type of tensor, not just LongTensor (which is what integers
1308 * in Python represent).
1309 *
1310 * Otherwise, they behave like their non-wrapped equivalents.
1311 * See [Result type computation] in TensorIterator.h.
1312 *
1313 * Why did we opt for wrapped numbers, as opposed to just having
1314 * an extra function add(Tensor, Scalar)? This helps greatly reduce
1315 * the amount of code we have to write for add, when actually
1316 * a Tensor-Scalar addition is really just a Tensor-Tensor
1317 * addition when the RHS is 0-dim (except for promotion behavior.)
1318 */
is_wrapped_numberTensorImpl1319 bool is_wrapped_number() const {
1320 return is_wrapped_number_;
1321 }
1322
1323 /**
1324 * Set whether or not a tensor was auto-wrapped from a C++ or Python
1325 * number. You probably don't want to call this, unless you are
1326 * writing binding code.
1327 */
set_wrapped_numberTensorImpl1328 void set_wrapped_number(bool value) {
1329 TORCH_INTERNAL_ASSERT(dim() == 0);
1330 is_wrapped_number_ = value;
1331 }
1332
1333 /**
1334 * Returns true if Tensor supports as_strided and as_strided_backward.
1335 * This is used in autograd to perform inplace update on view Tensors.
1336 * See Note [View + Inplace update for base tensor] and
1337 * [View + Inplace update for view tensor] for details.
1338 * Note this method only returns true for XLA backend, where it
1339 * simulates strided Tensor to support most view ops, but it cannot
1340 * fully support general `as_strided` case.
1341 * It can be expanded as needed in the future, e.g sparse Tensor.
1342 */
support_as_stridedTensorImpl1343 inline bool support_as_strided() const {
1344 if (is_nested()) {
1345 return false;
1346 }
1347 if (key_set_.has(DispatchKey::Functionalize)) {
1348 return false;
1349 }
1350 return device().supports_as_strided();
1351 }
1352
1353 // ~~~~~ Autograd API ~~~~~
1354 // Some methods below are defined in TensorImpl.cpp because Tensor is an
1355 // incomplete type.
1356
1357 /**
1358 * Set whether or not a tensor requires gradient.
1359 */
1360 void set_requires_grad(bool requires_grad);
1361
1362 /**
1363 * True if a tensor requires gradient. Tensors which require gradient
1364 * have history tracked for any operations performed on them, so that
1365 * we can automatically differentiate back to them. A tensor that
1366 * requires gradient and has no history is a "leaf" tensor, which we
1367 * accumulate gradients into.
1368 */
1369 bool requires_grad() const;
1370
1371 /**
1372 * Return a mutable reference to the gradient. This is conventionally
1373 * used as `t.grad() = x` to set a gradient to a completely new tensor.
1374 */
1375 at::Tensor& mutable_grad();
1376
1377 /**
1378 * Return the accumulated gradient of a tensor. This gradient is written
1379 * into when performing backwards, when this tensor is a leaf tensor.
1380 */
1381 const at::Tensor& grad() const;
1382
1383 /**
1384 * Whether or not the imaginary part of the tensor should be negated
1385 */
is_conjTensorImpl1386 inline bool is_conj() const {
1387 constexpr auto conjugate_ks = DispatchKeySet(DispatchKey::Conjugate);
1388 return key_set_.has_all(conjugate_ks);
1389 }
1390
1391 /**
1392 * Set whether or not to take the conjugate of the tensor (flip the imaginary
1393 * bit).
1394 */
_set_conjTensorImpl1395 void _set_conj(bool value) {
1396 if (value) {
1397 key_set_ = key_set_.add(DispatchKey::Conjugate);
1398 TORCH_INTERNAL_ASSERT(isComplexType(typeMetaToScalarType(dtype())));
1399 } else {
1400 key_set_ = key_set_.remove(DispatchKey::Conjugate);
1401 }
1402 }
1403
1404 /**
1405 * XXX: do not use, private api!
1406 * Update the backend component related keys to the backend component
1407 * corresponding to this device.
1408 */
1409 void _change_backend_component_keys(c10::Device device);
1410
1411 /**
1412 * Whether or not the tensor is a zerotensor
1413 */
_is_zerotensorTensorImpl1414 inline bool _is_zerotensor() const {
1415 constexpr auto zerotensor_ks = DispatchKeySet(DispatchKey::ZeroTensor);
1416 return key_set_.has_all(zerotensor_ks);
1417 }
1418
1419 /**
1420 Set whether or not the tensor is a zero tensor
1421 */
_set_zeroTensorImpl1422 void _set_zero(bool value) {
1423 if (value) {
1424 TORCH_INTERNAL_ASSERT(
1425 false,
1426 "Please call `torch._efficientzerotensor` if you want to create a tensor with no storage.");
1427 } else {
1428 key_set_ = key_set_.remove(DispatchKey::ZeroTensor);
1429 }
1430 }
1431
1432 /**
1433 * Whether or not the tensor should be negated
1434 */
is_negTensorImpl1435 inline bool is_neg() const {
1436 constexpr auto negative_ks = DispatchKeySet(DispatchKey::Negative);
1437 return key_set_.has_all(negative_ks);
1438 }
1439
1440 /**
1441 * Set whether or not to take the conjugate of the tensor (flip the imaginary
1442 * bit).
1443 */
_set_negTensorImpl1444 void _set_neg(bool value) {
1445 if (value) {
1446 key_set_ = key_set_.add(DispatchKey::Negative);
1447 } else {
1448 key_set_ = key_set_.remove(DispatchKey::Negative);
1449 }
1450 }
1451
1452 /**
1453 * Return the accumulated gradient of a tensor. This gradient is computed
1454 * using forward mode AD.
1455 *
1456 * This is an internal API that should never be used by end users.
1457 *
1458 * The API is as follows:
1459 * - "level" allows to specify the level of forward AD nesting for which the
1460 * gradient should be returned. Note that since levels are not fully
1461 * supported yet, this argument should be 0. See documentation for
1462 * torch::autograd::enter_dual_level for more details about forward AD
1463 * nesting.
1464 * - "self" should represent the Tensor whose forward grad is accessed. It
1465 * is required when dealing with view.
1466 */
1467 const at::Tensor& _fw_grad(uint64_t level, const at::TensorBase& self) const;
1468
1469 /**
1470 * Sets the forward gradient for this Tensor.
1471 * The given Tensor might not be used directly and its content will be copied.
1472 *
1473 * This is an internal API that should never be used by end users.
1474 *
1475 * The API is as follows:
1476 * - "new_grad" is a Tensor containing the new value of the gradient that
1477 * should be set
1478 * - "self" should represent the Tensor whose forward grad is accessed. It
1479 * is required when dealing with view.
1480 * - "level" allows to specify the level of forward AD nesting for which the
1481 * gradient should be set. Note that since levels are not fully supported
1482 * yet, this argument should be 0. See documentation for
1483 * torch::autograd::enter_dual_level for more details about forward AD
1484 * nesting.
1485 * - "is_inplace_op" is a boolean flag that tells if this gradient was
1486 * generated by an inplace operation or an out of place one. This allows
1487 * better error checking.
1488 */
1489 void _set_fw_grad(
1490 const at::TensorBase& new_grad,
1491 const at::TensorBase& self,
1492 uint64_t level,
1493 bool is_inplace_op);
1494
1495 /**
1496 * Return a typed data pointer to the actual data which this tensor refers to.
1497 * This checks that the requested type (from the template parameter) matches
1498 * the internal type of the tensor.
1499 *
1500 * It is invalid to call data() on a dtype-uninitialized tensor, even if
1501 * the size is 0.
1502 *
1503 * WARNING: If a tensor is not contiguous, you MUST use strides when
1504 * performing index calculations to determine the location of elements in
1505 * the tensor. We recommend using 'TensorAccessor' to handle this computation
1506 * for you; this class is available from 'Tensor'.
1507 */
1508 template <typename T>
data_dtype_initializedTensorImpl1509 const T* data_dtype_initialized() const {
1510 return data_dtype_initialized_impl<const T>(
1511 [this] { return static_cast<const T*>(storage_.data()); });
1512 }
1513
1514 /**
1515 * Return a mutable typed data pointer to the actual data which this
1516 * tensor refers to. This checks that the requested type (from the
1517 * template parameter) matches the internal type of the tensor.
1518 *
1519 * It is invalid to call data() on a dtype-uninitialized tensor, even if
1520 * the size is 0.
1521 *
1522 * WARNING: If a tensor is not contiguous, you MUST use strides when
1523 * performing index calculations to determine the location of elements in
1524 * the tensor. We recommend using 'TensorAccessor' to handle this computation
1525 * for you; this class is available from 'Tensor'.
1526 */
1527 template <typename T>
mutable_data_dtype_initializedTensorImpl1528 T* mutable_data_dtype_initialized() {
1529 return data_dtype_initialized_impl<T>(
1530 [this] { return static_cast<T*>(storage_.mutable_data()); });
1531 }
1532
1533 private:
1534 // Shared implementation of data_dtype_initialized() and
1535 // mutable_data_dtype_initialized().
1536 template <typename T, typename Func>
data_dtype_initialized_implTensorImpl1537 T* data_dtype_initialized_impl(const Func& get_data) const {
1538 TORCH_CHECK(
1539 data_type_.Match<std::remove_const_t<T>>(),
1540 "Tensor type mismatch, caller expects elements to be ",
1541 caffe2::TypeMeta::TypeName<std::remove_const_t<T>>(),
1542 ", while tensor contains ",
1543 data_type_.name(),
1544 ". ");
1545 return data_ptr_impl_impl<T>(get_data);
1546 }
1547
1548 public:
1549 /**
1550 * More efficient helper for Tensor::data_ptr(). Like data<T>(), but
1551 * does not do a type check. Unlike the untemplated data(), does
1552 * check has_storage() and storage_initialized().
1553 */
1554 template <typename T>
data_ptr_implTensorImpl1555 inline const T* data_ptr_impl() const {
1556 return data_ptr_impl_impl<const T>(
1557 [this] { return static_cast<const T*>(storage_.data()); });
1558 }
1559
1560 /**
1561 * More efficient helper for Tensor::data_ptr(). Like data<T>(), but
1562 * does not do a type check. Unlike the untemplated data(), does
1563 * check has_storage() and storage_initialized().
1564 */
1565 template <typename T>
mutable_data_ptr_implTensorImpl1566 inline T* mutable_data_ptr_impl() {
1567 return data_ptr_impl_impl<T>(
1568 [this] { return static_cast<T*>(storage_.mutable_data()); });
1569 }
1570
1571 private:
1572 // Shared implementation of mutable_data_ptr_impl() and the future
1573 // mutable_data_ptr_impl().
1574 template <typename T, typename Func>
data_ptr_impl_implTensorImpl1575 __ubsan_ignore_pointer_overflow__ T* data_ptr_impl_impl(
1576 const Func& get_data) const {
1577 if (C10_UNLIKELY(!has_storage())) {
1578 throw_data_ptr_access_error();
1579 }
1580 TORCH_CHECK(
1581 storage_initialized(),
1582 "The tensor has a non-zero number of elements, but its data is not allocated yet.\n"
1583 "If you're using torch.compile/export/fx, it is likely that we are erroneously "
1584 "tracing into a custom kernel. To fix this, please wrap the custom kernel into "
1585 "an opaque custom op. Please see the following for details: "
1586 "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html\n"
1587 "If you're using Caffe2, Caffe2 uses a lazy allocation, so you will need to call "
1588 "mutable_data() or raw_mutable_data() to actually allocate memory.");
1589 // Caller does the type check.
1590 // Note: storage_offset_ can be non-null even for zero-elements tensors
1591 // (for example if created as `torch.empty(5)[10:]`) that triggers
1592 // applying non-zero offset to null pointer in UBSan
1593 return get_data() + storage_offset_;
1594 }
1595
1596 public:
1597 /**
1598 * Return a const void* data pointer to the actual data which this
1599 * tensor refers to.
1600 *
1601 * It is invalid to call data() on a dtype-uninitialized tensor, even if the
1602 * size is 0.
1603 *
1604 * WARNING: The data pointed to by this tensor may not contiguous; do NOT
1605 * assume that itemsize() * numel() is sufficient to compute the bytes that
1606 * can be validly read from this tensor.
1607 */
dataTensorImpl1608 inline const void* data() const {
1609 return data_impl<const void>(
1610 [this] { return static_cast<const char*>(storage_.data()); });
1611 }
1612
1613 /**
1614 * Return a void* data pointer to the actual data which this tensor refers to.
1615 *
1616 * It is invalid to call mutable_data() on a dtype-uninitialized
1617 * tensor, even if the size is 0.
1618 *
1619 * WARNING: The data pointed to by this tensor may not contiguous; do NOT
1620 * assume that itemsize() * numel() is sufficient to compute the bytes that
1621 * can be validly read from this tensor.
1622 */
mutable_dataTensorImpl1623 inline void* mutable_data() {
1624 return data_impl<void>(
1625 [this] { return static_cast<char*>(storage_.mutable_data()); });
1626 }
1627
1628 private:
1629 /// Shared implementation of data() and mutable_data().
1630 ///
1631 /// get_data must return a byte-addressed pointer, e.g. char*,
1632 /// std::byte const*, etc.
1633 template <typename Void, typename Func>
data_implTensorImpl1634 Void* data_impl(const Func& get_data) const {
1635 if (C10_UNLIKELY(!has_storage())) {
1636 throw_data_ptr_access_error();
1637 }
1638 TORCH_CHECK(
1639 dtype_initialized(),
1640 "Cannot access data pointer of Tensor that doesn't have initialized dtype "
1641 "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data<T>() on x)");
1642 auto* data = get_data();
1643 static_assert(
1644 sizeof(*data) == 1, "get_data must return a byte-addressed pointer.");
1645 // Computing an offset into an empty tensor would be UB, since an empty
1646 // tensor's storage will be nullptr, and adding a nonzero offset to nullptr
1647 // is UB. So we skip the offset computation in this case.
1648 if (is_empty()) {
1649 return nullptr;
1650 }
1651 return data + data_type_.itemsize() * storage_offset_;
1652 }
1653
1654 public:
1655 /**
1656 * Returns the TypeMeta of a tensor, which describes what data type
1657 * it is (e.g., int, float, ...)
1658 */
dtypeTensorImpl1659 const caffe2::TypeMeta dtype() const {
1660 return data_type_;
1661 }
1662
1663 /**
1664 * Return the size of a single element of this tensor in bytes.
1665 */
itemsizeTensorImpl1666 size_t itemsize() const {
1667 TORCH_CHECK(
1668 dtype_initialized(),
1669 "Cannot report itemsize of Tensor that doesn't have initialized dtype "
1670 "(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data<T>() on x)");
1671 return data_type_.itemsize();
1672 }
1673
set_backend_metaTensorImpl1674 void set_backend_meta(intrusive_ptr<c10::BackendMeta> backend_meta) {
1675 get_extra_meta().backend_meta_ = std::move(backend_meta);
1676 }
1677
get_backend_metaTensorImpl1678 c10::BackendMeta* get_backend_meta() {
1679 if (!extra_meta_) {
1680 return nullptr;
1681 }
1682 return extra_meta_->backend_meta_.get();
1683 }
1684
get_backend_meta_intrusive_ptrTensorImpl1685 intrusive_ptr<c10::BackendMeta> get_backend_meta_intrusive_ptr() const {
1686 if (!extra_meta_) {
1687 return nullptr;
1688 }
1689 return extra_meta_->backend_meta_;
1690 }
1691
release_storage_and_set_meta_custom_data_ptr_error_msg_TensorImpl1692 void release_storage_and_set_meta_custom_data_ptr_error_msg_(
1693 std::optional<std::string> s) {
1694 storage_ = {};
1695 set_storage_access_should_throw();
1696 get_extra_meta().custom_data_ptr_error_msg_ = s;
1697 get_extra_meta().custom_storage_error_msg_ = std::move(s);
1698 }
1699
1700 protected:
1701 /**
1702 * Returns the human-readable name of the actual type of this object (e.g.,
1703 * TensorImpl, BatchedTensorImpl, etc.). Used for error messages.
1704 */
tensorimpl_type_nameTensorImpl1705 virtual const char* tensorimpl_type_name() const {
1706 return "TensorImpl";
1707 }
1708
1709 private:
1710 [[noreturn]] void throw_storage_access_error() const;
1711 [[noreturn]] void throw_data_ptr_access_error() const;
1712
get_extra_metaTensorImpl1713 ExtraMeta& get_extra_meta() {
1714 if (!extra_meta_) {
1715 extra_meta_ = std::make_unique<ExtraMeta>();
1716 }
1717 return *extra_meta_;
1718 }
1719
symbolic_shape_metaTensorImpl1720 c10::SymbolicShapeMeta& symbolic_shape_meta() {
1721 TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_);
1722 return *extra_meta_->symbolic_shape_meta_;
1723 }
1724
symbolic_shape_metaTensorImpl1725 const c10::SymbolicShapeMeta& symbolic_shape_meta() const {
1726 TORCH_INTERNAL_ASSERT(extra_meta_ && extra_meta_->symbolic_shape_meta_);
1727 return *extra_meta_->symbolic_shape_meta_;
1728 }
1729
1730 public:
1731 /**
1732 * True if a tensor has no elements (e.g., numel() == 0).
1733 */
is_emptyTensorImpl1734 inline bool is_empty() const {
1735 return numel() == 0;
1736 }
1737
1738 // if we are going to use sym sizes, we should be setting sym strides at the
1739 // same time, otherwise it's very easy to misuse this API
1740 void set_sizes_and_strides(
1741 c10::SymIntArrayRef sizes,
1742 c10::SymIntArrayRef strides,
1743 std::optional<c10::SymInt> storage_offset = std::nullopt);
1744 // This is renamed to avoid breaking overload BC
1745 void generic_set_sizes_contiguous(c10::SymIntArrayRef sizes);
generic_set_sizes_contiguousTensorImpl1746 void generic_set_sizes_contiguous(c10::IntArrayRef sizes) {
1747 set_sizes_contiguous(sizes);
1748 }
1749
1750 /**
1751 * Change the size at some dimension. This DOES NOT update strides;
1752 * thus, most changes to size will not preserve contiguity. You probably
1753 * also want to call set_stride() when you call this.
1754 *
1755 * TODO: This should be jettisoned in favor of `set_sizes_and_strides`,
1756 * which is harder to misuse.
1757 */
set_sizeTensorImpl1758 virtual void set_size(int64_t dim, int64_t new_size) {
1759 TORCH_CHECK(
1760 allow_tensor_metadata_change(),
1761 "set_size ",
1762 err_msg_tensor_metadata_change_not_allowed);
1763 TORCH_CHECK(
1764 !matches_policy(SizesStridesPolicy::CustomSizes),
1765 "set_size() called on tensor with dynamic shapes or customized size behavior")
1766 sizes_and_strides_.size_at(dim) = new_size;
1767 refresh_numel();
1768 refresh_contiguous();
1769 }
1770
1771 /**
1772 * Change the stride at some dimension.
1773 *
1774 * TODO: This should be jettisoned in favor of `set_sizes_and_strides`,
1775 * which is harder to misuse.
1776 */
set_strideTensorImpl1777 virtual void set_stride(int64_t dim, int64_t new_stride) {
1778 TORCH_CHECK(
1779 allow_tensor_metadata_change(),
1780 "set_stride ",
1781 err_msg_tensor_metadata_change_not_allowed);
1782 TORCH_CHECK(
1783 !has_symbolic_sizes_strides_,
1784 "set_stride() called on tensor with symbolic shape")
1785 sizes_and_strides_.stride_at_unchecked(dim) = new_stride;
1786 refresh_contiguous();
1787 }
1788
1789 /**
1790 * Set the offset into the storage of this tensor.
1791 *
1792 * WARNING: This does NOT check if the tensor is in bounds for the new
1793 * location at the storage; the caller is responsible for checking this
1794 * (and resizing if necessary.)
1795 */
set_storage_offsetTensorImpl1796 virtual void set_storage_offset(int64_t storage_offset) {
1797 TORCH_CHECK(
1798 allow_tensor_metadata_change(),
1799 "set_storage_offset ",
1800 err_msg_tensor_metadata_change_not_allowed);
1801 // TODO: this should probably consult policy
1802 TORCH_CHECK(
1803 !has_symbolic_sizes_strides_,
1804 "set_storage_offset() called on tensor with symbolic shape")
1805 storage_offset_ = storage_offset;
1806 }
1807
1808 /**
1809 * Like set_sizes_and_strides but assumes contiguous strides.
1810 *
1811 * WARNING: This function does not check if the requested
1812 * sizes/strides are in bounds for the storage that is allocated;
1813 * this is the responsibility of the caller
1814 */
set_sizes_contiguousTensorImpl1815 void set_sizes_contiguous(IntArrayRef new_size) {
1816 TORCH_CHECK(
1817 allow_tensor_metadata_change(),
1818 "set_sizes_contiguous ",
1819 err_msg_tensor_metadata_change_not_allowed);
1820 TORCH_CHECK(
1821 !matches_policy(SizesStridesPolicy::CustomStrides),
1822 "tried to directly modify sizes for customized tensor");
1823 sizes_and_strides_.set_sizes(new_size);
1824
1825 refresh_numel();
1826 empty_tensor_restride(
1827 MemoryFormat::Contiguous); // calls refresh_contiguous()
1828 }
1829
1830 /**
1831 * Set the sizes and strides of a tensor.
1832 *
1833 * WARNING: This function does not check if the requested
1834 * sizes/strides are in bounds for the storage that is allocated;
1835 * this is the responsibility of the caller
1836 */
1837 void set_sizes_and_strides(
1838 IntArrayRef new_size,
1839 IntArrayRef new_stride,
1840 std::optional<int64_t> storage_offset = std::nullopt) {
1841 TORCH_CHECK(
1842 allow_tensor_metadata_change(),
1843 "set_sizes_and_strides ",
1844 err_msg_tensor_metadata_change_not_allowed);
1845 TORCH_CHECK(
1846 !has_symbolic_sizes_strides_,
1847 "set_sizes_and_strides() called on tensor with symbolic shape")
1848 TORCH_CHECK(
1849 new_size.size() == new_stride.size(),
1850 "dimensionality of sizes (",
1851 new_size.size(),
1852 ") must match dimensionality of strides (",
1853 new_stride.size(),
1854 ")");
1855 const auto new_dim = new_size.size();
1856 bool overflowed = false;
1857 sizes_and_strides_.set_sizes(new_size);
1858
1859 if (new_dim > 0) {
1860 for (size_t dim = new_dim - 1;; dim--) {
1861 if (new_stride[dim] >= 0) {
1862 sizes_and_strides_.stride_at_unchecked(dim) = new_stride[dim];
1863 } else {
1864 // XXX: This behavior is surprising and may need to be removed to
1865 // support negative strides. Some pytorch functions rely on it:
1866 // for example, torch.cat (run TestTorch.test_cat_empty).
1867 if (dim == new_dim - 1) {
1868 sizes_and_strides_.stride_at_unchecked(dim) = 1;
1869 } else {
1870 // Keep stride monotonically increasing to match NumPy.
1871 overflowed |= c10::mul_overflows(
1872 sizes_and_strides_.stride_at_unchecked(dim + 1),
1873 std::max<int64_t>(
1874 sizes_and_strides_.size_at_unchecked(dim + 1), 1),
1875 std::addressof(sizes_and_strides_.stride_at_unchecked(dim)));
1876 }
1877 }
1878 if (dim == 0)
1879 break;
1880 }
1881 TORCH_CHECK(!overflowed, "Stride calculation overflowed");
1882 }
1883
1884 refresh_numel();
1885 refresh_contiguous();
1886
1887 if (storage_offset.has_value()) {
1888 storage_offset_ = *storage_offset;
1889 }
1890 }
1891
1892 /**
1893 * Set whether a tensor allows changes to its metadata (e.g. sizes / strides /
1894 * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor
1895 * ] for details.
1896 */
set_allow_tensor_metadata_changeTensorImpl1897 void set_allow_tensor_metadata_change(bool value [[maybe_unused]]) {
1898 // TODO: at some point, we should kill this field completely.
1899 allow_tensor_metadata_change_ = true;
1900 }
1901
1902 /**
1903 * True if a tensor allows changes to its metadata (e.g. sizes / strides /
1904 * storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor
1905 * ] for details.
1906 */
allow_tensor_metadata_changeTensorImpl1907 bool allow_tensor_metadata_change() const {
1908 return allow_tensor_metadata_change_;
1909 }
1910
1911 /**
1912 * Set the pointer to autograd metadata.
1913 */
1914 void set_autograd_meta(
1915 std::unique_ptr<c10::AutogradMetaInterface> autograd_meta);
1916
1917 /**
1918 * Return the pointer to autograd metadata. May return nullptr if the
1919 * tensor does not track gradients.
1920 */
1921 c10::AutogradMetaInterface* autograd_meta() const;
1922
1923 /**
1924 * Set the pointer to named tensor metadata.
1925 */
set_named_tensor_metaTensorImpl1926 void set_named_tensor_meta(
1927 std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta) {
1928 TORCH_WARN_ONCE(
1929 "Named tensors and all their associated APIs are an experimental feature ",
1930 "and subject to change. Please do not use them for anything important ",
1931 "until they are released as stable.");
1932 #ifdef DEBUG
1933 if (named_tensor_meta) {
1934 TORCH_INTERNAL_ASSERT(named_tensor_meta->slow_dim() == dim());
1935 }
1936 #endif
1937 if (named_tensor_meta) {
1938 get_extra_meta().named_tensor_meta_ = std::move(named_tensor_meta);
1939 key_set_ = key_set_.add(DispatchKey::Named);
1940 } else {
1941 if (extra_meta_) {
1942 extra_meta_->named_tensor_meta_ = nullptr;
1943 }
1944 key_set_ = key_set_.remove(DispatchKey::Named);
1945 }
1946 }
1947
set_python_dispatchTensorImpl1948 void set_python_dispatch(bool k) {
1949 if (k) {
1950 key_set_ = key_set_.add(c10::python_ks);
1951 } else {
1952 key_set_ = key_set_ - c10::python_ks;
1953 }
1954 }
1955
is_python_dispatchTensorImpl1956 bool is_python_dispatch() const {
1957 return key_set_.has_all(c10::python_ks);
1958 }
1959
1960 /**
1961 * Return the pointer to named tensor metadata.
1962 */
named_tensor_metaTensorImpl1963 const c10::NamedTensorMetaInterface* named_tensor_meta() const {
1964 if (!extra_meta_) {
1965 return nullptr;
1966 }
1967 return extra_meta_->named_tensor_meta_.get();
1968 }
1969
named_tensor_metaTensorImpl1970 c10::NamedTensorMetaInterface* named_tensor_meta() {
1971 if (!extra_meta_) {
1972 return nullptr;
1973 }
1974 return extra_meta_->named_tensor_meta_.get();
1975 }
1976
has_named_tensor_metaTensorImpl1977 bool has_named_tensor_meta() const {
1978 if (!extra_meta_) {
1979 return false;
1980 }
1981 return extra_meta_->named_tensor_meta_ != nullptr;
1982 }
1983
1984 // NOTE [ TensorImpl Shallow-Copying ]
1985 //
1986 // TensorImpl shallow-copying is used when we want to have two Variables share
1987 // the same tensor metadata (e.g. sizes / strides / storage pointer /
1988 // storage_offset), but each with a different autograd history. Example call
1989 // sites:
1990 //
1991 // 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create
1992 // `var_detached` that shares the same tensor metadata with `var`, but with a
1993 // completely new autograd history.
1994 // 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor
1995 // metadata from `tensor` into `var`, while keeping `var`'s original
1996 // AutogradMeta.
1997 //
1998 // Functions that shallow-copy a TensorImpl (such as
1999 // `shallow_copy_and_detach()` / `shallow_copy_from()` /
2000 // `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes /
2001 // strides / storage pointer / storage_offset) by value. However, the
2002 // following fields are not copied:
2003 //
2004 // 1. the AutogradMeta pointer, because it is unique for each Variable.
2005 // 2. the version counter, because the destination TensorImpl's version
2006 // counter is either set to the passed-in `version_counter` (in
2007 // `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept
2008 // intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for
2009 // details.
2010 //
2011 // In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in
2012 // `allow_tensor_metadata_change` determines whether the TensorImpl
2013 // shallow-copy allows changes to its metadata (e.g. sizes / strides / storage
2014 // / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for
2015 // details.
2016 //
2017 // In `shallow_copy_from()`, we don't check the destination TensorImpl's
2018 // `allow_tensor_metadata_change_`, because `shallow_copy_from()` is used for
2019 // implementing functions such as `var.set_data(tensor)`, which changes
2020 // `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to
2021 // be ignored.
2022
2023 /**
2024 * One TensorImpl can be copied to another TensorImpl if they have the same
2025 * DispatchKeySet. The only two special cases (for legacy reason) are:
2026 * CPU is compatible with CUDA and SparseCPU is
2027 * compatible with SparseCUDA.
2028 */
has_compatible_shallow_copy_typeTensorImpl2029 inline bool has_compatible_shallow_copy_type(DispatchKeySet from) {
2030 auto is_dense = [](DispatchKeySet ts) {
2031 constexpr auto dense_backends = DispatchKeySet(
2032 {BackendComponent::CPUBit,
2033 BackendComponent::CUDABit,
2034 BackendComponent::MPSBit,
2035 BackendComponent::HIPBit,
2036 BackendComponent::XPUBit,
2037 BackendComponent::HPUBit});
2038 constexpr auto dense_k = DispatchKeySet(DispatchKey::Dense);
2039 return ts.has_any(dense_k) && ts.has_any(dense_backends);
2040 };
2041 auto is_sparse = [](DispatchKeySet ts) {
2042 constexpr auto sparse_backends = DispatchKeySet(
2043 {BackendComponent::CPUBit,
2044 BackendComponent::CUDABit,
2045 BackendComponent::HIPBit,
2046 BackendComponent::XPUBit});
2047 constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
2048 return ts.has_any(sparse_k) && ts.has_any(sparse_backends);
2049 };
2050 auto is_sparse_compressed = [](DispatchKeySet ts) {
2051 constexpr auto sparse_compressed_k =
2052 DispatchKeySet(DispatchKey::SparseCsr);
2053 return ts.has_any(sparse_compressed_k);
2054 };
2055 return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) ||
2056 (is_sparse(key_set_) && is_sparse(from)) ||
2057 (is_sparse_compressed(key_set_) && is_sparse_compressed(from));
2058 ;
2059 }
2060
2061 private:
2062 template <typename VariableVersion>
2063 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach_core(
2064 VariableVersion&& version_counter,
2065 bool allow_tensor_metadata_change) const;
2066
2067 public:
2068 /**
2069 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
2070 *
2071 * For usage of `version_counter` and `allow_tensor_metadata_change`,
2072 * see NOTE [ TensorImpl Shallow-Copying ].
2073 */
2074 virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
2075 const c10::VariableVersion& version_counter,
2076 bool allow_tensor_metadata_change) const;
2077
2078 /**
2079 * Return a TensorImpl that is a shallow-copy of this TensorImpl.
2080 *
2081 * For usage of `version_counter` and `allow_tensor_metadata_change`,
2082 * see NOTE [ TensorImpl Shallow-Copying ].
2083 */
2084 virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
2085 c10::VariableVersion&& version_counter,
2086 bool allow_tensor_metadata_change) const;
2087
2088 /**
2089 * Shallow-copies data from another TensorImpl into this TensorImpl.
2090 *
2091 * For why this function doesn't check this TensorImpl's
2092 * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ].
2093 */
shallow_copy_fromTensorImpl2094 virtual void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
2095 copy_tensor_metadata(
2096 /*src_impl=*/impl.get(),
2097 /*dest_impl=*/this,
2098 /*version_counter=*/version_counter(),
2099 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
2100 }
2101
2102 // Inference tensor doesn't have version counter,
2103 // set_version_counter is no-op for them.
set_version_counterTensorImpl2104 void set_version_counter(const c10::VariableVersion& version_counter) {
2105 TORCH_CHECK(
2106 !(is_inference() && version_counter.enabled()),
2107 "Cannot set version_counter for inference tensor");
2108 version_counter_ = version_counter;
2109 }
2110
set_version_counterTensorImpl2111 void set_version_counter(c10::VariableVersion&& version_counter) {
2112 TORCH_CHECK(
2113 !(is_inference() && version_counter.enabled()),
2114 "Cannot set version_counter for inference tensor");
2115 version_counter_ = std::move(version_counter);
2116 }
2117
version_counterTensorImpl2118 const c10::VariableVersion& version_counter() const noexcept {
2119 return version_counter_;
2120 }
2121
bump_versionTensorImpl2122 void bump_version() {
2123 version_counter_.bump();
2124 }
2125
pyobj_slotTensorImpl2126 impl::PyObjectSlot* pyobj_slot() {
2127 return &pyobj_slot_;
2128 }
2129
pyobj_slotTensorImpl2130 const impl::PyObjectSlot* pyobj_slot() const {
2131 return &pyobj_slot_;
2132 }
2133
2134 private:
2135 // See NOTE [std::optional operator usage in CUDA]
2136 // We probably don't want to expose this publicly until
2137 // the note is addressed.
device_optTensorImpl2138 std::optional<c10::Device> device_opt() const {
2139 return device_opt_;
2140 }
2141
2142 public:
2143 /**
2144 * The device type of a Tensor, e.g., DeviceType::CPU or DeviceType::CUDA.
2145 */
device_typeTensorImpl2146 DeviceType device_type() const {
2147 // TODO: A useful internal assert would be to show that device_opt_ is null
2148 // only if you are an undefined tensor
2149 TORCH_CHECK(
2150 device_opt_.has_value(),
2151 "device_type cannot be run on undefined Tensor");
2152 // See NOTE [std::optional operator usage in CUDA]
2153 return (*device_opt_).type();
2154 }
2155
2156 /**
2157 * @brief Extends the outer-most dimension of this tensor by num elements,
2158 * preserving the existing data.
2159 *
2160 * The underlying data may be reallocated in order to accommodate the new
2161 * elements, in which case this tensors' capacity is grown at a factor of
2162 * growthPct. This ensures that Extend runs on an amortized O(1) time
2163 * complexity.
2164 *
2165 * This op is auto-asynchronous if the underlying device (CUDA) supports it.
2166 */
2167 void Extend(int64_t num, float growthPct);
2168
2169 /**
2170 * @brief Reserve space for the underlying tensor.
2171 *
2172 * This must be called after Resize(), since we only specify the first
2173 * dimension This does not copy over the old data to the newly allocated space
2174 */
2175 void ReserveSpace(int64_t outer_dim);
2176
2177 /**
2178 * @brief Resizes a tensor.
2179 *
2180 * Resize takes in a vector of ints specifying the dimensions of the tensor.
2181 * You can pass in an empty vector to specify that it is a scalar (i.e.
2182 * containing one single item).
2183 *
2184 * The underlying storage may be deleted after calling Resize: if the new
2185 * shape leads to a different number of items in the tensor, the old memory
2186 * is deleted and new memory will be allocated next time you call
2187 * mutable_data(). However, if the shape is different but the total number of
2188 * items is the same, the underlying storage is kept.
2189 *
2190 * This method respects caffe2_keep_on_shrink. Consult the internal logic
2191 * of this method to see exactly under what circumstances this flag matters.
2192 */
2193 template <typename... Ts>
ResizeTensorImpl2194 void Resize(Ts... dim_source) {
2195 bool size_changed = SetDims(dim_source...);
2196 if (size_changed) {
2197 HandleResize();
2198 }
2199 }
2200
2201 template <typename T>
ResizeTensorImpl2202 void Resize(const std::vector<T>& dim_source) {
2203 Resize(ArrayRef<T>(dim_source));
2204 }
2205
2206 /**
2207 * Resizes the tensor without touching underlying storage.
2208 * This requires the total size of the tensor to remains constant.
2209 */
2210 void Reshape(const std::vector<int64_t>& dims);
2211
2212 /**
2213 * Release whatever memory the tensor was holding but keep size and type
2214 * information. Subsequent call to mutable_data will trigger new memory
2215 * allocation.
2216 */
2217 void FreeMemory();
2218
2219 /**
2220 * @brief Shares the data with another tensor.
2221 *
2222 * To share data between two tensors, the sizes of the two tensors must be
2223 * equal already. The reason we do not implicitly do a Resize to make the two
2224 * tensors have the same shape is that we want to allow tensors of different
2225 * shapes but the same number of items to still be able to share data. This
2226 * allows one to e.g. have a n-dimensional Tensor and a flattened version
2227 * sharing the same underlying storage.
2228 *
2229 * The source tensor should already have its data allocated.
2230 */
2231 // To be deprecated
2232 void ShareData(const TensorImpl& src);
2233
2234 void ShareExternalPointer(
2235 DataPtr&& data_ptr,
2236 const caffe2::TypeMeta data_type,
2237 size_t size_bytes);
2238
2239 /**
2240 * Returns a mutable raw pointer of the underlying storage. Since we will need
2241 * to know the type of the data for allocation, a TypeMeta object is passed in
2242 * to specify the necessary information. This is conceptually equivalent of
2243 * calling mutable_data<T>() where the TypeMeta parameter meta is derived from
2244 * the type T. This function differs from mutable_data<T>() in the sense that
2245 * the type T can be specified during runtime via the TypeMeta object.
2246 *
2247 * If the existing data does not match the desired type, it will be deleted
2248 * and a new storage will be created.
2249 */
raw_mutable_dataTensorImpl2250 inline void* raw_mutable_data(const caffe2::TypeMeta& meta) {
2251 // For 0-size tensors it's fine to return any pointer (including nullptr)
2252 if (data_type_ == meta && storage_initialized()) {
2253 return static_cast<void*>(
2254 static_cast<char*>(storage_.mutable_data()) +
2255 storage_offset_ * meta.itemsize());
2256 } else {
2257 bool had_special_dtor = data_type_.placementDelete() != nullptr;
2258 storage_offset_ = 0;
2259 data_type_ = meta;
2260 // NB: device is not changed
2261
2262 // We can reuse the existing buffer if the current data does not have
2263 // a special destructor and the new data doesn't have a special
2264 // constructor.
2265 if (numel_ == 0 ||
2266 (meta.placementNew() == nullptr && !had_special_dtor &&
2267 (storage_.nbytes() >= (numel_ * data_type_.itemsize())))) {
2268 TORCH_INTERNAL_ASSERT(
2269 storage_offset_ == 0); // because we just reallocated
2270 return storage_.mutable_data();
2271 }
2272 Allocator* allocator = storage_.allocator();
2273 // Storage might have nullptr allocator in rare cases, for example, if
2274 // an external memory segment has been wrapped with Tensor and we don't
2275 // know how to reallocate it. However, in order to preserve legacy C2
2276 // behavior, we allow reallocating the memory using default allocator.
2277 if (allocator == nullptr) {
2278 allocator = GetAllocator(storage_.device_type());
2279 }
2280 if (meta.placementNew()) {
2281 // For types that need placement new, we will call it, as well as
2282 // making sure that when the data is freed, it calls the right
2283 // destruction procedure.
2284 auto size = numel_;
2285 auto dtor = data_type_.placementDelete();
2286 auto data_ptr = allocator->allocate(numel_ * data_type_.itemsize());
2287 storage_.set_data_ptr_noswap(PlacementDeleteContext::makeDataPtr(
2288 std::move(data_ptr), dtor, size, storage_.device()));
2289 data_type_.placementNew()(storage_.mutable_data(), numel_);
2290 } else {
2291 // For fundamental type, new and delete is easier.
2292 storage_.set_data_ptr_noswap(
2293 allocator->allocate(numel_ * data_type_.itemsize()));
2294 }
2295 storage_.set_nbytes(numel_ * data_type_.itemsize());
2296 TORCH_INTERNAL_ASSERT(
2297 storage_offset_ == 0); // because we just reallocated
2298 device_opt_ = storage_.device();
2299 return storage_.mutable_data();
2300 }
2301 }
2302
2303 /**
2304 * Returns a typed pointer of the underlying storage.
2305 *
2306 * For fundamental types, we reuse possible existing storage if there
2307 * is sufficient capacity.
2308 */
2309 template <typename T>
mutable_dataTensorImpl2310 inline T* mutable_data() {
2311 if (storage_initialized() && data_type_.Match<T>()) {
2312 return static_cast<T*>(storage_.mutable_data()) + storage_offset_;
2313 }
2314 // Check it here statically - otherwise TypeMeta would throw the runtime
2315 // error in attempt to invoke TypeMeta::ctor()
2316 static_assert(
2317 std::is_default_constructible<T>::value,
2318 "Tensor can't hold non-default-constructable types");
2319 return static_cast<T*>(raw_mutable_data(caffe2::TypeMeta::Make<T>()));
2320 }
2321
2322 /**
2323 * True if a tensor is storage initialized. A tensor may become
2324 * storage UNINITIALIZED after a Resize() or FreeMemory()
2325 */
storage_initializedTensorImpl2326 bool storage_initialized() const {
2327 TORCH_CHECK(
2328 has_storage(),
2329 "cannot call storage_initialized on tensor that does not have storage");
2330 return storage_.data() || numel_ == 0;
2331 }
2332
2333 /**
2334 * True if a tensor is dtype initialized. A tensor allocated with
2335 * Caffe2-style constructors is dtype uninitialized until the
2336 * first time mutable_data<T>() is called.
2337 */
dtype_initializedTensorImpl2338 bool dtype_initialized() const noexcept {
2339 return data_type_ != caffe2::TypeMeta();
2340 }
2341
set_storage_keep_dtypeTensorImpl2342 void set_storage_keep_dtype(at::Storage storage) {
2343 TORCH_CHECK(
2344 allow_tensor_metadata_change(),
2345 "set_storage ",
2346 err_msg_tensor_metadata_change_not_allowed);
2347 storage_ = std::move(storage);
2348 device_opt_ = storage_.device();
2349 }
2350
set_storage_and_dtypeTensorImpl2351 void set_storage_and_dtype(
2352 at::Storage storage,
2353 const caffe2::TypeMeta data_type) {
2354 set_storage_keep_dtype(std::move(storage));
2355 data_type_ = data_type;
2356 }
2357
2358 void empty_tensor_restride_symint(MemoryFormat memory_format);
2359
2360 /**
2361 * Set the strides of the tensor to match memory_format
2362 *
2363 * WARNING: This function doesn't rearrange data and assumes tensor is a
2364 * memory contiguous
2365 */
empty_tensor_restrideTensorImpl2366 void empty_tensor_restride(MemoryFormat memory_format) {
2367 if (has_symbolic_sizes_strides_) {
2368 empty_tensor_restride_symint(memory_format);
2369 return;
2370 }
2371 #ifdef DEBUG
2372 TORCH_INTERNAL_ASSERT(
2373 compute_numel() == numel_,
2374 "If you are seeing this error, that means empty_tensor_restride was "
2375 "called before setting correct numel");
2376 #endif
2377 switch (memory_format) {
2378 case MemoryFormat::Contiguous: {
2379 // dim_ is a virtual call, don't repeat it
2380 const auto dim_ = dim();
2381 sizes_and_strides_.resize(dim_);
2382 if (dim_ > 0) {
2383 bool overflowed = false;
2384 const auto last_idx = dim_ - 1;
2385 sizes_and_strides_.stride_at_unchecked(last_idx) = 1;
2386 for (auto i = last_idx - 1; i >= 0; --i) {
2387 overflowed |= c10::mul_overflows(
2388 sizes_and_strides_.stride_at_unchecked(i + 1),
2389 std::max<int64_t>(
2390 sizes_and_strides_.size_at_unchecked(i + 1), 1),
2391 std::addressof(sizes_and_strides_.stride_at_unchecked(i)));
2392 }
2393 TORCH_CHECK(!overflowed, "Stride calculation overflowed");
2394 }
2395 break;
2396 }
2397 case MemoryFormat::ChannelsLast: {
2398 TORCH_CHECK(
2399 dim() == 4, "required rank 4 tensor to use channels_last format");
2400 set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes()));
2401 break;
2402 }
2403 case MemoryFormat::ChannelsLast3d: {
2404 TORCH_CHECK(
2405 dim() == 5,
2406 "required rank 5 tensor to use channels_last_3d format");
2407 set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes()));
2408 break;
2409 }
2410 case MemoryFormat::Preserve:
2411 TORCH_CHECK(false, "unsupported memory format ", memory_format);
2412 // Cleaning warning messages, no need to break as TORCH_CHECK(false)
2413 // terminates flow.
2414 // break;
2415 case MemoryFormat::NumOptions:
2416 TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format);
2417 }
2418 // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
2419 // exclusive see #24090
2420 refresh_contiguous();
2421 }
2422
is_strides_likeTensorImpl2423 bool is_strides_like(at::MemoryFormat memory_format) const {
2424 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
2425 return is_strides_like_custom(memory_format);
2426 }
2427 return is_strides_like_default(memory_format);
2428 }
2429
is_strides_like_channels_lastTensorImpl2430 bool is_strides_like_channels_last() const {
2431 return is_strides_like(at::MemoryFormat::ChannelsLast);
2432 }
2433
is_strides_like_channels_last_3dTensorImpl2434 bool is_strides_like_channels_last_3d() const {
2435 return is_strides_like(at::MemoryFormat::ChannelsLast3d);
2436 }
2437
is_non_overlapping_and_denseTensorImpl2438 bool is_non_overlapping_and_dense() const {
2439 if (C10_UNLIKELY(matches_policy(SizesStridesPolicy::CustomStrides))) {
2440 return is_non_overlapping_and_dense_custom();
2441 }
2442 return is_non_overlapping_and_dense_default();
2443 }
2444
2445 // if this returns true, then it is guaranteed that this tensor has symbolic
2446 // sizes/strides
has_symbolic_sizes_stridesTensorImpl2447 bool has_symbolic_sizes_strides() const {
2448 return has_symbolic_sizes_strides_;
2449 }
2450
2451 private:
2452 void HandleResize();
2453
2454 // The Caffe2 Resize() method supports being called both as Resize({2,2}) as
2455 // well as variadic with Resize(2, 2). These overloads provide all of the
2456 // supported calling configurations, while being overloads (and not templates)
2457 // so that implicit conversions still work.
2458 //
2459 // SetDims on ArrayRef is internally implemented as a template, so we can
2460 // handle both ArrayRefs of different types (there are some uses of
2461 // Resize in Caffe2 which pass in int, not int64_t.)
2462
2463 template <
2464 typename T,
2465 typename = typename std::enable_if_t<std::is_integral_v<T>>>
SetDimsTemplateTensorImpl2466 bool SetDimsTemplate(ArrayRef<T> src) {
2467 TORCH_CHECK(
2468 !has_symbolic_sizes_strides_,
2469 "SetDims() called on tensor with symbolic shape")
2470
2471 auto old_numel = numel_;
2472 sizes_and_strides_.resize(src.size());
2473 int64_t new_numel = 1;
2474 for (const auto i : c10::irange(src.size())) {
2475 new_numel *= src[i];
2476 sizes_and_strides_.size_at_unchecked(i) = src[i];
2477 }
2478 numel_ = new_numel;
2479 empty_tensor_restride(MemoryFormat::Contiguous);
2480 return numel_ != old_numel;
2481 }
2482
SetDimsTensorImpl2483 bool SetDims(ArrayRef<int64_t> s) {
2484 return SetDimsTemplate(s);
2485 }
2486
SetDimsTensorImpl2487 bool SetDims(ArrayRef<int> s) {
2488 return SetDimsTemplate(s);
2489 }
2490
SetDimsTensorImpl2491 bool SetDims(ArrayRef<size_t> s) {
2492 return SetDimsTemplate(s);
2493 }
2494
SetDimsTensorImpl2495 bool SetDims() {
2496 return SetDims(IntArrayRef{});
2497 }
2498
SetDimsTensorImpl2499 bool SetDims(const int64_t d0) {
2500 return SetDims(IntArrayRef{d0});
2501 }
2502
SetDimsTensorImpl2503 bool SetDims(const int64_t d0, const int64_t d1) {
2504 return SetDims(IntArrayRef{d0, d1});
2505 }
2506
SetDimsTensorImpl2507 bool SetDims(const int64_t d0, const int64_t d1, const int64_t d2) {
2508 return SetDims(IntArrayRef{d0, d1, d2});
2509 }
2510
SetDimsTensorImpl2511 bool SetDims(
2512 const int64_t d0,
2513 const int64_t d1,
2514 const int64_t d2,
2515 const int64_t d3) {
2516 return SetDims(IntArrayRef{d0, d1, d2, d3});
2517 }
2518
2519 /**
2520 * Compute the number of elements based on the sizes of a tensor.
2521 */
2522 // NB: This is ONLY called when sizes_and_strides_ is used directly; if
2523 // we are virtualizing, then numel calls are virtualized as well, and this
2524 // should never get called
compute_numelTensorImpl2525 int64_t compute_numel() const {
2526 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
2527 #if C10_HAS_BUILTIN_OVERFLOW() && !defined(C10_MOBILE)
2528 // Use overflow checks if supported by the compiler
2529 return safe_compute_numel();
2530 #else
2531 return c10::multiply_integers(sizes_and_strides_.sizes_arrayref());
2532 #endif
2533 }
2534
2535 /**
2536 * Compute the number of elements based on the sizes of a
2537 * tensor. Catches integer overflow that may occur when a tensor
2538 * using a sparse layout has multiple dimensions with large sizes.
2539 */
safe_compute_numelTensorImpl2540 int64_t safe_compute_numel() const {
2541 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!has_symbolic_sizes_strides_);
2542 uint64_t n = 1;
2543 bool overflows =
2544 c10::safe_multiplies_u64(sizes_and_strides_.sizes_arrayref(), &n);
2545 constexpr auto numel_max = std::min(
2546 static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
2547 static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
2548
2549 overflows |= (n > numel_max);
2550 TORCH_CHECK(!overflows, "numel: integer multiplication overflow");
2551 return static_cast<int64_t>(n);
2552 }
2553
2554 /**
2555 * Compute whether or not a tensor is contiguous based on the sizes and
2556 * strides of a tensor.
2557 */
2558 bool compute_contiguous(identity<bool>) const;
2559
2560 bool compute_channels_last_contiguous_2d(identity<bool>) const;
2561
2562 bool compute_channels_last_contiguous_3d(identity<bool>) const;
2563
2564 bool compute_strides_like_channels_last_2d(identity<bool>) const;
2565
2566 bool compute_strides_like_channels_last_3d(identity<bool>) const;
2567
2568 bool compute_non_overlapping_and_dense(identity<bool>) const;
2569
2570 protected:
2571 /**
2572 * Recompute the cached numel of a tensor. Call this if you modify
2573 * sizes.
2574 *
2575 * For tensors with sparse layouts, use safe_refresh_numel() instead
2576 * because it will catch integer overflow that may occur for tensors
2577 * with sparse layouts and large dimensions.
2578 *
2579 * NB: We may uselessly recompute cached numel even in situations where
2580 * it is completely never used (e.g., if CustomSizes for Python). However,
2581 * we still must keep it up to date in case the Python overload
2582 * returns None (in which case we will consult the field here). This also
2583 * implies that sizes/strides will never be complete garbage; in the
2584 * very worst case scenario, it will reflect a 1-dim zero size tensor.
2585 */
refresh_numelTensorImpl2586 void refresh_numel() {
2587 if (has_symbolic_sizes_strides_) {
2588 symbolic_shape_meta().refresh_numel();
2589 } else {
2590 numel_ = compute_numel();
2591 }
2592 }
2593
2594 /**
2595 * Recompute the cached numel of a tensor. Call this if you modify
2596 * sizes. Use only for tensors with sparse layouts because only
2597 * sparse tensor are likely to have sizes that may lead to integer
2598 * overflow when computing numel.
2599 */
safe_refresh_numelTensorImpl2600 void safe_refresh_numel() {
2601 if (has_symbolic_sizes_strides_) {
2602 // NB: sym numel is done with symbolic integers, which handle overflow
2603 // checking
2604 symbolic_shape_meta().refresh_numel();
2605 } else {
2606 numel_ = safe_compute_numel();
2607 }
2608 }
2609
2610 private:
2611 // NB: the TypeId argument prevents confusion where you pass a true/false
2612 // literal and pick the wrong overload
2613
_set_is_contiguousTensorImpl2614 void _set_is_contiguous(identity<bool>, bool b) {
2615 is_contiguous_ = b;
2616 }
2617
_set_is_channels_last_contiguousTensorImpl2618 void _set_is_channels_last_contiguous(identity<bool>, bool b) {
2619 is_channels_last_contiguous_ = b;
2620 }
2621
_set_is_channels_last_3d_contiguousTensorImpl2622 void _set_is_channels_last_3d_contiguous(identity<bool>, bool b) {
2623 is_channels_last_3d_contiguous_ = b;
2624 }
2625
_set_is_channels_lastTensorImpl2626 void _set_is_channels_last(identity<bool>, bool b) {
2627 is_channels_last_ = b;
2628 }
2629
_set_is_channels_last_3dTensorImpl2630 void _set_is_channels_last_3d(identity<bool>, bool b) {
2631 is_channels_last_3d_ = b;
2632 }
2633
_set_is_non_overlapping_and_denseTensorImpl2634 void _set_is_non_overlapping_and_dense(identity<bool>, bool b) {
2635 is_non_overlapping_and_dense_ = b;
2636 }
2637
2638 // These are little wrappers over the real compute_ functions that
2639 // can make use of other contiguity fields to short circuit.
2640
compute_is_non_overlapping_and_dense_dim4TensorImpl2641 bool compute_is_non_overlapping_and_dense_dim4(identity<bool> type_id) {
2642 return is_contiguous_ || is_channels_last_contiguous_ ||
2643 compute_non_overlapping_and_dense(type_id);
2644 }
2645
compute_channels_last_contiguous_3d_dim5TensorImpl2646 bool compute_channels_last_contiguous_3d_dim5(identity<bool> type_id) {
2647 return !is_channels_last_contiguous_ &&
2648 compute_channels_last_contiguous_3d(type_id);
2649 }
2650
compute_channels_last_2d_dim5TensorImpl2651 bool compute_channels_last_2d_dim5(identity<bool> type_id) {
2652 return !is_channels_last_3d_contiguous_ &&
2653 compute_strides_like_channels_last_2d(type_id);
2654 }
2655
compute_channels_last_3d_dim5TensorImpl2656 bool compute_channels_last_3d_dim5(identity<bool> type_id) {
2657 return !is_channels_last_ && compute_strides_like_channels_last_3d(type_id);
2658 }
2659
compute_is_non_overlapping_and_dense_dim5TensorImpl2660 bool compute_is_non_overlapping_and_dense_dim5(identity<bool> type_id) {
2661 return is_contiguous_ || is_channels_last_contiguous_ ||
2662 is_channels_last_3d_contiguous_ ||
2663 compute_non_overlapping_and_dense(type_id);
2664 }
2665
compute_is_non_overlapping_and_dense_anydimTensorImpl2666 bool compute_is_non_overlapping_and_dense_anydim(identity<bool> type_id) {
2667 return is_contiguous_ || compute_non_overlapping_and_dense(type_id);
2668 }
2669
2670 template <typename T>
_refresh_contiguousTensorImpl2671 void _refresh_contiguous() {
2672 auto type_id = identity<T>();
2673 // Note:
2674 // Dim 0, 1, 2 will never be a channels last 2d/3d format
2675 // Dim 3+ is possibly be a channels last 2d format (Dim 4 only at this
2676 // point) Dim 4+ is possibly be a channels last 3d format (Dim 5 only at
2677 // this point)
2678 switch (dim()) {
2679 case 4: {
2680 _set_is_contiguous(type_id, compute_contiguous(type_id));
2681 _set_is_channels_last_contiguous(
2682 type_id, compute_channels_last_contiguous_2d(type_id));
2683 _set_is_channels_last_3d_contiguous(type_id, false);
2684 _set_is_channels_last(
2685 type_id, compute_strides_like_channels_last_2d(type_id));
2686 _set_is_channels_last_3d(type_id, false);
2687 _set_is_non_overlapping_and_dense(
2688 type_id, compute_is_non_overlapping_and_dense_dim4(type_id));
2689 break;
2690 }
2691 case 5: {
2692 _set_is_contiguous(type_id, compute_contiguous(type_id));
2693 _set_is_channels_last_contiguous(
2694 type_id, compute_channels_last_contiguous_2d(type_id));
2695 _set_is_channels_last_3d_contiguous(
2696 type_id, compute_channels_last_contiguous_3d_dim5(type_id));
2697 _set_is_channels_last(type_id, compute_channels_last_2d_dim5(type_id));
2698 _set_is_channels_last_3d(
2699 type_id, compute_channels_last_3d_dim5(type_id));
2700 _set_is_non_overlapping_and_dense(
2701 type_id, compute_is_non_overlapping_and_dense_dim5(type_id));
2702 break;
2703 }
2704 default:
2705 // is_channels_last_ and is_channels_last_3d_ are suggested
2706 // memory_format. Being channels_last_contiguous doesn't necessarily
2707 // mean the tensor is strided like channels_last: for strides on channel
2708 // dimension could suggest desired memory_layout, but it doesn't affect
2709 // memory storage
2710 _set_is_contiguous(type_id, compute_contiguous(type_id));
2711 _set_is_channels_last_contiguous(type_id, false);
2712 _set_is_channels_last_3d_contiguous(type_id, false);
2713 _set_is_channels_last(type_id, false);
2714 _set_is_channels_last_3d(type_id, false);
2715 _set_is_non_overlapping_and_dense(
2716 type_id, compute_is_non_overlapping_and_dense_anydim(type_id));
2717 break;
2718 }
2719 }
2720
2721 protected:
2722 /**
2723 * Recompute the cached contiguity of a tensor. Call this if you modify sizes
2724 * or strides.
2725 */
refresh_contiguousTensorImpl2726 void refresh_contiguous() {
2727 if (has_symbolic_sizes_strides_) {
2728 symbolic_shape_meta().refresh_contiguous();
2729 } else {
2730 _refresh_contiguous<bool>();
2731 }
2732 }
2733
2734 /**
2735 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
2736 * storage_offset) from one TensorImpl to another TensorImpl.
2737 *
2738 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
2739 * [ TensorImpl Shallow-Copying ].
2740 */
2741 static void copy_tensor_metadata(
2742 const TensorImpl* src_impl,
2743 TensorImpl* dest_impl,
2744 const c10::VariableVersion& version_counter,
2745 bool allow_tensor_metadata_change);
2746
2747 /**
2748 * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
2749 * storage_offset) from one TensorImpl to another TensorImpl.
2750 *
2751 * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE
2752 * [ TensorImpl Shallow-Copying ].
2753 */
2754 static void copy_tensor_metadata(
2755 const TensorImpl* src_impl,
2756 TensorImpl* dest_impl,
2757 c10::VariableVersion&& version_counter,
2758 bool allow_tensor_metadata_change);
2759
2760 private:
2761 static void copy_tensor_metadata_except_version_counter(
2762 const TensorImpl* src_impl,
2763 TensorImpl* dest_impl,
2764 bool allow_tensor_metadata_change);
2765
2766 protected:
2767 // Error message to show when the user tries to change tensor metadata on
2768 // Tensor created from .data or .detach().
2769 //
2770 // See NOTE [ Metadata Change for a Detached Tensor ] for details.
2771 static const char* const err_msg_tensor_metadata_change_not_allowed;
2772
2773 static void copy_generic_tensor_metadata(
2774 const TensorImpl* src_impl,
2775 TensorImpl* dest_impl);
2776
2777 public:
set_storage_access_should_throwTensorImpl2778 void set_storage_access_should_throw() {
2779 storage_access_should_throw_ = true;
2780 }
2781
2782 public:
set_custom_sizes_stridesTensorImpl2783 void set_custom_sizes_strides(SizesStridesPolicy policy) {
2784 custom_sizes_strides_ = static_cast<uint8_t>(policy);
2785 refresh_sizes_strides_policy();
2786 }
2787
set_python_custom_sizes_stridesTensorImpl2788 void set_python_custom_sizes_strides(SizesStridesPolicy policy) {
2789 python_custom_sizes_strides_ = static_cast<uint8_t>(policy);
2790 refresh_sizes_strides_policy();
2791 }
2792
set_custom_deviceTensorImpl2793 void set_custom_device(bool custom_device) {
2794 custom_device_ = custom_device;
2795 refresh_device_policy();
2796 }
2797
set_custom_layoutTensorImpl2798 void set_custom_layout(bool custom_layout) {
2799 custom_layout_ = custom_layout;
2800 refresh_layout_policy();
2801 }
2802
set_python_custom_deviceTensorImpl2803 void set_python_custom_device(bool custom_device) {
2804 python_custom_device_ = custom_device;
2805 refresh_device_policy();
2806 }
2807
set_python_custom_layoutTensorImpl2808 void set_python_custom_layout(bool custom_layout) {
2809 python_custom_layout_ = custom_layout;
2810 refresh_layout_policy();
2811 }
2812
2813 protected:
refresh_sizes_strides_policyTensorImpl2814 void refresh_sizes_strides_policy() {
2815 if (has_symbolic_sizes_strides_) {
2816 sizes_strides_policy_ =
2817 static_cast<uint8_t>(SizesStridesPolicy::CustomSizes);
2818 } else {
2819 sizes_strides_policy_ =
2820 std::max(custom_sizes_strides_, python_custom_sizes_strides_);
2821 }
2822 }
2823
refresh_device_policyTensorImpl2824 void refresh_device_policy() {
2825 device_policy_ = custom_device_ || python_custom_device_;
2826 }
2827
refresh_layout_policyTensorImpl2828 void refresh_layout_policy() {
2829 layout_policy_ = custom_layout_ || python_custom_layout_;
2830 }
2831
2832 protected:
2833 Storage storage_;
2834
2835 private:
2836 // This pointer points to an AutogradMeta struct that stores autograd-specific
2837 // fields (such as grad_ / grad_fn_ / grad_accumulator_). This pointer always
2838 // has unique ownership (meaning only one TensorImpl can own it at a time).
2839 //
2840 // autograd_meta_ can be nullptr, as an optimization. When this occurs, it is
2841 // equivalent to having an autograd_meta_ pointing to a default constructed
2842 // AutogradMeta; intuitively, tensors which don't require grad will have this
2843 // field set to null.
2844 //
2845 // This means accessors on autograd_meta_ have to be careful to test if they
2846 // got a nullptr, and handle default behavior appropriately in that case.
2847 //
2848 // Note that we don't enforce the invariant that if the AutogradMeta is
2849 // default constructed, it is nullptr (to do this, we'd have to continuously
2850 // check if an AutogradMeta became, by mutation, equal to the default
2851 // constructed form. (This might be useful, but it seems rare enough that
2852 // a requires_grad=True variable will turn back into the requires_grad=False
2853 // version.) So there are three representable states:
2854 //
2855 // 1. autograd_meta_ == nullptr
2856 // 2. autograd_meta_ is default constructed (semantically, same as (1))
2857 // 3. autograd_meta_ has nontrivial information content
2858 //
2859 std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
2860
2861 protected:
2862 std::unique_ptr<c10::ExtraMeta> extra_meta_ = nullptr;
2863
2864 c10::VariableVersion version_counter_;
2865
2866 impl::PyObjectSlot pyobj_slot_;
2867
2868 c10::impl::SizesAndStrides sizes_and_strides_;
2869
2870 int64_t storage_offset_ = 0;
2871 // If sizes and strides are empty, the numel is 1!! However, most of the
2872 // time, we will immediately set sizes to {0} and reset numel to 0.
2873 // (Can't do that in the default initializers, because there's no way to
2874 // spell "allocate a one-element array" for strides_).
2875 int64_t numel_ = 1;
2876
2877 // INVARIANT: When storage is non-null, this type meta must
2878 // agree with the type meta in storage
2879 caffe2::TypeMeta data_type_;
2880
2881 // NOTE [std::optional operator usage in CUDA]
2882 // Our optional definition doesn't compile in .cu file if `value()` or
2883 // `operator->` are used. Instead, we always use `operator*`.
2884 // See https://github.com/pytorch/pytorch/issues/18496 for more info.
2885 // If this is too burdensome to maintain, we can just
2886 // manually implement this with an additional bool.
2887
2888 // INVARIANT: When storage is non-null, this Device must
2889 // agree with the type meta in storage.
2890 //
2891 // INVARIANT: device_opt_ is only nullopt for undefined tensors
2892 // (which do not have a device.)
2893 std::optional<c10::Device> device_opt_;
2894
2895 // default member initializers for bit-fields only available with -std=c++2a
2896 // or -std=gnu++2a
init_bitfieldsTensorImpl2897 inline void init_bitfields() {
2898 is_contiguous_ = true;
2899 is_channels_last_ = false;
2900 is_channels_last_contiguous_ = false;
2901 is_channels_last_3d_ = false;
2902 is_channels_last_3d_contiguous_ = false;
2903 is_non_overlapping_and_dense_ = true;
2904 is_wrapped_number_ = false;
2905 allow_tensor_metadata_change_ = true;
2906 reserved_ = false;
2907 sizes_strides_policy_ = static_cast<uint8_t>(SizesStridesPolicy::Default);
2908 custom_sizes_strides_ = static_cast<uint8_t>(SizesStridesPolicy::Default);
2909 python_custom_sizes_strides_ =
2910 static_cast<uint8_t>(SizesStridesPolicy::Default);
2911 python_custom_device_ = false;
2912 python_custom_layout_ = false;
2913 custom_device_ = false;
2914 custom_layout_ = false;
2915 device_policy_ = false;
2916 layout_policy_ = false;
2917 storage_access_should_throw_ = false;
2918 has_symbolic_sizes_strides_ = false;
2919 }
2920
2921 // Tensor is contiguous
2922 bool is_contiguous_ : 1;
2923
2924 // Tensor is a subclass that does not permit storage access.
2925 bool storage_access_should_throw_ : 1;
2926
2927 // Tensor is stored in the channels last 2d memory format, when dimensions
2928 // order is (N)CHW and C-strides < W-strides < H-strides (< N-strides)
2929 // (If size of any dimension is equal to 1, this dimension strides value
2930 // is not taken into account).
2931 bool is_channels_last_ : 1;
2932
2933 // Channels last contiguous tensor is channel last tensor which occupies
2934 // contiguous memory block.
2935 bool is_channels_last_contiguous_ : 1;
2936
2937 // Tensor is stored in the channels last 3d memory format, when dimensions
2938 // order is (N)CDHW and C-strides < W-strides < H-strides < D - strides (<
2939 // N-strides) (If size of any dimension is equal to 1, this dimension strides
2940 // value is not taken into account).
2941 bool is_channels_last_3d_ : 1;
2942
2943 // Channels last 3d contiguous tensor is channel last 3d tensor which occupies
2944 // contiguous memory block.
2945 bool is_channels_last_3d_contiguous_ : 1;
2946
2947 // Dense tensor is the tensor that store values in a contiguous block of
2948 // memory. Non-overlapping tensor is the tensor in which elements occupy
2949 // individual non-repetitive memory.
2950 bool is_non_overlapping_and_dense_ : 1;
2951
2952 bool is_wrapped_number_ : 1;
2953
2954 // NOTE [ Metadata Change for a Detached Tensor ]
2955 //
2956 // Normally, a user is allowed to change the tensor metadata
2957 // (e.g. sizes / strides / storage / storage_offset) of a tensor.
2958 // However, if the tensor is created by `t1_detached = t1.data` in Python
2959 // or `t1_detached = t1.detach()` in Python/C++, those changes to the
2960 // tensor metadata of `t1_detached` will not be propagated back to the
2961 // original tensor `t1`. In order to make such changes explicitly illegal,
2962 // we created the `allow_tensor_metadata_change_` flag, to prevent users
2963 // from changing metadata of the detached tensor and expecting the original
2964 // tensor to also be updated.
2965 //
2966 // NOTE: For a full list of tensor metadata fields, please see
2967 // `copy_tensor_metadata()` in TensorImpl and its subclasses to find
2968 // which fields are copied by value.
2969 bool allow_tensor_metadata_change_ : 1;
2970
2971 // we decide to keep reserved_ and it will
2972 // live in Tensor after the split
2973 // The logic is that if Extend() or ReserveSpace() were ever called,
2974 // then subsequent Resize()s will not free up Storage.
2975 bool reserved_ : 1;
2976
2977 // Call _custom() virtual methods for
2978 // strides()/is_contiguous()/sizes()/dim()/numel()
2979 // This is a combination of sizes_strides_custom_dispatch_
2980 // and has_symbolic_sizes_strides_
2981 uint8_t sizes_strides_policy_ : 2;
2982
2983 // Whether or not sizes_and_strides_ contains a symbolic value.
2984 bool has_symbolic_sizes_strides_ : 1;
2985
2986 // Call _custom() virtual method for
2987 // strides()/is_contiguous()/sizes()/dim()/numel()
2988 uint8_t custom_sizes_strides_ : 2;
2989
2990 // Combo of custom_ and python_custom_
2991 bool device_policy_ : 1;
2992 bool layout_policy_ : 1;
2993
2994 // Call _custom() virtual method for device()
2995 bool custom_device_ : 1;
2996
2997 // Call _custom() virtual method for layout()
2998 bool custom_layout_ : 1;
2999
3000 // Call into Python for
3001 // strides()/is_contiguous()/sizes()/dim()/numel()
3002 uint8_t python_custom_sizes_strides_ : 2;
3003
3004 // Call into Python for device()
3005 bool python_custom_device_ : 1;
3006
3007 // Call into Python for layout()
3008 bool python_custom_layout_ : 1;
3009
3010 // The set of DispatchKeys which describe this tensor. NB: this
3011 // does NOT include Autograd (historically, it did, but
3012 // not anymore!)
3013 //
3014 // INVARIANT: extra_meta_->named_tensor_meta_ != nullptr <==>
3015 // key_set_.has(DispatchKey::Named)
3016 DispatchKeySet key_set_;
3017
3018 private:
3019 // C10_TensorImpl_Size_Check_Dummy_Class needs to be friends with
3020 // TensorImpl so it can inspect the size of private fields
3021 template <
3022 size_t cplusplus,
3023 size_t clang_ver_major,
3024 size_t gcc_ver,
3025 size_t gcc_ver_minor,
3026 size_t nvcc,
3027 size_t cuda_version,
3028 size_t cuda_version_major,
3029 size_t ptr_size>
3030 friend class C10_TensorImpl_Size_Check_Dummy_Class;
3031 };
3032
3033 // Note [TensorImpl size constraints]
3034 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3035 // Changed the size of TensorImpl? If the size went down, good for
3036 // you! Adjust the documentation below and the expected size.
3037 // Did it go up? Read on...
3038 //
3039 // Struct size matters. In some production systems at Facebook, we have
3040 // 400M live tensors during a training run. Do the math: every 64-bit
3041 // word you add to Tensor is an extra 3.2 gigabytes in RAM.
3042 //
3043 // If you are a Facebook employee, you can check if the run in question
3044 // has tipped you over the point using the command here:
3045 // https://fburl.com/q5enpv98
3046 //
3047 // For reference, we OOMed at 160 bytes (20 words) per TensorImpl.
3048 // This is not counting overhead from strides out-of-line allocation and
3049 // StorageImpl space and this is from before we inlined sizes and strides
3050 // directly into TensorImpl as SmallVectors.
3051 //
3052 // Our memory usage on 32-bit systems is suboptimal, but we're not checking
3053 // for it at the moment (to help avoid rage inducing cycles when the
3054 // 32-bit number is wrong).
3055 //
3056 // Current breakdown:
3057 //
3058 // vtable pointer
3059 // strong refcount TODO: pack these into one word
3060 // weak refcount
3061 // storage pointer
3062 // autograd metadata pointer
3063 // named tensor metadata pointer
3064 // version counter pointer
3065 // PyObjectSlot
3066 // SizesAndStrides size/pointer
3067 // SizesAndStrides sizes (pre-allocated 0)
3068 // SizesAndStrides sizes (pre-allocated 1)
3069 // SizesAndStrides sizes (pre-allocated 2)
3070 // SizesAndStrides sizes (pre-allocated 3)
3071 // SizesAndStrides sizes (pre-allocated 4)
3072 // SizesAndStrides strides (pre-allocated 0)
3073 // SizesAndStrides strides (pre-allocated 1)
3074 // SizesAndStrides strides (pre-allocated 2)
3075 // SizesAndStrides strides (pre-allocated 3)
3076 // SizesAndStrides strides (pre-allocated 4)
3077 // storage offset
3078 // numel
3079 // data type, device, is_contiguous, storage_access_should_throw_, bitfields
3080 // DispatchKeySet
3081 //
3082
3083 // Various preprocessor macros we use to check that the
3084 // TensorImpl size hasn't changed unexpectedly. We undef
3085 // these later.
3086 #ifndef __NVCC__
3087 #define C10_NVCC 0
3088 #else
3089 #define C10_NVCC __NVCC__
3090 #endif
3091
3092 #ifndef __CUDA_VER_MAJOR__
3093 #define C10_CUDA_VERSION_MAJOR 0
3094 #else
3095 #define C10_CUDA_VERSION_MAJOR __CUDA_VER_MAJOR__
3096 #endif
3097
3098 #ifndef CUDA_VERSION
3099 #define C10_CUDA_VERSION 0
3100 #else
3101 #define C10_CUDA_VERSION CUDA_VERSION
3102 #endif
3103
3104 #ifndef __clang_major__
3105 #define C10_CLANG_MAJOR_VERSION 0
3106 #else
3107 #define C10_CLANG_MAJOR_VERSION __clang_major__
3108 #endif
3109
3110 #ifndef __GNUC__
3111 #define C10_GCC_VERSION 0
3112 #else
3113 #define C10_GCC_VERSION __GNUC__
3114 #endif
3115
3116 #ifndef __GNUC_MINOR__
3117 #define C10_GCC_VERSION_MINOR 0
3118 #else
3119 #define C10_GCC_VERSION_MINOR __GNUC_MINOR__
3120 #endif
3121
3122 // We use a templatized class to both contain the logic of checking the sizes
3123 // as well as to provide compile-time information that might be useful in
3124 // figuring out why sizes may have changed.
3125 // All the compile time information is given by the template fields that are
3126 // always printed by the compiler when the static_assert fails.
3127 template <
3128 size_t cplusplus = __cplusplus,
3129 size_t clang_ver_major = C10_CLANG_MAJOR_VERSION,
3130 size_t gcc_ver = C10_GCC_VERSION,
3131 size_t gcc_ver_minor = C10_GCC_VERSION_MINOR,
3132 size_t nvcc = C10_NVCC,
3133 size_t cuda_version = C10_CUDA_VERSION,
3134 size_t cuda_version_major = C10_CUDA_VERSION_MAJOR,
3135 size_t ptr_size = sizeof(void*)>
3136 class C10_TensorImpl_Size_Check_Dummy_Class : private TensorImpl {
3137 // Names of (non-bitfield) fields in TensorImpl; used to provide
3138 // compile-time info about fields whose size changes unexpectedly.
3139 enum class FieldNameEnum {
3140 storage_,
3141 autograd_meta_,
3142 extra_meta_,
3143 version_counter_,
3144 pyobj_slot_,
3145 sizes_and_strides_,
3146 storage_offset_,
3147 numel_,
3148 data_type_,
3149 device_opt_,
3150 key_set_,
3151 TOTAL_SIZE
3152 };
3153
3154 // Provides compile-time equality check that reveals what numbers
3155 // were used and on which quantity
3156 template <size_t Actual, size_t Expected, FieldNameEnum FiledName>
are_equal()3157 constexpr static bool are_equal() {
3158 static_assert(
3159 Actual == Expected,
3160 "Actual and Expected sizes of a field did not match!");
3161 return true;
3162 }
3163
3164 // Provides compile-time <= check that reveals what numbers
3165 // were used and on which quantity
3166 template <size_t Actual, size_t Expected, FieldNameEnum FiledName>
is_le()3167 constexpr static bool is_le() {
3168 static_assert(
3169 Actual <= Expected,
3170 "Actual and Expected sizes of a field did not match!");
3171 return true;
3172 }
3173
3174 public:
3175 // Compile-time check that TensorImpl field sizes are as expected
3176 //
3177 // Observed total sizes and associated versions
3178 // If you find a flag that predicts when unique_ptr has 16 bytes
3179 // on 64-bit systems or when sizes_and_strides_ is 84 vs 88 bytes
3180 // on 32-bit systems you get a cookie!
3181 // Length | LLVM | GCC | C++ | CUDA
3182 // 192 | ? | 11.2 | 201703 | 11040
3183 // 208 | ? | 11.2 | 201703 | 11040
3184 // 208 | ? | 11.2 | 201402 | 11040
3185 // 192 | ? | 11.2 | 201402 | 11040
3186 // 160 | 12 | 4.2 | 201703 | 0
3187 //
3188 // To keep things clean, we split on systems here.
3189
3190 #if UINTPTR_MAX == 0xFFFFFFFF
3191 // This is a 32-bit system
check_sizes()3192 static constexpr bool check_sizes() {
3193 constexpr size_t tsize = 20 * sizeof(int64_t);
3194
3195 // clang-format off
3196 are_equal<sizeof(storage_), 4, FieldNameEnum::storage_>();
3197 are_equal<sizeof(autograd_meta_), 4, FieldNameEnum::autograd_meta_>();
3198 are_equal<sizeof(extra_meta_), 4, FieldNameEnum::extra_meta_>();
3199 are_equal<sizeof(version_counter_), 4, FieldNameEnum::version_counter_>();
3200 are_equal<sizeof(pyobj_slot_), 8, FieldNameEnum::pyobj_slot_>();
3201 is_le<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
3202 are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
3203 are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
3204 are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
3205 are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
3206 are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
3207 is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
3208 // clang-format on
3209
3210 return true;
3211 }
3212 #else
3213 // This is a 64-bit system
3214 static constexpr bool check_sizes() {
3215 constexpr size_t tsize = 26 * sizeof(int64_t);
3216
3217 // clang-format off
3218 are_equal<sizeof(storage_), 8, FieldNameEnum::storage_>();
3219 // On some systems involving NVCC the size of unique_ptr is 16 bytes. We haven't
3220 // figured out how to detect those via macro preprocessors yet, so we use <=
3221 // comparisons for the relevant fields.
3222 is_le<sizeof(autograd_meta_), 16, FieldNameEnum::autograd_meta_>();
3223 is_le<sizeof(extra_meta_), 16, FieldNameEnum::extra_meta_>();
3224 are_equal<sizeof(version_counter_), 8, FieldNameEnum::version_counter_>();
3225 are_equal<sizeof(pyobj_slot_), 16, FieldNameEnum::pyobj_slot_>();
3226 are_equal<sizeof(sizes_and_strides_), 88, FieldNameEnum::sizes_and_strides_>();
3227 are_equal<sizeof(storage_offset_), 8, FieldNameEnum::storage_offset_>();
3228 are_equal<sizeof(numel_), 8, FieldNameEnum::numel_>();
3229 are_equal<sizeof(data_type_), 2, FieldNameEnum::data_type_>();
3230 are_equal<sizeof(device_opt_), 3, FieldNameEnum::device_opt_>();
3231 are_equal<sizeof(key_set_), 8, FieldNameEnum::key_set_>();
3232 is_le<sizeof(TensorImpl), tsize, FieldNameEnum::TOTAL_SIZE>();
3233 // clang-format on
3234
3235 return true;
3236 }
3237 #endif
3238 };
3239
3240 // We use a class to encapsulate size-checking logic with
3241 // templates to capture sizes and flags. We call this within
3242 // a static assert to prove there is no run-time behaviour.
3243 // Since the methods we call return either true or fail their
3244 // own static_asserts, we should never see the error messages
3245 // below. We have to provide it though for c++ <17.
3246 static_assert(
3247 C10_TensorImpl_Size_Check_Dummy_Class<>::check_sizes(),
3248 "You should not see this message.");
3249
3250 // Clean up after ourselves
3251 #undef C10_NVCC
3252 #undef C10_CUDA_VERSION_MAJOR
3253 #undef C10_CUDA_VERSION
3254 #undef C10_CLANG_MAJOR_VERSION
3255 #undef C10_GCC_VERSION
3256 #undef C10_GCC_VERSION_MINOR
3257
3258 } // namespace c10
3259