1 #pragma once
2
3 #include <c10/core/SymNodeImpl.h>
4 #include <c10/util/intrusive_ptr.h>
5 #include <torch/csrc/lazy/backend/backend_data.h>
6 #include <torch/csrc/lazy/backend/backend_device.h>
7 #include <torch/csrc/lazy/core/ir.h>
8 #include <torch/csrc/lazy/core/util.h>
9
10 namespace torch {
11 namespace lazy {
12
13 class TORCH_API SymNodeImpl : public c10::SymNodeImpl {
14 public:
SymNodeImpl(NodePtr ptr)15 SymNodeImpl(NodePtr ptr) : node_(std::move(ptr)){};
16 NodePtr node_;
17 };
18
19 class LazyTensor;
20 using LazyTensorPtr = c10::intrusive_ptr<LazyTensor>;
21
22 class TORCH_API LazyTensor : public c10::intrusive_ptr_target {
23 public:
24 // This is the core lazy tensor data structure where all the tensor data is
25 // held. The lazy tensor is nothing more than a shared pointer to a Data
26 // object.
27 struct Data {
DataData28 Data(BackendDataPtr handle, BackendDevice device)
29 : handle(std::move(handle)),
30 device(std::move(device)),
31 unique_id(GetNextTensorId()) {}
DataData32 Data(Value ir_value, BackendDevice device)
33 : ir_value(std::move(ir_value)),
34 device(std::move(device)),
35 unique_id(GetNextTensorId()) {}
DataData36 Data(at::Tensor tensor_data, BackendDevice device)
37 : tensor_data(std::move(tensor_data)),
38 device(std::move(device)),
39 unique_id(GetNextTensorId()) {}
40 // TODO(alanwaketan): Remove this ctor. This is a
41 // temporary ctor to ease XLA LTC migration. It depends on
42 // XLA's Functionalization integration.
DataData43 Data(BackendDevice device)
44 : device(std::move(device)), unique_id(GetNextTensorId()) {}
45
46 virtual ~Data();
47
48 BackendDataPtr handle;
49 Value ir_value;
50 std::optional<at::Tensor> tensor_data;
51 const BackendDevice device;
52 const int64_t unique_id = 0;
53 size_t generation = 1;
54 };
55
56 static LazyTensorPtr Create(
57 const at::Tensor& tensor,
58 const BackendDevice& device);
59 static LazyTensorPtr Create(Value ir_value, const BackendDevice& device);
60 static LazyTensorPtr Create(BackendDataPtr handle);
61 static LazyTensorPtr Create(std::shared_ptr<Data> data);
62
63 // The default ctor previously created a null LazyTensor (one with no 'data'
64 // obj). Creating a null LazyTensor is no longer possible, since the same can
65 // be achieved by creating a null LazyTensorPtr and it is way too confusing to
66 // have to check both lazy_tensor_ptr && *lazy_tensor_ptr, so everywhere that
67 // used to rely on a LazyTensor obj with a null Data can now rely on a null
68 // LazyTensorPtr instead.
69 LazyTensor() = delete;
70 LazyTensor(const LazyTensor&) = default;
71 LazyTensor(LazyTensor&&) noexcept = default;
72
73 ~LazyTensor() override = default;
74
generation()75 size_t generation() const {
76 return data()->generation;
77 }
78
79 // Override it to use your own Shape.
80 virtual int64_t size(int64_t dim) const;
81
82 // Override it to use your own graph executor.
83 virtual at::Tensor ToTensor(bool detached);
84
85 void ShallowCopyTo(LazyTensorPtr dest) const;
86
87 // Assigns the tensor value to the lazy tensor.
88 void SetTensor(at::Tensor tensor);
89
90 void UpdateFromTensor(at::Tensor tensor, bool sync);
91 void UpdateFromTensorOut(at::Tensor tensor);
92 void UpdateFromTensorOut(const LazyTensorPtr& tensor);
93
94 const std::shared_ptr<Data>& data() const;
95
96 // Override it to use your own type conversion.
97 virtual at::ScalarType dtype() const;
98
99 MaybeRef<Shape> shape() const;
100
101 const BackendDevice& GetDevice() const;
102 int64_t GetUniqueId() const;
103
104 // Fetches the data behind the tensor. If the tensor has a graph defining
105 // its current value, executes the graph and fetches the data result.
106 BackendDataPtr GetDataHandle();
107
108 // Fetches the current value of the data, which can be missing (nullptr)
109 // in case the tensor has a graph defining its current value,
110 BackendDataPtr CurrentDataHandle() const;
111
112 void SetDataHandle(BackendDataPtr handle);
113 void SetDataHandle(BackendDataPtr handle, bool sync);
114
115 // Retrieves the current IR Node, or nullptr in case no active IR Node is
116 // available.
117 Value CurrentIrValue() const;
118
119 // Retrieves the IR Node representing this LazyTensor. One will be created if
120 // missing. Note that although this is a const API, it actually changes the
121 // internal state ofthe object.
122 Value GetIrValue() const;
123
124 void SetIrValue(Value ir_value);
125 void SetInPlaceIrValue(Value ir_value);
126
127 std::optional<at::Tensor> CurrentTensorData() const;
128
129 std::vector<LazyTensorPtr> MakeOutputTensors(NodePtr node) const;
130
131 LazyTensorPtr CopyTensorToDevice(const BackendDevice& device);
132
133 // Applies the queue of operations in preparation for using the data.
134 // Override it to use your own graph executor.
135 virtual void ApplyPendingGraph();
136
137 // Override it to set extra information.
138 virtual void AssignIrValue(Value ir_value) const;
139
140 protected:
141 explicit LazyTensor(std::shared_ptr<Data> data);
142
143 void SetTensorData(at::Tensor tensor_data);
144
145 // We build a graph accumulating operations, but at a given point we
146 // need to force a rendering, otherwise the graph can grow without control.
147 // Think:
148 // for i in range(0, 100000):
149 // a = a + b
150 void TryLimitGraphSize();
151
152 // Override it to instantiate your own data.
153 virtual Value GetIrValueForTensor(
154 const at::Tensor& tensor,
155 const BackendDevice& device) const;
156
157 Value CreateTensorNode(BackendDataPtr data, bool read_only) const;
158
159 private:
160 LazyTensor(const at::Tensor& tensor, const BackendDevice& device);
161 LazyTensor(Value ir_value, const BackendDevice& device);
162 explicit LazyTensor(BackendDataPtr handle);
163
164 static int64_t GetNextTensorId();
165
166 std::shared_ptr<Data> data_;
167 };
168
169 // Utils to convert at::Tensor to LazyTensor, and vice versa.
170
171 // Section 0: c10::Tensorlist ==> lazy::TensorList
172 // note: GetTensorList is not totally parallel to GetLtcTensor; A TensorList
173 // skips
174 // the LazyTensor wrappers, assuming that the list of underlying IR nodes
175 // is actually more useful for downstream computations. TBD.
176 TORCH_API torch::lazy::Value GetTensorList(at::ITensorListRef tensors);
177
178 // Section 1: at::Tensor => LazyTensor.
179 // Extracts the LazyTensor out of an at::Tensor. Returns a null LazyTensor
180 // if the tensor is not a lazy tensor.
181 TORCH_API LazyTensorPtr TryGetLtcTensor(const at::Tensor& tensor);
182
183 // Extracts the LazyTensor out of an at::Tensor. Throws an exception
184 // if the tensor is not a lazy tensor.
185 TORCH_API LazyTensorPtr GetLtcTensor(const at::Tensor& tensor);
186
187 // Same as above, applied to a list of tensors.
188 TORCH_API std::vector<LazyTensorPtr> GetLtcTensors(
189 c10::ArrayRef<at::Tensor> tensors);
190
191 // If tensor is a lazy tensor type, returns the LazyTensor embedded within it,
192 // otherwise creates a new lazy tensor type with tensor as data.
193 TORCH_API LazyTensorPtr GetOrCreateLtcTensor(
194 const std::optional<at::Tensor>& tensor,
195 const BackendDevice& device);
196
197 TORCH_API LazyTensorPtr GetLtcTensorOrCreateForWrappedNumber(
198 const at::Tensor& tensor,
199 const BackendDevice& device);
200
201 // Section 2: LazyTensor => at::Tensor.
202 // Creates an ATen tensor from an LazyTensor.
203 TORCH_API at::Tensor CreateAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
204 TORCH_API at::Tensor CreateAtenFromLtcTensor(LazyTensor&& ltc_tensor);
205
206 // Note [Lazy Tensor Functionalization]
207 // The functionalization pass is implemented by wrapping all TensorImpl
208 // objects in C++ with an extra FunctionalTensorWrapper object,
209 // that knows how to perform functionalization
210 //
211 // Certain functions in the aten API serve as entry/exit points for
212 // functionalization, where we need to perform the wrapping/unwrapping:
213 // - aten::to.device
214 // - aten::empty
215
216 // Given a non-lazy tensor, this function creates a lazy tensor on the specified
217 // (lazy) device. The functionalize_output determines whether or not we should
218 // wrap the output in a "functional wrapper".
219 //
220 // How do you know whether to pass true/false for functionalize_output?
221 //
222 // Case 1: nonlazy -> lazy
223 // If you're implementing a function that takes in nonlazy tensors and returns
224 // lazy tensors, then you should think of that function as an "entrypoint" to
225 // functionalization, and use functionalize_output=true Examples include:
226 // - factory functions (the LTC kernel for at::empty)
227 // - CPU -> Lazy device converions (the LTC kernel for at::to_device)
228 //
229 // Case 2: lazy -> lazy
230 // If you're implementing a function that takes in lazy tensors and returns
231 // lazy tensors,
232 // **but** requires creating lazy tensors internally,
233 // then you can assume that the current function is running inside of some
234 // outer context where functionalization is already running, that will take
235 // care of doing the wrapping for you, and use functionalize_output=true
236 // Examples include:
237 // - CPU fallback (takes in lazy tensors, converts to cpu, calls kernel,
238 // converts returns back to lazy tensors).
239 TORCH_API at::Tensor to_lazy_tensor(
240 const at::Tensor& self,
241 const c10::TensorOptions& options,
242 at::Device device,
243 bool non_blocking,
244 bool functionalize_output);
245
246 template <size_t... Indices>
TupleAtenFromLtcTensorsImpl(const std::vector<LazyTensorPtr> & tensors,std::index_sequence<Indices...>)247 auto TupleAtenFromLtcTensorsImpl(
248 const std::vector<LazyTensorPtr>& tensors,
249 std::index_sequence<Indices...>) {
250 return std::make_tuple(CreateAtenFromLtcTensor(tensors[Indices])...);
251 }
252
253 template <size_t N>
TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr> & tensors)254 auto TupleAtenFromLtcTensors(const std::vector<LazyTensorPtr>& tensors) {
255 return TupleAtenFromLtcTensorsImpl(tensors, std::make_index_sequence<N>{});
256 }
257
258 } // namespace lazy
259 } // namespace torch
260