xref: /aosp_15_r20/external/pytorch/c10/core/TensorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/TensorImpl.h>
2 
3 #include <c10/core/Contiguity.h>
4 #include <c10/core/CopyBytes.h>
5 #include <c10/core/InferenceMode.h>
6 #include <c10/core/SymIntArrayRef.h>
7 #include <c10/core/impl/LocalDispatchKeySet.h>
8 #include <c10/core/impl/PyInterpreter.h>
9 #include <c10/core/impl/TorchDispatchModeTLS.h>
10 #include <c10/util/Logging.h>
11 #include <c10/util/accumulate.h>
12 #include <c10/util/irange.h>
13 #include <optional>
14 
15 #include <utility>
16 
17 C10_DEFINE_bool(
18     caffe2_keep_on_shrink,
19     true,
20     "If set, keeps memory when a tensor is shrinking its size.");
21 
22 C10_DEFINE_int64(
23     caffe2_max_keep_on_shrink_memory,
24     LLONG_MAX,
25     "The maximum memory in bytes to keep on shrink, if the difference between "
26     "tensor sizes is bigger than this then tensor will be reset.");
27 
28 namespace c10 {
29 
30 const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
31     "is not allowed on a Tensor created from .data or .detach().\n"
32     "If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
33     "without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
34     "For example, change:\n"
35     "    x.data.set_(y)\n"
36     "to:\n"
37     "    with torch.no_grad():\n"
38     "        x.set_(y)";
39 
mutable_grad()40 at::Tensor& TensorImpl::mutable_grad() {
41   if (!autograd_meta_)
42     autograd_meta_ = impl::GetAutogradMetaFactory()->make();
43   return autograd_meta_->mutable_grad();
44 }
45 
grad() const46 const at::Tensor& TensorImpl::grad() const {
47   // Yes, I know this looks really weird.  But I don't really have a choice as
48   // long as this function returns a const reference to Tensor.  I'm not
49   // really sure how I would have designed this API differently, but it
50   // is not so easy to fix right now because the mutable counterpart of
51   // this function must keep working so that "x.grad() = ..." keeps working
52   // (part of public API).
53   if (!autograd_meta_)
54     return impl::GetAutogradMetaFactory()->undefined_tensor();
55   return autograd_meta_->grad();
56 }
57 
_fw_grad(uint64_t level,const at::TensorBase & self) const58 const at::Tensor& TensorImpl::_fw_grad(
59     uint64_t level,
60     const at::TensorBase& self) const {
61   // See TensorImpl::grad() above for explanation about the line below
62   if (!autograd_meta_)
63     return impl::GetAutogradMetaFactory()->undefined_tensor();
64   return autograd_meta_->fw_grad(level, self);
65 }
66 
_set_fw_grad(const at::TensorBase & new_grad,const at::TensorBase & self,uint64_t level,bool is_inplace_op)67 void TensorImpl::_set_fw_grad(
68     const at::TensorBase& new_grad,
69     const at::TensorBase& self,
70     uint64_t level,
71     bool is_inplace_op) {
72   if (!autograd_meta_)
73     autograd_meta_ = impl::GetAutogradMetaFactory()->make();
74   autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
75 }
76 
77 TensorImpl::~TensorImpl() = default;
78 
TensorImpl(Storage && storage,DispatchKeySet key_set,const caffe2::TypeMeta data_type)79 TensorImpl::TensorImpl(
80     Storage&& storage,
81     DispatchKeySet key_set,
82     const caffe2::TypeMeta data_type)
83     // Use std::forward to suppress static analyzer false positive.
84     : TensorImpl(
85           std::forward<Storage>(storage),
86           key_set,
87           data_type,
88           storage.device()) {}
89 
90 // [Note: Python key removal]
91 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92 // In most constructors for TensorImpl, you will see Python and
93 // PythonTLSSnapshot keys are removed from the passed in DispatchKeySet.  Why?
94 //
95 // INVARIANT: Python and PythonTLSSnapshot dispatch keys are set iff PyObject
96 // for the Tensor has a nontrivial __torch_dispatch__ implementation.
97 //
98 // When a fresh TensorImpl is created, there is *no* PyObject (this only gets
99 // initialized lazily at the first point in time the Tensor passes into Python).
100 // So we would violate the invariant.
101 //
102 // In practice, what will happen shortly afterwards is that the TensorImpl
103 // will get its PyObject initialized by Tensor._make_subclass; at this point
104 // the Python and PythonTLSSnapshot dispatch keys will be set and all is well.
105 // The point is to delay the dispatch key setting until that point.
106 
107 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl(ImplType type,Storage && storage,DispatchKeySet key_set,const caffe2::TypeMeta data_type)108 TensorImpl::TensorImpl(
109     ImplType type,
110     Storage&& storage,
111     DispatchKeySet key_set,
112     const caffe2::TypeMeta data_type)
113     : storage_(std::move(storage)),
114 
115       numel_(0),
116       data_type_(data_type),
117       device_opt_(storage_.device()),
118       key_set_(key_set - c10::python_ks) { // See [Note: Python key removal]
119   init_bitfields();
120   // Inference tensor doesn't have version counter.
121   if (!is_inference()) {
122     version_counter_ = VariableVersion(/*version=*/0);
123   }
124 }
125 
126 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl(DispatchKeySet key_set,const caffe2::TypeMeta data_type,std::optional<c10::Device> device_opt)127 TensorImpl::TensorImpl(
128     DispatchKeySet key_set,
129     const caffe2::TypeMeta data_type,
130     std::optional<c10::Device> device_opt)
131     : TensorImpl({}, key_set, data_type, device_opt) {}
132 
133 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl(Storage && storage,DispatchKeySet key_set,const caffe2::TypeMeta data_type,std::optional<c10::Device> device_opt)134 TensorImpl::TensorImpl(
135     Storage&& storage,
136     DispatchKeySet key_set,
137     const caffe2::TypeMeta data_type,
138     std::optional<c10::Device> device_opt)
139     : storage_(std::move(storage)),
140 
141       numel_(0),
142       data_type_(data_type),
143       device_opt_(device_opt) {
144   init_bitfields();
145 
146   if (!key_set.empty()) {
147     TORCH_INTERNAL_ASSERT(
148         data_type == ScalarType::Undefined || device_opt_.has_value());
149     // UndefinedTensorImpl is a singleton, so we skip logging it
150     C10_LOG_API_USAGE_ONCE("tensor.create");
151   }
152 
153   // XXX: if updating keyset logic here also update
154   // _change_backend_component_keys
155   bool inference_mode = c10::InferenceMode::is_enabled();
156 
157   // TODO: be more explicit about the full key set at call sites so we
158   // don't have to keep recomputing it here
159   auto k = key_set.highestBackendKey();
160 
161   key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
162 
163   // See [Note: Python key removal]
164   key_set = key_set - c10::python_ks;
165 
166   // Inference tensor doesn't have autograd related keys.
167   if (inference_mode) {
168     // See Note [Expected TLS state in InferenceMode] for why we exclude
169     // Autograd & ADInplaceOrView keys. Normally key_set only contains backend
170     // keys but we do the substraction here to make sure.
171     key_set_ = key_set - c10::autograd_dispatch_keyset_with_ADInplaceOrView;
172   } else {
173     // TODO: Ideally we only add AutogradBackend key when the tensor requires
174     // grad.
175     //       See Note [Dream: skip VariableType kernel when requires_grad=false]
176     key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
177   }
178 
179   // Inference tensor doesn't have version counter.
180   if (!is_inference()) {
181     version_counter_ = VariableVersion(/*version=*/0);
182   }
183   // we would also like to check that non-cpu devices have an index, but some
184   // Caffe2 operators create Storages with default devices.
185 }
186 
_change_backend_component_keys(c10::Device device)187 void TensorImpl::_change_backend_component_keys(c10::Device device) {
188   BackendComponent new_backend = toBackendComponent(device.type());
189   BackendComponent old_backend = key_set_.highestBackendKey();
190 
191   // following logic TensorImpl::TensorImpl, update the BackendComponent related
192   // keys to correspond to device
193 
194   // TODO: Autocoast should be a per-backend functionality key, once that change
195   // is made this key swap will not be necessary.
196   auto key_set =
197       key_set_ - c10::getAutocastRelatedKeySetFromBackend(old_backend);
198   key_set = key_set | c10::getAutocastRelatedKeySetFromBackend(new_backend);
199 
200   // See note [Removing keys from DispatchKeySet Only Affects Functionality
201   // Keys]
202   key_set = key_set.remove_backend(old_backend);
203   key_set_ = key_set | DispatchKeySet(new_backend);
204 }
205 
HandleResize()206 void TensorImpl::HandleResize() {
207   // If needed, we will free the data. the next mutable_data() call
208   // will create the data storage.
209   bool reset_tensor = false;
210   if (reserved_) {
211     // If tensor is reserved then don't claim its memory unless nbytes()
212     // is smaller than new size
213     reset_tensor =
214         storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize();
215   } else {
216     reset_tensor = storage_.nbytes() <
217             (storage_offset_ + numel_) * data_type_.itemsize() ||
218         !FLAGS_caffe2_keep_on_shrink ||
219         storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() >
220             static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
221   }
222 
223   if (reset_tensor && storage_initialized()) {
224     FreeMemory();
225   }
226 }
227 
compute_contiguous(identity<bool>) const228 bool TensorImpl::compute_contiguous(identity<bool>) const {
229   if (is_sparse()) {
230     return false;
231   }
232   return _compute_contiguous<int64_t>(
233       sizes_and_strides_.sizes_arrayref(),
234       sizes_and_strides_.strides_arrayref(),
235       numel_);
236 }
237 
compute_channels_last_contiguous_2d(identity<bool>) const238 bool TensorImpl::compute_channels_last_contiguous_2d(identity<bool>) const {
239   if (is_sparse()) {
240     return false;
241   }
242   return _compute_channels_last_contiguous_2d<int64_t>(
243       sizes_and_strides_.sizes_arrayref(),
244       sizes_and_strides_.strides_arrayref());
245 }
246 
compute_channels_last_contiguous_3d(identity<bool>) const247 bool TensorImpl::compute_channels_last_contiguous_3d(identity<bool>) const {
248   if (is_sparse()) {
249     return false;
250   }
251   return _compute_channels_last_contiguous_3d<int64_t>(
252       sizes_and_strides_.sizes_arrayref(),
253       sizes_and_strides_.strides_arrayref());
254 }
255 
compute_strides_like_channels_last_2d(identity<bool>) const256 bool TensorImpl::compute_strides_like_channels_last_2d(identity<bool>) const {
257   if (is_sparse()) {
258     return false;
259   }
260   return is_channels_last_strides_2d<int64_t>(
261       sizes_and_strides_.sizes_arrayref(),
262       sizes_and_strides_.strides_arrayref());
263 }
264 
compute_strides_like_channels_last_3d(identity<bool>) const265 bool TensorImpl::compute_strides_like_channels_last_3d(identity<bool>) const {
266   if (is_sparse()) {
267     return false;
268   }
269   return is_channels_last_strides_3d<int64_t>(
270       sizes_and_strides_.sizes_arrayref(),
271       sizes_and_strides_.strides_arrayref());
272 }
273 
compute_non_overlapping_and_dense(identity<bool>) const274 bool TensorImpl::compute_non_overlapping_and_dense(identity<bool>) const {
275   if (is_sparse()) {
276     return false;
277   }
278   return _compute_non_overlapping_and_dense<int64_t>(
279       sizes_and_strides_.sizes_arrayref(),
280       sizes_and_strides_.strides_arrayref());
281 }
282 
release_resources()283 void TensorImpl::release_resources() {
284   autograd_meta_.reset();
285   if (storage_) {
286     storage_ = {};
287   }
288   pyobj_slot_.maybe_destroy_pyobj();
289 }
290 
291 #ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
has_storage() const292 bool TensorImpl::has_storage() const {
293   return storage_;
294 }
295 #endif
296 
throw_cannot_call_with_symbolic(const char * meth) const297 void TensorImpl::throw_cannot_call_with_symbolic(const char* meth) const {
298   TORCH_CHECK_ALWAYS_SHOW_CPP_STACKTRACE(
299       false, "Cannot call ", meth, "() on tensor with symbolic sizes/strides");
300 }
301 
throw_storage_access_error() const302 void TensorImpl::throw_storage_access_error() const {
303   if (extra_meta_ && extra_meta_->custom_storage_error_msg_) {
304     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
305     TORCH_CHECK(false, *extra_meta_->custom_storage_error_msg_);
306   }
307   TORCH_CHECK_NOT_IMPLEMENTED(
308       false, "Cannot access storage of ", tensorimpl_type_name());
309 }
310 
throw_data_ptr_access_error() const311 void TensorImpl::throw_data_ptr_access_error() const {
312   if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) {
313     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
314     TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_);
315   }
316   TORCH_CHECK(
317       false, "Cannot access data pointer of Tensor that doesn't have storage");
318 }
319 
is_contiguous_custom(at::MemoryFormat memory_format) const320 bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
321   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
322     return pyobj_slot_.load_pyobj_interpreter()->is_contiguous(
323         this, memory_format);
324   }
325   return is_contiguous_default(memory_format);
326 }
327 
is_strides_like_custom(at::MemoryFormat memory_format) const328 bool TensorImpl::is_strides_like_custom(at::MemoryFormat memory_format) const {
329   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
330     return pyobj_slot_.load_pyobj_interpreter()->is_strides_like(
331         this, memory_format);
332   }
333   return is_strides_like_default(memory_format);
334 }
335 
is_non_overlapping_and_dense_custom() const336 bool TensorImpl::is_non_overlapping_and_dense_custom() const {
337   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
338     return pyobj_slot_.load_pyobj_interpreter()->is_non_overlapping_and_dense(
339         this);
340   }
341   return is_non_overlapping_and_dense_default();
342 }
343 
sizes_custom() const344 IntArrayRef TensorImpl::sizes_custom() const {
345   if (C10_UNLIKELY(
346           matches_python_custom(SizesStridesPolicy::CustomSizes) ||
347           has_symbolic_sizes_strides_)) {
348     return pyobj_slot_.load_pyobj_interpreter()->sizes(this);
349   }
350   return sizes_default();
351 }
352 
sym_sizes_custom() const353 c10::SymIntArrayRef TensorImpl::sym_sizes_custom() const {
354   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
355     return pyobj_slot_.load_pyobj_interpreter()->sym_sizes(this);
356   }
357   return sym_sizes_default();
358 }
359 
sym_numel_custom() const360 c10::SymInt TensorImpl::sym_numel_custom() const {
361   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
362     return pyobj_slot_.load_pyobj_interpreter()->sym_numel(this);
363   }
364   return sym_numel_default();
365 }
366 
sym_strides_custom() const367 c10::SymIntArrayRef TensorImpl::sym_strides_custom() const {
368   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomStrides))) {
369     return pyobj_slot_.load_pyobj_interpreter()->sym_strides(this);
370   }
371   return sym_strides_default();
372 }
373 
device_custom() const374 c10::Device TensorImpl::device_custom() const {
375   if (C10_UNLIKELY(python_custom_device_)) {
376     return pyobj_slot_.load_pyobj_interpreter()->device(this);
377   }
378   return device_default();
379 }
380 
strides_custom() const381 IntArrayRef TensorImpl::strides_custom() const {
382   if (C10_UNLIKELY(
383           matches_python_custom(SizesStridesPolicy::CustomStrides) ||
384           has_symbolic_sizes_strides_)) {
385     return pyobj_slot_.load_pyobj_interpreter()->strides(this);
386   }
387   return strides_default();
388 }
389 
dim_custom() const390 int64_t TensorImpl::dim_custom() const {
391   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
392     return pyobj_slot_.load_pyobj_interpreter()->dim(this);
393   }
394   return dim_default();
395 }
396 
numel_custom() const397 int64_t TensorImpl::numel_custom() const {
398   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
399     return pyobj_slot_.load_pyobj_interpreter()->numel(this);
400   }
401   return numel_default();
402 }
403 
layout_custom() const404 c10::Layout TensorImpl::layout_custom() const {
405   if (C10_UNLIKELY(python_custom_layout_)) {
406     return pyobj_slot_.load_pyobj_interpreter()->layout(this);
407   }
408   // TODO: fix this
409   TORCH_CHECK(
410       0, "Tensors of type ", tensorimpl_type_name(), " do not have layout")
411   // return layout_default();
412 }
413 
storage_offset_custom() const414 int64_t TensorImpl::storage_offset_custom() const {
415   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
416     // TODO: fix this
417     return pyobj_slot_.load_pyobj_interpreter()
418         ->sym_storage_offset(this)
419         .guard_int(__FILE__, __LINE__);
420   }
421   return storage_offset_default();
422 }
423 
sym_storage_offset_custom() const424 c10::SymInt TensorImpl::sym_storage_offset_custom() const {
425   if (C10_UNLIKELY(matches_python_custom(SizesStridesPolicy::CustomSizes))) {
426     return pyobj_slot_.load_pyobj_interpreter()->sym_storage_offset(this);
427   }
428   return sym_storage_offset_default();
429 }
430 
deletePlacementDeleteContext(void * ptr)431 static void deletePlacementDeleteContext(void* ptr) {
432   delete static_cast<PlacementDeleteContext*>(ptr);
433 }
434 
makeDataPtr(at::DataPtr && data_ptr,PlacementDtor placement_dtor,size_t size,at::Device device)435 at::DataPtr PlacementDeleteContext::makeDataPtr(
436     at::DataPtr&& data_ptr,
437     PlacementDtor placement_dtor,
438     size_t size,
439     at::Device device) {
440   auto* ptr = data_ptr.get();
441   return {
442       ptr,
443       new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
444       &deletePlacementDeleteContext,
445       device};
446 }
447 
448 AutogradMetaInterface::~AutogradMetaInterface() = default;
449 
450 // Setting requires_grad to true on inference tensor outside InferenceMode
451 // is forbidden.  Ideally it would also be illegal inside InferenceMode.
452 // But there's no way that we can directly allocate a tensor to have
453 // requires_grad = true in C++ constructor so set_requires_grad is widely
454 // used in C++ frontend. Forbidding it inside InferenceMode will force users
455 // to delete these setter code in their code which is not ideal.
set_requires_grad(bool requires_grad)456 void TensorImpl::set_requires_grad(bool requires_grad) {
457   TORCH_CHECK(
458       !(requires_grad && is_inference() && !c10::InferenceMode::is_enabled()),
459       "Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
460   if (!requires_grad && !autograd_meta_)
461     return;
462   if (!autograd_meta_)
463     autograd_meta_ = impl::GetAutogradMetaFactory()->make();
464   // NB: In principle, setting requires_grad to false could result in
465   // the AutogradMeta becoming equal to a default constructed state,
466   // in which case we could apply the nullptr AutogradMeta optimization
467   // (see autograd_meta_ docs).  But we don't do this right now.  Note
468   // that it is unsound to unconditionally set AutogradMeta to false
469   // when you set requires_grad to False, as there may be nontrivial
470   // information content in the other fields; for example, we may
471   // have set the string name for a Variable, or there may be hooks
472   // registered for it.
473   autograd_meta_->set_requires_grad(requires_grad, this);
474 }
475 
requires_grad() const476 bool TensorImpl::requires_grad() const {
477   if (!autograd_meta_)
478     return false;
479   return autograd_meta_->requires_grad();
480 }
481 
set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta)482 void TensorImpl::set_autograd_meta(
483     std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
484   // NB: autograd_meta may be null!  That just means it's the default
485   // constructor
486   autograd_meta_ = std::move(autograd_meta);
487 }
488 
autograd_meta() const489 c10::AutogradMetaInterface* TensorImpl::autograd_meta() const {
490   // NB: Might return null!
491   return autograd_meta_.get();
492 }
493 
494 template <typename VariableVersion>
shallow_copy_and_detach_core(VariableVersion && version_counter,bool allow_tensor_metadata_change) const495 c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach_core(
496     VariableVersion&& version_counter,
497     bool allow_tensor_metadata_change) const {
498   c10::intrusive_ptr<TensorImpl> r;
499   const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
500   // TODO: do we have to exclude after Python dispatch key set?
501   if (mode_stack_len > 0 &&
502       !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
503     const auto& cur_torch_dispatch_mode_state =
504         c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
505     r = cur_torch_dispatch_mode_state->pyinterpreter()->detach(this);
506   } else if (
507       key_set_.has(DispatchKey::Python) &&
508       !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
509     r = (pyobj_slot_.load_pyobj_interpreter())->detach(this);
510   }
511   if (r) {
512     if (!r->is_inference()) {
513       r->set_version_counter(std::forward<VariableVersion>(version_counter));
514     }
515     r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
516     return r;
517   }
518   // otherwise just copy the TensorImpl and not the PyObject.  Since
519   // the interpreter is dead no one can call us out on it
520   auto impl = c10::make_intrusive<TensorImpl>(
521       // No need to populate Storage; copy_tensor_metadata will do it for us.
522       key_set_,
523       data_type_,
524       device_opt_);
525   copy_tensor_metadata(
526       /*src_impl=*/this,
527       /*dest_impl=*/impl.get(),
528       /*version_counter=*/std::forward<VariableVersion>(version_counter),
529       /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
530   return impl;
531 }
532 
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const533 c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
534     const c10::VariableVersion& version_counter,
535     bool allow_tensor_metadata_change) const {
536   return shallow_copy_and_detach_core(
537       version_counter, allow_tensor_metadata_change);
538 }
539 
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const540 c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
541     c10::VariableVersion&& version_counter,
542     bool allow_tensor_metadata_change) const {
543   return shallow_copy_and_detach_core(
544       std::move(version_counter), allow_tensor_metadata_change);
545 }
546 
547 // This function copies all of the metadata from the src tensor except for:
548 // - key_set_
549 // - storage_
550 // - storage_access_should_throw_
551 // - sizes_strides_policy_
552 // - version_counter_
553 // - allow_tensor_metadata_change_
554 // The idea is that if we have a "wrapper tensor" (like in functionalization),
555 // all of the above are properties that the wrapper will want to customize,
556 // while everything else should be mirrored between the wrapper and the inner
557 // tensor.
copy_generic_tensor_metadata(const TensorImpl * src_impl,TensorImpl * dest_impl)558 void TensorImpl::copy_generic_tensor_metadata(
559     const TensorImpl* src_impl,
560     TensorImpl* dest_impl) {
561   dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_;
562   dest_impl->has_symbolic_sizes_strides_ =
563       src_impl->has_symbolic_sizes_strides_;
564 
565   dest_impl->storage_offset_ = src_impl->storage_offset_;
566   dest_impl->data_type_ = src_impl->data_type_;
567   dest_impl->device_opt_ = src_impl->device_opt_;
568   dest_impl->is_contiguous_ = src_impl->is_contiguous_;
569   dest_impl->is_channels_last_contiguous_ =
570       src_impl->is_channels_last_contiguous_;
571   dest_impl->is_channels_last_3d_contiguous_ =
572       src_impl->is_channels_last_3d_contiguous_;
573   dest_impl->is_channels_last_ = src_impl->is_channels_last_;
574   dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_;
575   dest_impl->is_non_overlapping_and_dense_ =
576       src_impl->is_non_overlapping_and_dense_;
577   dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
578   dest_impl->reserved_ = src_impl->reserved_;
579   dest_impl->numel_ = src_impl->numel_;
580   if (src_impl->extra_meta_ != nullptr) {
581     dest_impl->extra_meta_ = src_impl->extra_meta_->clone();
582   } else if (dest_impl->extra_meta_ != nullptr) {
583     // Clean dest_impl extra meta data, cause shallow_copy_from dest impl is a
584     // real tensor impl, which maybe take extra meta data. This info will
585     // contaminate the new dest_impl metadata info.
586     dest_impl->extra_meta_.reset(nullptr);
587   }
588 
589   // NB: symbolic sizes and strides are copied as is custom policy, but python
590   // policy is NOT (you have no Python object to dispatch to!)
591   // NB: subclass relevant policy doesn't have to be copied; the
592   // constructor sets this up
593 
594   dest_impl->refresh_sizes_strides_policy();
595   dest_impl->refresh_layout_policy();
596   dest_impl->refresh_device_policy();
597 }
598 
copy_tensor_metadata_except_version_counter(const TensorImpl * src_impl,TensorImpl * dest_impl,bool allow_tensor_metadata_change)599 void TensorImpl::copy_tensor_metadata_except_version_counter(
600     const TensorImpl* src_impl,
601     TensorImpl* dest_impl,
602     bool allow_tensor_metadata_change) {
603   // First call the generic copy function
604   copy_generic_tensor_metadata(src_impl, dest_impl);
605   // Then copy everything else (see the comment at copy_generic_tensor_metadata
606   // for the list of metadata that it does not directly copy).
607   dest_impl->storage_ = src_impl->storage_;
608   // Copying tensor metadata doesn't change the PyObject (maybe
609   // it should), which means that we have to preserve whatever the
610   // original Python keyset was (as it's associated with the PyObject
611   // being a tensor subclass or not)
612   dest_impl->key_set_ = (src_impl->key_set_ - c10::python_ks) |
613       (dest_impl->key_set_ & c10::python_ks);
614   dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
615   dest_impl->storage_access_should_throw_ =
616       src_impl->storage_access_should_throw_;
617 }
618 
copy_tensor_metadata(const TensorImpl * src_impl,TensorImpl * dest_impl,const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change)619 void TensorImpl::copy_tensor_metadata(
620     const TensorImpl* src_impl,
621     TensorImpl* dest_impl,
622     const c10::VariableVersion& version_counter,
623     bool allow_tensor_metadata_change) {
624   copy_tensor_metadata_except_version_counter(
625       src_impl, dest_impl, allow_tensor_metadata_change);
626   // TODO: In the ideal end state, it's okay to set disabled version_counter
627   // on inference tensor since it's a no-op. This requires refactor on call
628   // sites.
629   if (!dest_impl->is_inference()) {
630     dest_impl->set_version_counter(version_counter);
631   }
632 }
633 
copy_tensor_metadata(const TensorImpl * src_impl,TensorImpl * dest_impl,c10::VariableVersion && version_counter,bool allow_tensor_metadata_change)634 void TensorImpl::copy_tensor_metadata(
635     const TensorImpl* src_impl,
636     TensorImpl* dest_impl,
637     c10::VariableVersion&& version_counter,
638     bool allow_tensor_metadata_change) {
639   copy_tensor_metadata_except_version_counter(
640       src_impl, dest_impl, allow_tensor_metadata_change);
641   if (!dest_impl->is_inference()) {
642     dest_impl->set_version_counter(std::move(version_counter));
643   }
644 }
645 
646 // Legacy Caffe2 operations
647 
Extend(int64_t num,float growthPct)648 void TensorImpl::Extend(int64_t num, float growthPct) {
649   TORCH_CHECK(sizes_and_strides_.size() >= 1u);
650   TORCH_CHECK(num >= 0, "`num` must be non-negative for Extend");
651   TORCH_CHECK(
652       is_contiguous_,
653       "Right now Extend is only supported for contiguous Tensor.");
654   TORCH_CHECK(
655       !has_symbolic_sizes_strides_,
656       "Extend() called on tensor with symbolic shape")
657 
658   using SizesVector = SmallVector<int64_t, 5>;
659   IntArrayRef sizes_and_strides = sizes_and_strides_.sizes_arrayref();
660   SizesVector newDims(sizes_and_strides.begin(), sizes_and_strides.end());
661   newDims[0] += num;
662   if (!storage_.data()) {
663     Resize(newDims);
664     return;
665   }
666   const auto newNumel = c10::multiply_integers(newDims.begin(), newDims.end());
667   if (newNumel * data_type_.itemsize() <= storage_.nbytes()) {
668     sizes_and_strides_.set_sizes(newDims);
669     numel_ = newNumel;
670     return;
671   }
672   SizesVector newCapacity(sizes_and_strides.begin(), sizes_and_strides.end());
673   newCapacity[0] = std::max(
674       newDims[0],
675       static_cast<int64_t>(std::ceil(
676           static_cast<float>(sizes_and_strides_.size_at_unchecked(0)) *
677           (1 + growthPct / 100))));
678   auto oldData = std::move(storage_.mutable_data_ptr());
679   auto oldSize = numel_;
680   Resize(std::move(newCapacity));
681   auto* newData = raw_mutable_data(data_type_);
682   if (data_type_.copy()) {
683     TORCH_CHECK(
684         device_type() == DeviceType::CPU, "non-POD types work only on CPU");
685     data_type_.copy()(oldData.get(), newData, oldSize);
686   } else {
687     // The following copy uses the current (thread local) stream for copying
688     // and also takes the GPU id from the device() field passed in.
689     //
690     // TODO: Potentially more enforcements are necessary to avoid accidental
691     // switch to sync copy if the currently set device is wrong.
692     //
693     // Specifically, we might need to switch to a different context device
694     // here explicitly to avoid relying on user synchronizing things
695     // properly.
696     CopyBytes(
697         oldSize * itemsize(),
698         oldData.get(),
699         device(),
700         newData,
701         device(),
702         true); // non-blocking
703   }
704   reserved_ = true;
705   sizes_and_strides_.set_sizes(newDims);
706   numel_ = newNumel;
707 }
708 
ReserveSpace(int64_t outer_dim)709 void TensorImpl::ReserveSpace(int64_t outer_dim) {
710   TORCH_CHECK(
711       is_contiguous_,
712       "Right now ReserveSpace is only supported for contiguous Tensor.");
713   TORCH_CHECK(
714       !has_symbolic_sizes_strides_,
715       "ReserveSpace() called on tensor with symbolic shape")
716 
717   TORCH_CHECK(storage_.unique(), "Can't call ReserveSpace on shared storage.");
718   // TODO: eliminate newCapacity.
719   IntArrayRef sizes_and_strides = sizes_and_strides_.sizes_arrayref();
720   SmallVector<int64_t, 5> newCapacity(
721       sizes_and_strides.begin(), sizes_and_strides.end());
722   newCapacity[0] = outer_dim;
723   auto newNumel = c10::multiply_integers(newCapacity);
724   if (newNumel * data_type_.itemsize() <= storage_.nbytes()) {
725     return;
726   }
727   // Old data is discarded
728   storage_.mutable_data_ptr().clear();
729   auto oldSize = numel_;
730   SmallVector<int64_t, 5> oldDims(
731       sizes_and_strides.begin(), sizes_and_strides.end());
732   Resize(std::move(newCapacity));
733   // Allocate new memory but don't copy over the data
734   raw_mutable_data(data_type_);
735   sizes_and_strides_.set_sizes(oldDims);
736   numel_ = oldSize;
737   reserved_ = true;
738 }
739 
Reshape(const std::vector<int64_t> & dims)740 void TensorImpl::Reshape(const std::vector<int64_t>& dims) {
741   TORCH_CHECK(
742       is_contiguous_,
743       "Right now Reshape is only supported for contiguous Tensor.");
744   TORCH_CHECK(
745       !has_symbolic_sizes_strides_,
746       "Reshape() called on tensor with symbolic shape")
747 
748   int64_t new_size = 1;
749   for (auto d : dims) {
750     TORCH_CHECK(d >= 0);
751     new_size *= d;
752   }
753   TORCH_CHECK(
754       new_size == numel_,
755       "New size and old size are not equal. You cannot use Reshape, "
756       "but should use Resize."
757       // TODO(jiayq): remove the following warning after pending diffs
758       // stabilize.
759       " The old caffe2 mixes Reshape and Resize but this behavior has "
760       "been changed. If you find this error, most likely you will need "
761       "to change corresponding code from Reshape to Resize.");
762   sizes_and_strides_.set_sizes(dims);
763   empty_tensor_restride(MemoryFormat::Contiguous);
764 }
765 
FreeMemory()766 void TensorImpl::FreeMemory() {
767   // We'll detach from the old Storage and create a new one
768   if (storage_.use_count() != 1 || !storage_.resizable() ||
769       !storage_.allocator()) {
770     storage_ = Storage::create_legacy(storage_.device());
771   } else {
772     storage_.reset_legacy();
773   }
774   storage_offset_ = 0;
775 }
776 
ShareData(const TensorImpl & src)777 void TensorImpl::ShareData(const TensorImpl& src) {
778   // Right now, we are assuming the device_type are the same, since it is
779   // inherently the same in the non-templatized code. We should probably add
780   // an assert here which might affect perf a little bit.
781   TORCH_CHECK(
782       src.numel_ == numel_,
783       "Size mismatch - did you call reshape before sharing the data?");
784   // It is possible that the source tensor hasn't called mutable_data() yet,
785   // in which case ShareData() doesn't make much sense since we don't really
786   // know what to share yet.
787   // TODO: Add the assert after all uninitialized states are eliminated
788   // TORCH_CHECK(src.dtype_initialized(),
789   //            "Source tensor don't have a data type (did you call
790   //            mutable_data<T> on the tensor?)");
791   if (!src.dtype_initialized()) {
792     C10_LOG_EVERY_MS(WARNING, 1000)
793         << "Source tensor don't have a data type (did you call mutable_data<T> on the tensor?)";
794   }
795   TORCH_CHECK(
796       src.storage_initialized(),
797       "Source tensor has no content and has size > 0");
798   // Finally, do sharing.
799   /* Since we create new Storage whenever we need to change data_type/nbytes
800    * this still keeps the original semantics
801    */
802   storage_ = src.storage();
803   data_type_ = src.dtype();
804   device_opt_ = src.device_opt();
805   storage_offset_ = src.storage_offset();
806 }
807 
ShareExternalPointer(DataPtr && data_ptr,const caffe2::TypeMeta data_type,size_t size_bytes)808 void TensorImpl::ShareExternalPointer(
809     DataPtr&& data_ptr,
810     const caffe2::TypeMeta data_type,
811     size_t size_bytes) {
812   TORCH_CHECK(
813       data_type != ScalarType::Undefined,
814       "To share with a raw external pointer you need to pass in an "
815       "initialized data_type(TypeMeta).");
816   TORCH_CHECK(
817       !has_symbolic_sizes_strides_,
818       "ShareExternalPointer() called on tensor with symbolic shape");
819   if (!size_bytes) {
820     size_bytes = numel_ * data_type.itemsize();
821   }
822   if (storage_.unique()) {
823     storage_.UniqueStorageShareExternalPointer(std::move(data_ptr), size_bytes);
824     data_type_ = data_type;
825     device_opt_ = storage_.device();
826     storage_offset_ = 0;
827   } else {
828     // Create a new Storage
829     storage_ = Storage(
830         Storage::use_byte_size_t(),
831         size_bytes,
832         std::move(data_ptr),
833         /*allocator=*/nullptr,
834         /*resizable=*/false);
835     data_type_ = data_type;
836     device_opt_ = storage_.device();
837     storage_offset_ = 0;
838   }
839 }
840 
clone_symvec(SymIntArrayRef src,SymDimVector & dst)841 static void clone_symvec(SymIntArrayRef src, SymDimVector& dst) {
842   dst.clear();
843   dst.reserve(src.size());
844   for (const auto& i : src) {
845     dst.emplace_back(i.clone());
846   }
847 }
848 
849 // NB: this doesn't check that the sizes/strides/offset are in bound for the
850 // storage, and furthermore, it CANNOT do so as in some cases we temporarily
851 // violate invariants by first setting sizes/strides, and then updating the
852 // storage
set_sizes_and_strides(c10::SymIntArrayRef sizes,c10::SymIntArrayRef strides,std::optional<c10::SymInt> storage_offset)853 void TensorImpl::set_sizes_and_strides(
854     c10::SymIntArrayRef sizes,
855     c10::SymIntArrayRef strides,
856     std::optional<c10::SymInt> storage_offset) {
857   auto int_sizes = asIntArrayRefSlowOpt(sizes);
858   auto int_strides = asIntArrayRefSlowOpt(strides);
859   if (int_sizes && int_strides &&
860       // NB: storage_offset guaranteed to be positive
861       (!storage_offset.has_value() || !storage_offset->is_heap_allocated()) &&
862       !has_symbolic_sizes_strides_) {
863     set_sizes_and_strides(*int_sizes, *int_strides);
864     if (storage_offset.has_value())
865       set_storage_offset(storage_offset->as_int_unchecked());
866     return;
867   }
868   TORCH_CHECK(
869       allow_tensor_metadata_change(),
870       "set_sizes_and_strides ",
871       err_msg_tensor_metadata_change_not_allowed);
872 
873   has_symbolic_sizes_strides_ = true;
874   refresh_sizes_strides_policy();
875   if (!extra_meta_) {
876     extra_meta_ = std::make_unique<ExtraMeta>();
877     extra_meta_->symbolic_shape_meta_ =
878         std::make_unique<c10::SymbolicShapeMeta>();
879     extra_meta_->symbolic_shape_meta_->strides_valid_ = !is_sparse();
880     if (!storage_offset.has_value()) {
881       extra_meta_->symbolic_shape_meta_->storage_offset_ = storage_offset_;
882     }
883   }
884 
885   auto& sym_shape_meta{symbolic_shape_meta()};
886   clone_symvec(sizes, sym_shape_meta.sizes_);
887   clone_symvec(strides, sym_shape_meta.strides_);
888   if (storage_offset.has_value())
889     sym_shape_meta.storage_offset_ = storage_offset->clone();
890 
891   refresh_numel();
892   refresh_contiguous();
893 }
894 
generic_set_sizes_contiguous(SymIntArrayRef sizes)895 void TensorImpl::generic_set_sizes_contiguous(SymIntArrayRef sizes) {
896   auto int_sizes = asIntArrayRefSlowOpt(sizes);
897   if (int_sizes.has_value()) {
898     set_sizes_contiguous(*int_sizes);
899     return;
900   }
901 
902   TORCH_CHECK(
903       allow_tensor_metadata_change(),
904       "generic_set_sizes_contiguous ",
905       err_msg_tensor_metadata_change_not_allowed);
906 
907   has_symbolic_sizes_strides_ = true;
908   refresh_sizes_strides_policy();
909   auto& extra_meta{get_extra_meta()};
910   if (extra_meta.symbolic_shape_meta_ == nullptr) {
911     extra_meta_->symbolic_shape_meta_ =
912         std::make_unique<c10::SymbolicShapeMeta>();
913     extra_meta_->symbolic_shape_meta_->strides_valid_ = !is_sparse();
914   }
915 
916   clone_symvec(sizes, symbolic_shape_meta().sizes_);
917   refresh_numel();
918   empty_tensor_restride_symint(
919       MemoryFormat::Contiguous); // calls refresh_contiguous()
920 }
921 
empty_tensor_restride_symint(MemoryFormat memory_format)922 void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) {
923   TORCH_INTERNAL_ASSERT(has_symbolic_sizes_strides_);
924   auto& sym_shape_meta{symbolic_shape_meta()};
925   switch (memory_format) {
926     case MemoryFormat::Contiguous: {
927       // TODO: figure out if the non-symint version can also devirtualize;
928       // the last time we tried it was probably a narrowing problem
929       const auto dim_ = sym_shape_meta.dim();
930       sym_shape_meta.strides_.resize(dim_);
931       if (dim_ > 0) {
932         const auto last_idx = dim_ - 1;
933         sym_shape_meta.strides_[last_idx] = c10::SymInt(1);
934         for (auto i = last_idx - 1; i >= 0; --i) {
935           sym_shape_meta.strides_[i] = sym_shape_meta.strides_[i + 1] *
936               sym_shape_meta.sizes_[i + 1].max(1);
937         }
938       }
939       break;
940     }
941     case MemoryFormat::ChannelsLast: {
942       TORCH_CHECK(
943           dim() == 4, "required rank 4 tensor to use channels_last format");
944       clone_symvec(
945           get_channels_last_strides_2d(sym_sizes()), sym_shape_meta.strides_);
946       break;
947     }
948     case MemoryFormat::ChannelsLast3d: {
949       TORCH_CHECK(
950           dim() == 5, "required rank 5 tensor to use channels_last_3d format");
951       clone_symvec(
952           get_channels_last_strides_3d(sym_sizes()), sym_shape_meta.strides_);
953       break;
954     }
955     case MemoryFormat::Preserve:
956       TORCH_CHECK(false, "unsupported memory format ", memory_format);
957       // Cleaning warning messages, no need to break as TORCH_CHECK(false)
958       // terminates flow.
959       // break;
960     case MemoryFormat::NumOptions:
961       TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format);
962   }
963   // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
964   // exclusive see #24090
965   refresh_contiguous();
966   // hard code some known true settings, for unbacked case
967   // TODO: avoid chundering into the guards for computing these
968   switch (memory_format) {
969     case MemoryFormat::Contiguous: {
970       sym_shape_meta.assume_contiguous();
971       sym_shape_meta.assume_non_overlapping_and_dense();
972       break;
973     }
974     case MemoryFormat::ChannelsLast: {
975       sym_shape_meta.assume_channels_last_contiguous();
976       sym_shape_meta.assume_channels_last();
977       sym_shape_meta.assume_non_overlapping_and_dense();
978       break;
979     }
980     case MemoryFormat::ChannelsLast3d: {
981       sym_shape_meta.assume_channels_last_3d_contiguous();
982       sym_shape_meta.assume_channels_last_3d();
983       sym_shape_meta.assume_non_overlapping_and_dense();
984       break;
985     }
986     default:
987       break;
988   }
989 }
990 
991 namespace impl {
992 
993 namespace {
994 AutogradMetaFactory* meta_factory = nullptr;
995 } // namespace
996 
SetAutogradMetaFactory(AutogradMetaFactory * factory)997 void SetAutogradMetaFactory(AutogradMetaFactory* factory) {
998   meta_factory = factory;
999 }
GetAutogradMetaFactory()1000 AutogradMetaFactory* GetAutogradMetaFactory() {
1001   TORCH_CHECK(
1002       meta_factory,
1003       "Support for autograd has not been loaded; have you linked against libtorch.so?")
1004   return meta_factory;
1005 }
1006 
1007 } // namespace impl
1008 
1009 } // namespace c10
1010