xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/dynamic_ir.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/symbol.h>
4 
5 #include <functional>
6 #include <memory>
7 #include <set>
8 #include <string>
9 #include <unordered_map>
10 #include <unordered_set>
11 #include <utility>
12 #include <vector>
13 
14 #include <c10/core/ScalarType.h>
15 #include <c10/util/Flags.h>
16 #include <torch/csrc/lazy/core/dynamic_ir.h>
17 #include <torch/csrc/lazy/core/hash.h>
18 #include <torch/csrc/lazy/core/ir.h>
19 #include <torch/csrc/lazy/core/ir_metadata.h>
20 #include <torch/csrc/lazy/ts_backend/ts_node.h>
21 
22 C10_DECLARE_bool(ltc_enable_dynamic_shapes);
23 
24 namespace torch {
25 namespace lazy {
26 
27 /**
28  * The goal of "dynamic" Nodes is to patch a hole in our tracing.
29  * Previously, if a user called `sizes` on a Tensor, it would leak out
30  * of our tracing system, as `sizes` returns a torch.Size or an int. To
31  * prevent this from happening, we introduce DimensionNode, a new type
32  * of Node that abstracts the operation of getting the dimensions of a
33  * Tensor.
34  *
35  * Consider the following example:
36  * ```
37  * numel = x.shape()[0] * x.shape()[1]
38  * ```
39  *
40  * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode),
41  * and the multiplication of the two SizeNodes will be represented by
42  * a SizeMul (also a subclass of DimensionNode). Through this, we can
43  * prevent `numel` from being represented as a Python int and thus
44  * burned into the Graph.
45  */
46 
47 // Represents the result of calling `size` on a Tensor
48 class TORCH_API SizeNode : public TsNode, public DimensionNode {
49  public:
50   SizeNode(Value input, size_t dim);
51   int64_t getStaticValue() const override;
52   bool isSymbolic() const override;
53   std::string ToString() const override;
54   size_t dim_ = 0;
55   torch::lazy::TSOpVector Lower(
56       std::shared_ptr<torch::jit::GraphFunction> function,
57       TSLoweringContext* loctx) const override;
58 };
59 
60 class TORCH_API SizeAdd : public TsNode, public DimensionNode {
61  public:
62   SizeAdd(Value a, Value b);
63   int64_t getStaticValue() const override;
64   bool isSymbolic() const override;
65   std::string ToString() const override;
66 };
67 
68 class TORCH_API SizeMul : public TsNode, public DimensionNode {
69  public:
70   SizeMul(Value a, Value b);
71   int64_t getStaticValue() const override;
72   bool isSymbolic() const override;
73   std::string ToString() const override;
74 };
75 
76 class TORCH_API SizeDiv : public TsNode, public DimensionNode {
77  public:
78   SizeDiv(Value a, Value b);
79   int64_t getStaticValue() const override;
80   bool isSymbolic() const override;
81   std::string ToString() const override;
82 };
83 
84 } // namespace lazy
85 } // namespace torch
86