xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/ir_builder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/ScalarType.h>
4 #include <torch/csrc/lazy/backend/backend_interface.h>
5 #include <torch/csrc/lazy/core/config.h>
6 #include <torch/csrc/lazy/core/ir.h>
7 #include <torch/csrc/lazy/core/tensor.h>
8 #include <torch/csrc/lazy/core/trie.h>
9 #include <optional>
10 #include <vector>
11 
12 // This file is part of the backend interface. So, ops shouldn't be added or
13 // removed without due process The exception to this being the view ops which
14 // will be removed soon pending functionalization
15 
16 namespace torch {
17 namespace lazy {
18 
19 template <typename T, typename... Args>
ReuseNode(Args &&...args)20 NodePtr ReuseNode(Args&&... args) {
21   if (FLAGS_torch_lazy_reuse_ir) {
22     return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
23   }
24   return nullptr;
25 }
26 
27 // Caching an IR node into TrieCache
CacheNode(NodePtr node)28 static inline void CacheNode(NodePtr node) {
29   if (FLAGS_torch_lazy_reuse_ir) {
30     TrieCache::Get()->Insert(std::move(node));
31   }
32 }
33 
34 template <typename T, typename... Args>
MakeNode(Args &&...args)35 NodePtr MakeNode(Args&&... args) {
36   return std::make_shared<T>(std::forward<Args>(args)...);
37 }
38 
39 // op is passed in for a more efficient node casting, see the implementation of
40 // NodeCast
41 template <typename T, typename... Args>
ReuseOrMakeNode(Args &&...args)42 NodePtr ReuseOrMakeNode(Args&&... args) {
43   NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
44   if (!node) {
45     node = MakeNode<T>(std::forward<Args>(args)...);
46     CacheNode(node);
47   }
48   return node;
49 }
50 
51 struct IrBuilder {
52   virtual NodePtr MakeDeviceData(
53       const std::shared_ptr<BackendData>& data) const = 0;
54   virtual NodePtr MakeScalar(
55       const at::Scalar& value,
56       const at::ScalarType& type) const = 0;
57   virtual NodePtr MakeExpand(
58       const Value& input0,
59       const std::vector<int64_t>& size,
60       const bool& is_scalar_expand) const = 0;
61   virtual NodePtr MakeCast(
62       const Value& input0,
63       const at::ScalarType& dtype,
64       const std::optional<at::ScalarType>& stype = std::nullopt) const = 0;
65   virtual NodePtr MakeTensorList(const OpList& inputs) const = 0;
66   virtual NodePtr MakeGeneric(
67       const OpKind& op,
68       const OpList& operands,
69       const Shape& shape,
70       const size_t& num_outputs = 1,
71       const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) const = 0;
72 
73   // dynamic ir nodes
74   virtual NodePtr MakeSizeNode(const Value& input, size_t dim) const = 0;
75   virtual NodePtr MakeSizeAdd(const Value& a, const Value& b) const = 0;
76   virtual NodePtr MakeSizeMul(const Value& a, const Value& b) const = 0;
77   virtual NodePtr MakeSizeDiv(const Value& a, const Value& b) const = 0;
78 
79   virtual ~IrBuilder() = default;
80 };
81 
MakeDeviceData(const std::shared_ptr<BackendData> & data)82 static inline NodePtr MakeDeviceData(const std::shared_ptr<BackendData>& data) {
83   return getIrBuilder()->MakeDeviceData(data);
84 }
MakeScalar(const at::Scalar & value,const at::ScalarType & type)85 static inline NodePtr MakeScalar(
86     const at::Scalar& value,
87     const at::ScalarType& type) {
88   return getIrBuilder()->MakeScalar(value, type);
89 }
MakeExpand(const Value & input0,const std::vector<int64_t> & size,const bool & is_scalar_expand)90 static inline NodePtr MakeExpand(
91     const Value& input0,
92     const std::vector<int64_t>& size,
93     const bool& is_scalar_expand) {
94   return getIrBuilder()->MakeExpand(input0, size, is_scalar_expand);
95 }
96 static inline NodePtr MakeCast(
97     const Value& input0,
98     const at::ScalarType& dtype,
99     const std::optional<at::ScalarType>& stype = std::nullopt) {
100   return getIrBuilder()->MakeCast(input0, dtype, stype);
101 }
MakeTensorList(const OpList & inputs)102 static inline NodePtr MakeTensorList(const OpList& inputs) {
103   return getIrBuilder()->MakeTensorList(inputs);
104 }
105 static inline NodePtr MakeGeneric(
106     const OpKind& op,
107     const OpList& operands,
108     const Shape& shape,
109     const size_t& num_outputs = 1,
110     const hash_t& hash_seed = static_cast<uint32_t>(0x5a2d296e9)) {
111   return getIrBuilder()->MakeGeneric(
112       op, operands, shape, num_outputs, hash_seed);
113 }
114 
115 // dynamic ir nodes
MakeSizeNode(const Value & input,size_t dim)116 static inline NodePtr MakeSizeNode(const Value& input, size_t dim) {
117   return getIrBuilder()->MakeSizeNode(input, dim);
118 }
MakeSizeAdd(const Value & a,const Value & b)119 static inline NodePtr MakeSizeAdd(const Value& a, const Value& b) {
120   return getIrBuilder()->MakeSizeAdd(a, b);
121 }
MakeSizeMul(const Value & a,const Value & b)122 static inline NodePtr MakeSizeMul(const Value& a, const Value& b) {
123   return getIrBuilder()->MakeSizeAdd(a, b);
124 }
MakeSizeDiv(const Value & a,const Value & b)125 static inline NodePtr MakeSizeDiv(const Value& a, const Value& b) {
126   return getIrBuilder()->MakeSizeDiv(a, b);
127 }
128 
GetSymIntValue(c10::SymInt a)129 inline Value GetSymIntValue(c10::SymInt a) {
130   if (auto ma = a.maybe_as_int()) {
131     return Value(MakeScalar(*ma, at::kLong), 0);
132   } else {
133     return Value(
134         dynamic_cast<torch::lazy::SymNodeImpl*>(a.toSymNodeImplUnowned())
135             ->node_,
136         0);
137   }
138 }
139 
140 // TODO: this should return Value
GetSymIntArrayRefValue(c10::SymIntArrayRef arr)141 inline std::vector<int64_t> GetSymIntArrayRefValue(c10::SymIntArrayRef arr) {
142   std::vector<int64_t> r;
143   for (const auto& a : arr) {
144     r.emplace_back(a.guard_int(__FILE__, __LINE__));
145   }
146   return r;
147 }
148 
149 } // namespace lazy
150 } // namespace torch
151