1 #pragma once
2
3 #include <c10/core/ScalarType.h>
4 #include <torch/csrc/lazy/backend/backend_interface.h>
5 #include <torch/csrc/lazy/core/config.h>
6 #include <torch/csrc/lazy/core/ir.h>
7 #include <torch/csrc/lazy/core/tensor.h>
8 #include <torch/csrc/lazy/core/trie.h>
9 #include <optional>
10 #include <vector>
11
12 // This file is part of the backend interface. So, ops shouldn't be added or
13 // removed without due process The exception to this being the view ops which
14 // will be removed soon pending functionalization
15
16 namespace torch {
17 namespace lazy {
18
19 template <typename T, typename... Args>
ReuseNode(Args &&...args)20 NodePtr ReuseNode(Args&&... args) {
21 if (FLAGS_torch_lazy_reuse_ir) {
22 return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
23 }
24 return nullptr;
25 }
26
27 // Caching an IR node into TrieCache
CacheNode(NodePtr node)28 static inline void CacheNode(NodePtr node) {
29 if (FLAGS_torch_lazy_reuse_ir) {
30 TrieCache::Get()->Insert(std::move(node));
31 }
32 }
33
34 template <typename T, typename... Args>
MakeNode(Args &&...args)35 NodePtr MakeNode(Args&&... args) {
36 return std::make_shared<T>(std::forward<Args>(args)...);
37 }
38
39 // op is passed in for a more efficient node casting, see the implementation of
40 // NodeCast
41 template <typename T, typename... Args>
ReuseOrMakeNode(Args &&...args)42 NodePtr ReuseOrMakeNode(Args&&... args) {
43 NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
44 if (!node) {
45 node = MakeNode<T>(std::forward<Args>(args)...);
46 CacheNode(node);
47 }
48 return node;
49 }
50
51 struct IrBuilder {
52 virtual NodePtr MakeDeviceData(
53 const std::shared_ptr<BackendData>& data) const = 0;
54 virtual NodePtr MakeScalar(
55 const at::Scalar& value,
56 const at::ScalarType& type) const = 0;
57 virtual NodePtr MakeExpand(
58 const Value& input0,
59 const std::vector<int64_t>& size,
60 const bool& is_scalar_expand) const = 0;
61 virtual NodePtr MakeCast(
62 const Value& input0,
63 const at::ScalarType& dtype,
64 const std::optional<at::ScalarType>& stype = std::nullopt) const = 0;
65 virtual NodePtr MakeTensorList(const OpList& inputs) const = 0;
66 virtual NodePtr MakeGeneric(
67 const OpKind& op,
68 const OpList& operands,
69 const Shape& shape,
70 const size_t& num_outputs = 1,
71 const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const = 0;
72
73 // dynamic ir nodes
74 virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0;
75 virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0;
76 virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0;
77 virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0;
78
79 virtual ~IrBuilder() = default;
80 };
81
MakeDeviceData(const std::shared_ptr<BackendData> & data)82 static inline NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) {
83 return getIrBuilder()->MakeDeviceData(data);
84 }
MakeScalar(const at::Scalar & value,const at::ScalarType & type)85 static inline NodePtr MakeScalar(
86 const at::Scalar& value,
87 const at::ScalarType& type) {
88 return getIrBuilder()->MakeScalar(value, type);
89 }
MakeExpand(const Value & input0,const std::vector<int64_t> & size,const bool & is_scalar_expand)90 static inline NodePtr MakeExpand(
91 const Value& input0,
92 const std::vector<int64_t>& size,
93 const bool& is_scalar_expand) {
94 return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand);
95 }
96 static inline NodePtr MakeCast(
97 const Value& input0,
98 const at::ScalarType& dtype,
99 const std::optional<at::ScalarType>& stype = std::nullopt) {
100 return getIrBuilder()->MakeCast(input0, dtype, stype);
101 }
MakeTensorList(const OpList & inputs)102 static inline NodePtr MakeTensorList(const OpList& inputs) {
103 return getIrBuilder()->MakeTensorList(inputs);
104 }
105 static inline NodePtr MakeGeneric(
106 const OpKind& op,
107 const OpList& operands,
108 const Shape& shape,
109 const size_t& num_outputs = 1,
110 const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) {
111 return getIrBuilder()->MakeGeneric(
112 op, operands, shape, num_outputs, hash_seed);
113 }
114
115 // dynamic ir nodes
MakeSizeNode(const Value & input,size_t dim)116 static inline NodePtr MakeSizeNode(const Value& input, size_t dim) {
117 return getIrBuilder()->MakeSizeNode(input, dim);
118 }
MakeSizeAdd(const Value & a,const Value & b)119 static inline NodePtr MakeSizeAdd(const Value& a, const Value& b) {
120 return getIrBuilder()->MakeSizeAdd(a, b);
121 }
MakeSizeMul(const Value & a,const Value & b)122 static inline NodePtr MakeSizeMul(const Value& a, const Value& b) {
123 return getIrBuilder()->MakeSizeAdd(a, b);
124 }
MakeSizeDiv(const Value & a,const Value & b)125 static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
126 return getIrBuilder()->MakeSizeDiv(a, b);
127 }
128
GetSymIntValue(c10::SymInt a)129 inline Value GetSymIntValue(c10::SymInt a) {
130 if (auto ma = a.maybe_as_int()) {
131 return Value(MakeScalar(*ma, at::kLong), 0);
132 } else {
133 return Value(
134 dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImplUnowned())
135 ->node_,
136 0);
137 }
138 }
139
140 // TODO: this should return Value
GetSymIntArrayRefValue(c10::SymIntArrayRef arr)141 inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
142 std::vector<int64_t> r;
143 for (const auto& a : arr) {
144 r.emplace_back(a.guard_int(__FILE__, __LINE__));
145 }
146 return r;
147 }
148
149 } // namespace lazy
150 } // namespace torch
151