1 #pragma once 2 3 #include <ATen/core/symbol.h> 4 5 #include <functional> 6 #include <memory> 7 #include <set> 8 #include <string> 9 #include <unordered_map> 10 #include <unordered_set> 11 #include <utility> 12 #include <vector> 13 14 #include <c10/core/ScalarType.h> 15 #include <c10/util/Flags.h> 16 #include <torch/csrc/lazy/core/dynamic_ir.h> 17 #include <torch/csrc/lazy/core/hash.h> 18 #include <torch/csrc/lazy/core/ir.h> 19 #include <torch/csrc/lazy/core/ir_metadata.h> 20 #include <torch/csrc/lazy/ts_backend/ts_node.h> 21 22 C10_DECLARE_bool(ltc_enable_dynamic_shapes); 23 24 namespace torch { 25 namespace lazy { 26 27 /** 28 * The goal of "dynamic" Nodes is to patch a hole in our tracing. 29 * Previously, if a user called `sizes` on a Tensor, it would leak out 30 * of our tracing system, as `sizes` returns a torch.Size or an int. To 31 * prevent this from happening, we introduce DimensionNode, a new type 32 * of Node that abstracts the operation of getting the dimensions of a 33 * Tensor. 34 * 35 * Consider the following example: 36 * ``` 37 * numel = x.shape()[0] * x.shape()[1] 38 * ``` 39 * 40 * Here, `x.shape()[i]` will be a SizeNode (subclass of DimensionNode), 41 * and the multiplication of the two SizeNodes will be represented by 42 * a SizeMul (also a subclass of DimensionNode). Through this, we can 43 * prevent `numel` from being represented as a Python int and thus 44 * burned into the Graph. 45 */ 46 47 // Represents the result of calling `size` on a Tensor 48 class TORCH_API SizeNode : public TsNode, public DimensionNode { 49 public: 50 SizeNode(Value input, size_t dim); 51 int64_t getStaticValue() const override; 52 bool isSymbolic() const override; 53 std::string ToString() const override; 54 size_t dim_ = 0; 55 torch::lazy::TSOpVector Lower( 56 std::shared_ptr<torch::jit::GraphFunction> function, 57 TSLoweringContext* loctx) const override; 58 }; 59 60 class TORCH_API SizeAdd : public TsNode, public DimensionNode { 61 public: 62 SizeAdd(Value a, Value b); 63 int64_t getStaticValue() const override; 64 bool isSymbolic() const override; 65 std::string ToString() const override; 66 }; 67 68 class TORCH_API SizeMul : public TsNode, public DimensionNode { 69 public: 70 SizeMul(Value a, Value b); 71 int64_t getStaticValue() const override; 72 bool isSymbolic() const override; 73 std::string ToString() const override; 74 }; 75 76 class TORCH_API SizeDiv : public TsNode, public DimensionNode { 77 public: 78 SizeDiv(Value a, Value b); 79 int64_t getStaticValue() const override; 80 bool isSymbolic() const override; 81 std::string ToString() const override; 82 }; 83 84 } // namespace lazy 85 } // namespace torch 86