1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <c10/core/SymIntArrayRef.h> 5 #include <c10/core/TensorImpl.h> 6 7 #include <torch/csrc/lazy/core/tensor.h> 8 9 namespace torch { 10 namespace lazy { 11 12 // Tensor implementation class used to be fed to the at::Tensor. 13 // Its scope is just to handle an LazyTensor. 14 class TORCH_API LTCTensorImpl final : public c10::TensorImpl { 15 public: 16 explicit LTCTensorImpl(const LazyTensorPtr& tensor); 17 explicit LTCTensorImpl(const LazyTensor& tensor); 18 explicit LTCTensorImpl(LazyTensor&& tensor); 19 tensor()20 LazyTensorPtr tensor() { 21 return tensor_; 22 } 23 24 void set_tensor(const LazyTensorPtr& lazy_tensor); 25 force_refresh_sizes()26 void force_refresh_sizes() { 27 generation_ = 0; 28 } 29 30 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 31 const c10::VariableVersion& version_counter, 32 bool allow_tensor_metadata_change) const override; 33 34 c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach( 35 c10::VariableVersion&& version_counter, 36 bool allow_tensor_metadata_change) const override; 37 38 void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) override; 39 40 at::IntArrayRef sizes_custom() const override; 41 at::IntArrayRef strides_custom() const override; 42 int64_t numel_custom() const override; 43 int64_t storage_offset_custom() const override; 44 int64_t dim_custom() const override; 45 bool is_contiguous_custom(at::MemoryFormat memory_format) const override; 46 bool is_strides_like_custom(at::MemoryFormat memory_format) const override; 47 bool is_non_overlapping_and_dense_custom() const override; 48 49 c10::SymIntArrayRef sym_sizes_custom() const override; 50 c10::SymIntArrayRef sym_strides_custom() const override; 51 c10::SymInt sym_numel_custom() const override; 52 53 private: 54 void setup_size_properties(); 55 56 LazyTensorPtr tensor_; 57 mutable std::optional<std::vector<c10::SymInt>> sym_sizes_; 58 size_t generation_{0}; 59 }; 60 61 } // namespace lazy 62 } // namespace torch 63