xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <c10/core/SymIntArrayRef.h>
5 #include <c10/core/TensorImpl.h>
6 
7 #include <torch/csrc/lazy/core/tensor.h>
8 
9 namespace torch {
10 namespace lazy {
11 
12 // Tensor implementation class used to be fed to the at::Tensor.
13 // Its scope is just to handle an LazyTensor.
14 class TORCH_API LTCTensorImpl final : public c10::TensorImpl {
15  public:
16   explicit LTCTensorImpl(const LazyTensorPtr& tensor);
17   explicit LTCTensorImpl(const LazyTensor& tensor);
18   explicit LTCTensorImpl(LazyTensor&& tensor);
19 
tensor()20   LazyTensorPtr tensor() {
21     return tensor_;
22   }
23 
24   void set_tensor(const LazyTensorPtr& lazy_tensor);
25 
force_refresh_sizes()26   void force_refresh_sizes() {
27     generation_ = 0;
28   }
29 
30   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
31       const c10::VariableVersion& version_counter,
32       bool allow_tensor_metadata_change) const override;
33 
34   c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
35       c10::VariableVersion&& version_counter,
36       bool allow_tensor_metadata_change) const override;
37 
38   void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override;
39 
40   at::IntArrayRef sizes_custom() const override;
41   at::IntArrayRef strides_custom() const override;
42   int64_t numel_custom() const override;
43   int64_t storage_offset_custom() const override;
44   int64_t dim_custom() const override;
45   bool is_contiguous_custom(at::MemoryFormat memory_format) const override;
46   bool is_strides_like_custom(at::MemoryFormat memory_format) const override;
47   bool is_non_overlapping_and_dense_custom() const override;
48 
49   c10::SymIntArrayRef sym_sizes_custom() const override;
50   c10::SymIntArrayRef sym_strides_custom() const override;
51   c10::SymInt sym_numel_custom() const override;
52 
53  private:
54   void setup_size_properties();
55 
56   LazyTensorPtr tensor_;
57   mutable std::optional<std::vector<c10::SymInt>> sym_sizes_;
58   size_t generation_{0};
59 };
60 
61 } // namespace lazy
62 } // namespace torch
63