xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/SymNodeImpl.h>
4 #include <c10/util/intrusive_ptr.h>
5 #include <torch/csrc/lazy/backend/backend_data.h>
6 #include <torch/csrc/lazy/backend/backend_device.h>
7 #include <torch/csrc/lazy/core/ir.h>
8 #include <torch/csrc/lazy/core/util.h>
9 
10 namespace torch {
11 namespace lazy {
12 
13 class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
14  public:
SymNodeImpl(NodePtr ptr)15   SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
16   NodePtr node_;
17 };
18 
19 class LazyTensor;
20 using LazyTensorPtr = c10::intrusive_ptr<LazyTensor>;
21 
22 class TORCH_API LazyTensor : public c10::intrusive_ptr_target {
23  public:
24   // This is the core lazy tensor data structure where all the tensor data is
25   // held. The lazy tensor is nothing more than a shared pointer to a Data
26   // object.
27   struct Data {
DataData28     Data(BackendDataPtr handle, BackendDevice device)
29         : handle(std::move(handle)),
30           device(std::move(device)),
31           unique_id(GetNextTensorId()) {}
DataData32     Data(Value ir_value, BackendDevice device)
33         : ir_value(std::move(ir_value)),
34           device(std::move(device)),
35           unique_id(GetNextTensorId()) {}
DataData36     Data(at::Tensor tensor_data, BackendDevice device)
37         : tensor_data(std::move(tensor_data)),
38           device(std::move(device)),
39           unique_id(GetNextTensorId()) {}
40     // TODO(alanwaketan): Remove this ctor. This is a
41     // temporary ctor to ease XLA LTC migration. It depends on
42     // XLA's Functionalization integration.
DataData43     Data(BackendDevice device)
44         : device(std::move(device)), unique_id(GetNextTensorId()) {}
45 
46     virtual ~Data();
47 
48     BackendDataPtr handle;
49     Value ir_value;
50     std::optional<at::Tensor> tensor_data;
51     const BackendDevice device;
52     const int64_t unique_id = 0;
53     size_t generation = 1;
54   };
55 
56   static LazyTensorPtr Create(
57       const at::Tensor& tensor,
58       const BackendDevice& device);
59   static LazyTensorPtr Create(Value ir_value, const BackendDevice& device);
60   static LazyTensorPtr Create(BackendDataPtr handle);
61   static LazyTensorPtr Create(std::shared_ptr<Data> data);
62 
63   // The default ctor previously created a null LazyTensor (one with no 'data'
64   // obj). Creating a null LazyTensor is no longer possible, since the same can
65   // be achieved by creating a null LazyTensorPtr and it is way too confusing to
66   // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that
67   // used to rely on a LazyTensor obj with a null Data can now rely on a null
68   // LazyTensorPtr instead.
69   LazyTensor() = delete;
70   LazyTensor(const LazyTensor&) = default;
71   LazyTensor(LazyTensor&&) noexcept = default;
72 
73   ~LazyTensor() override = default;
74 
generation()75   size_t generation() const {
76     return data()->generation;
77   }
78 
79   // Override it to use your own Shape.
80   virtual int64_t size(int64_t dim) const;
81 
82   // Override it to use your own graph executor.
83   virtual at::Tensor ToTensor(bool detached);
84 
85   void ShallowCopyTo(LazyTensorPtr dest) const;
86 
87   // Assigns the tensor value to the lazy tensor.
88   void SetTensor(at::Tensor tensor);
89 
90   void UpdateFromTensor(at::Tensor tensor, bool sync);
91   void UpdateFromTensorOut(at::Tensor tensor);
92   void UpdateFromTensorOut(const LazyTensorPtr& tensor);
93 
94   const std::shared_ptr<Data>& data() const;
95 
96   // Override it to use your own type conversion.
97   virtual at::ScalarType dtype() const;
98 
99   MaybeRef<Shape> shape() const;
100 
101   const BackendDevice& GetDevice() const;
102   int64_t GetUniqueId() const;
103 
104   // Fetches the data behind the tensor. If the tensor has a graph defining
105   // its current value, executes the graph and fetches the data result.
106   BackendDataPtr GetDataHandle();
107 
108   // Fetches the current value of the data, which can be missing (nullptr)
109   // in case the tensor has a graph defining its current value,
110   BackendDataPtr CurrentDataHandle() const;
111 
112   void SetDataHandle(BackendDataPtr handle);
113   void SetDataHandle(BackendDataPtr handle, bool sync);
114 
115   // Retrieves the current IR Node, or nullptr in case no active IR Node is
116   // available.
117   Value CurrentIrValue() const;
118 
119   // Retrieves the IR Node representing this LazyTensor. One will be created if
120   // missing. Note that although this is a const API, it actually changes the
121   // internal state ofthe object.
122   Value GetIrValue() const;
123 
124   void SetIrValue(Value ir_value);
125   void SetInPlaceIrValue(Value ir_value);
126 
127   std::optional<at::Tensor> CurrentTensorData() const;
128 
129   std::vector<LazyTensorPtr> MakeOutputTensors(NodePtr node) const;
130 
131   LazyTensorPtr CopyTensorToDevice(const BackendDevice& device);
132 
133   // Applies the queue of operations in preparation for using the data.
134   // Override it to use your own graph executor.
135   virtual void ApplyPendingGraph();
136 
137   // Override it to set extra information.
138   virtual void AssignIrValue(Value ir_value) const;
139 
140  protected:
141   explicit LazyTensor(std::shared_ptr<Data> data);
142 
143   void SetTensorData(at::Tensor tensor_data);
144 
145   // We build a graph accumulating operations, but at a given point we
146   // need to force a rendering, otherwise the graph can grow without control.
147   // Think:
148   //   for i in range(0, 100000):
149   //     a = a + b
150   void TryLimitGraphSize();
151 
152   // Override it to instantiate your own data.
153   virtual Value GetIrValueForTensor(
154       const at::Tensor& tensor,
155       const BackendDevice& device) const;
156 
157   Value CreateTensorNode(BackendDataPtr data, bool read_only) const;
158 
159  private:
160   LazyTensor(const at::Tensor& tensor, const BackendDevice& device);
161   LazyTensor(Value ir_value, const BackendDevice& device);
162   explicit LazyTensor(BackendDataPtr handle);
163 
164   static int64_t GetNextTensorId();
165 
166   std::shared_ptr<Data> data_;
167 };
168 
169 // Utils to convert at::Tensor to LazyTensor, and vice versa.
170 
171 // Section 0: c10::Tensorlist ==> lazy::TensorList
172 // note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList
173 // skips
174 //       the LazyTensor wrappers, assuming that the list of underlying IR nodes
175 //       is actually more useful for downstream computations.  TBD.
176 TORCH_API torch::lazy::Value GetTensorList(at::ITensorListRef tensors);
177 
178 // Section 1: at::Tensor => LazyTensor.
179 // Extracts the LazyTensor out of an at::Tensor. Returns a null LazyTensor
180 // if the tensor is not a lazy tensor.
181 TORCH_API LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor);
182 
183 // Extracts the LazyTensor out of an at::Tensor. Throws an exception
184 // if the tensor is not a lazy tensor.
185 TORCH_API LazyTensorPtr GetLtcTensor(const at::Tensor& tensor);
186 
187 // Same as above, applied to a list of tensors.
188 TORCH_API std::vector<LazyTensorPtr> GetLtcTensors(
189     c10::ArrayRef<at::Tensor> tensors);
190 
191 // If tensor is a lazy tensor type, returns the LazyTensor embedded within it,
192 // otherwise creates a new lazy tensor type with tensor as data.
193 TORCH_API LazyTensorPtr GetOrCreateLtcTensor(
194     const std::optional<at::Tensor>& tensor,
195     const BackendDevice& device);
196 
197 TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
198     const at::Tensor& tensor,
199     const BackendDevice& device);
200 
201 // Section 2: LazyTensor => at::Tensor.
202 // Creates an ATen tensor from an LazyTensor.
203 TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
204 TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor);
205 
206 // Note [Lazy Tensor Functionalization]
207 // The functionalization pass is implemented by wrapping all TensorImpl
208 // objects in C++ with an extra FunctionalTensorWrapper object,
209 // that knows how to perform functionalization
210 //
211 // Certain functions in the aten API serve as entry/exit points for
212 // functionalization, where we need to perform the wrapping/unwrapping:
213 // - aten::to.device
214 // - aten::empty
215 
216 // Given a non-lazy tensor, this function creates a lazy tensor on the specified
217 // (lazy) device. The functionalize_output determines whether or not we should
218 // wrap the output in a "functional wrapper".
219 //
220 // How do you know whether to pass true/false for functionalize_output?
221 //
222 // Case 1: nonlazy -> lazy
223 //   If you're implementing a function that takes in nonlazy tensors and returns
224 //   lazy tensors, then you should think of that function as an "entrypoint" to
225 //   functionalization, and use functionalize_output=true Examples include:
226 //   - factory functions (the LTC kernel for at::empty)
227 //   - CPU -> Lazy device converions (the LTC kernel for at::to_device)
228 //
229 // Case 2: lazy -> lazy
230 //   If you're implementing a function that takes in lazy tensors and returns
231 //   lazy tensors,
232 //   **but** requires creating lazy tensors internally,
233 //   then you can assume that the current function is running inside of some
234 //   outer context where functionalization is already running, that will take
235 //   care of doing the wrapping for you, and use functionalize_output=true
236 //   Examples include:
237 //   - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel,
238 //   converts returns back to lazy tensors).
239 TORCH_API at::Tensor to_lazy_tensor(
240     const at::Tensor& self,
241     const c10::TensorOptions& options,
242     at::Device device,
243     bool non_blocking,
244     bool functionalize_output);
245 
246 template <size_t... Indices>
TupleAtenFromLtcTensorsImpl(const std::vector<LazyTensorPtr> & tensors,std::index_sequence<Indices...>)247 auto TupleAtenFromLtcTensorsImpl(
248     const std::vector<LazyTensorPtr>& tensors,
249     std::index_sequence<Indices...>) {
250   return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...);
251 }
252 
253 template <size_t N>
TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr> & tensors)254 auto TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr>& tensors) {
255   return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence<N>{});
256 }
257 
258 } // namespace lazy
259 } // namespace torch
260