xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/symbol.h>
4 
5 #include <functional>
6 #include <memory>
7 #include <set>
8 #include <string>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13 
14 #include <c10/core/ScalarType.h>
15 #include <c10/util/ArrayRef.h>
16 #include <c10/util/Flags.h>
17 #include <torch/csrc/lazy/core/hash.h>
18 #include <torch/csrc/lazy/core/ir_metadata.h>
19 #include <torch/csrc/lazy/core/shape.h>
20 
21 C10_DECLARE_bool(ltc_enable_dynamic_shapes);
22 
23 namespace torch {
24 namespace lazy {
25 
26 static const hash_t kHashSeed(static_cast<uint32_t>(0x5a2d296e9));
27 
28 class Node;
29 struct Output;
30 struct Value;
31 
32 using NodePtr = std::shared_ptr<Node>;
33 
34 // The Kind of operation a Node can be associated to.
35 struct TORCH_API OpKind {
36   OpKind() = default;
OpKindOpKind37   explicit OpKind(c10::Symbol op) : op(op) {}
38 
39   bool operator==(const OpKind& rhs) const {
40     return op == rhs.op;
41   }
42   bool operator!=(const OpKind& rhs) const {
43     return !operator==(rhs);
44   }
45   bool operator<(const OpKind& rhs) const {
46     return c10::unique_t(op) < c10::unique_t(rhs.op);
47   }
48 
49   hash_t hash() const;
50 
ToStringOpKind51   std::string ToString() const {
52     return op.toQualString();
53   }
54 
55   // Retrieves an existing operation object, or creates a new one. Operations
56   // that are specific to lazy tensors, should live within the 'lazy_tensors::'
57   // namespace.
58   static OpKind Get(const std::string& name);
59 
60   c10::Symbol op;
61 };
62 
63 inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) {
64   stream << op.ToString();
65   return stream;
66 }
67 
68 using OpList = c10::ArrayRef<Value>;
69 
70 hash_t OperandHashes(
71     const OpList& operands,
72     const hash_t& seed,
73     bool bakeInSizes);
74 // A node in the graph. Nodes for operations which require extra data to be
75 // stored for lowering should inherit from this class and add an operation
76 // specific member there. For example, a constant might create a new
77 // NodeConstant class (inheriting from Node) with an extra lazy_tensors::Literal
78 // field, or a tensor value might create a new NodeTensor with a computation
79 // client data handle in it.
80 class TORCH_API Node {
81  public:
82   static bool enableDynamicShape();
83 
84   // Creates a new node with the given op name. The op is a unique identifier
85   // for the operation. The num_outputs tells how many outputs a given operation
86   // generates.
87   //
88   // None leaf node's node_hash does not contains shape information always.
89   // So we pass in the hash value rather than a function.
90   Node(OpKind op, size_t num_outputs);
91 
92   // Construct node with operands and shapes
93   Node(
94       OpKind op,
95       OpList operands,
96       std::vector<Shape>&& shapes,
97       size_t num_outputs = 1);
98 
99   // Construct node with operands and shape generated from a function
100   Node(
101       OpKind op,
102       OpList operands,
103       const std::function<Shape()>& shape_fn,
104       size_t num_outputs = 1);
105 
106   // Construct node with operands and no shape
107   Node(OpKind op, OpList operands, size_t num_outputs = 1);
108 
109   // Construct node with shape and no operands
110   Node(OpKind op, Shape shape, size_t num_outputs = 1);
111 
112   virtual ~Node();
113 
op()114   const OpKind& op() const {
115     return op_;
116   }
117 
num_outputs()118   size_t num_outputs() const {
119     return num_outputs_;
120   }
121 
122   // Retrieves the full shape of the IR Node.
123   virtual c10::ArrayRef<Shape> shapes() const;
124 
125   virtual const Shape& shape(size_t output_index = 0) const;
126 
127   // Add the shape computed by the shape_fn
128   void addComputedShape(const std::function<Shape()>& shape_fn);
129 
130   // Compute the shape using the provided shape_fn if not previously cached
131   Shape computeShape(const std::function<Shape()>& shape_fn);
132 
133   virtual const std::vector<Output>& operands() const;
134 
135   virtual const Output& operand(size_t i) const;
136 
137   // Gets operand at index i if index is valid, or kNullOutput otherwise.
138   virtual const Output& nullable_operand(size_t i) const;
139 
140   // Returns the hash of the dag used to look up the compiled graph
141   virtual hash_t hash() const = 0;
142 
143   // Returns the hash of the dag used to for shape caching
144   virtual hash_t shapeHash() const = 0;
145 
metadata()146   const MetaData& metadata() const {
147     return metadata_;
148   }
149 
user_metadata()150   UserMetaData* user_metadata() const {
151     return user_metadata_.get();
152   }
153 
SetUserMetadata(std::shared_ptr<UserMetaData> user_meta)154   std::shared_ptr<UserMetaData> SetUserMetadata(
155       std::shared_ptr<UserMetaData> user_meta) {
156     std::swap(user_metadata_, user_meta);
157     return user_meta;
158   }
159 
160   virtual std::string ToString() const;
161 
162  private:
163   // The ID of the operation captured by this node.
164   OpKind op_;
165   size_t num_outputs_ = 1;
166 
167   // The IR specific metadata attached to the IR node.
168   MetaData metadata_;
169   // The IR framework user can attach a user defined metadata object deriving
170   // from UserMetaData.
171   std::shared_ptr<UserMetaData> user_metadata_;
172 
173  protected:
174   // Adds node's index output number as operand.
175   void AddOperand(NodePtr node, size_t index = 0);
176 
177   std::vector<Shape> shapes_;
178   // A node holds a real reference to its operands.
179   std::vector<NodePtr> operands_;
180   // Outputs do not hold references on the nodes, and neither do the uses, since
181   // otherwise we get into circular reference counting.
182   std::vector<Output> operands_as_outputs_;
183 };
184 
185 inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
186   stream << node.ToString();
187   return stream;
188 }
189 
190 // Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
191 // clean up once the migration is done.
192 template <typename T>
NodeCast(const Node * node,OpKind op)193 const T* NodeCast(const Node* node, OpKind op) {
194   if (op != node->op()) {
195     return nullptr;
196   }
197 #ifdef NDEBUG
198   return static_cast<const T*>(node);
199 #else
200   return &dynamic_cast<const T&>(*node);
201 #endif
202 }
203 
204 template <typename T>
NodeCast(const Node * node)205 const T* NodeCast(const Node* node) {
206   if (T::ClassOpKind() != node->op()) {
207     return nullptr;
208   }
209   // TODO: Some IR classes share the same opkind, such as Mean and MeanDim, so
210   // static_cast is not safe here. Unless we have opkind unique for each class,
211   // we have to use dynamic_cast here.
212   return dynamic_cast<const T*>(node);
213 }
214 
215 // Represents a specific output produced by a node. Since the output of a node
216 // can be composed by multiple outputs, the node+index coordinates fully qualify
217 // each single output.
218 struct TORCH_API Output {
219   struct Hasher {
220     size_t operator()(const Output& output) const;
221   };
222 
223   Output() = default;
224   explicit Output(const Node* node, size_t index = 0)
nodeOutput225       : node(node), index(index) {}
226 
227   hash_t hash() const;
228   hash_t shapeHash() const;
229 
230   bool operator==(const Output& rhs) const {
231     return node == rhs.node && index == rhs.index;
232   }
233 
234   // To compare the operands of to-be-constructed node and to-be-reused node
235   bool operator==(const Value& rhs) const;
236 
237   bool operator!=(const Output& rhs) const {
238     return !operator==(rhs);
239   }
240 
shapeOutput241   const Shape& shape() const {
242     return node->shape(index);
243   }
244 
245   std::string ToString() const;
246 
247   // The node providing the output.
248   const Node* node{nullptr};
249   // The index in the node's output this output refers to.
250   size_t index{0};
251 };
252 
253 inline std::ostream& operator<<(std::ostream& stream, const Output& output) {
254   stream << output.ToString();
255   return stream;
256 }
257 
258 template <typename T>
259 using OutputMap = std::unordered_map<Output, T, Output::Hasher>;
260 
261 // Represents an input/operand for a Node object.
262 struct TORCH_API Value {
263   Value() = default;
264   /* implicit */ Value(NodePtr&& node, size_t index = 0)
nodeValue265       : node(std::move(node)), index(index) {}
266   /* implicit */ Value(const NodePtr& node, size_t index = 0)
nodeValue267       : node(node), index(index) {}
268 
269   hash_t hash() const;
270   hash_t shapeHash() const;
271 
272   operator bool() const {
273     return node != nullptr;
274   }
275 
OutputValue276   operator Output() const {
277     return Output(node.get(), index);
278   }
279 
shapeValue280   const Shape& shape() const {
281     return node->shape(index);
282   }
283 
284   Node* operator->() const {
285     return node.get();
286   }
287 
288   NodePtr node;
289   size_t index = 0;
290 };
291 
292 } // namespace lazy
293 } // namespace torch
294 
295 namespace c10 {
296 // Explicit template instantiation to make ArrayRef<Value> work
297 template class at::ArrayRef<torch::lazy::Value>;
298 } // namespace c10
299