xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/clear_undefinedness.h>
6 #include <torch/csrc/jit/runtime/custom_operator.h>
7 
8 namespace torch {
9 namespace jit {
10 
createStack(std::vector<at::Tensor> && list)11 Stack createStack(std::vector<at::Tensor>&& list) {
12   return Stack(
13       std::make_move_iterator(list.begin()),
14       std::make_move_iterator(list.end()));
15 }
16 
assertAllClose(const tensor_list & a,const tensor_list & b)17 void assertAllClose(const tensor_list& a, const tensor_list& b) {
18   ASSERT_EQ(a.size(), b.size());
19   for (size_t i = 0; i < a.size(); ++i) {
20     ASSERT_TRUE(a[i].is_same_size(b[i]));
21     ASSERT_TRUE(a[i].allclose(b[i]));
22   }
23 }
24 
run(InterpreterState & interp,const std::vector<at::Tensor> & inputs)25 std::vector<at::Tensor> run(
26     InterpreterState& interp,
27     const std::vector<at::Tensor>& inputs) {
28   std::vector<IValue> stack(inputs.begin(), inputs.end());
29   interp.run(stack);
30   return fmap(stack, [](const IValue& i) { return i.toTensor(); });
31 }
32 
unpackReturnTuple(Stack & stack)33 static void unpackReturnTuple(Stack& stack) {
34   auto tuple = pop(stack).toTuple();
35   stack.insert(stack.end(), tuple->elements().begin(), tuple->elements().end());
36 }
37 
runGradient(Gradient & grad_spec,tensor_list & tensors_in,tensor_list & tensor_grads_in)38 std::pair<tensor_list, tensor_list> runGradient(
39     Gradient& grad_spec,
40     tensor_list& tensors_in,
41     tensor_list& tensor_grads_in) {
42   static const auto as_tensorlist = [](const Stack& stack) {
43     return fmap(stack, [](const IValue& i) { return i.toTensor(); });
44   };
45   ClearUndefinedness(grad_spec.df);
46   Code f_code{grad_spec.f, ""}, df_code{grad_spec.df, ""};
47   InterpreterState f_interpreter{f_code}, df_interpreter{df_code};
48 
49   auto f_stack = fmap<IValue>(tensors_in);
50   f_interpreter.run(f_stack);
51 
52   Stack df_stack;
53   df_stack.insert(
54       df_stack.end(), tensor_grads_in.begin(), tensor_grads_in.end());
55   for (auto offset : grad_spec.df_input_captured_inputs)
56     df_stack.push_back(tensors_in[offset]);
57   for (auto offset : grad_spec.df_input_captured_outputs)
58     df_stack.push_back(f_stack[offset]);
59   df_interpreter.run(df_stack);
60   unpackReturnTuple(df_stack);
61   // Outputs of f needs to be sliced
62   f_stack.erase(f_stack.begin() + grad_spec.f_real_outputs, f_stack.end());
63   return std::make_pair(as_tensorlist(f_stack), as_tensorlist(df_stack));
64 }
65 
build_lstm()66 std::shared_ptr<Graph> build_lstm() {
67   const auto graph_string = R"IR(
68     graph(%0 : Tensor,
69           %1 : Tensor,
70           %2 : Tensor,
71           %3 : Tensor,
72           %4 : Tensor):
73       %5 : Tensor = aten::mm(%0, %3)
74       %6 : Tensor = aten::mm(%1, %4)
75       %7 : int = prim::Constant[value=1]()
76       %8 : Tensor = aten::add(%5, %6, %7)
77       %9 : Tensor, %10 : Tensor, %11 : Tensor, %12 : Tensor = prim::ConstantChunk[chunks=4, dim=1](%8)
78       %13 : Tensor = aten::sigmoid(%9)
79       %14 : Tensor = aten::sigmoid(%12)
80       %15 : Tensor = aten::tanh(%11)
81       %16 : Tensor = aten::sigmoid(%10)
82       %17 : Tensor = aten::mul(%16, %2)
83       %18 : Tensor = aten::mul(%13, %15)
84       %19 : int = prim::Constant[value=1]()
85       %20 : Tensor = aten::add(%17, %18, %19)
86       %21 : Tensor = aten::tanh(%20)
87       %22 : Tensor = aten::mul(%14, %21)
88       return (%22, %20))IR";
89   auto g = std::make_shared<Graph>();
90   torch::jit::parseIR(graph_string, g.get());
91   g->lint();
92 
93   return g;
94 }
95 
build_mobile_export_analysis_graph()96 std::shared_ptr<Graph> build_mobile_export_analysis_graph() {
97   // We use following two schemas for this graph:
98   //   1. slice.Tensor(Tensor(a) self, int dim=0, int? start=None,
99   //                   int? end=None, int step=1) -> Tensor(a)
100   //   2. slice.str(str string, int? start=None, int? end=None,
101   //                  int step=1) -> str
102   // %3 and %4 use slice.Tensor while %5 use slice.str.
103   // Since we can see %3 and %4 have the same last argument that is never used
104   // (same as default value of schema), we know we can ignore that last arg. For
105   // %5, we see that last three args are same as schema default, hence
106   // unnecessary.
107 
108   const auto graph_string = R"IR(
109     graph(%0 : Tensor):
110       %1 : int = prim::Constant[value=1]()
111       %2 : int = prim::Constant[value=2]()
112       %20 : int = prim::Constant[value=0]()
113       %21 : int = prim::Constant[value=9223372036854775807]()
114       %22 : str = prim::Constant[value="value"]()
115       %3 : Tensor  = aten::slice(%0, %1, %20, %2, %1)
116       %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
117       %5 : str = aten::slice(%22, %20, %21, %2)
118       return (%3, %4, %5))IR";
119 
120   auto g = std::make_shared<Graph>();
121   torch::jit::parseIR(graph_string, g.get());
122   g->lint();
123   return g;
124 }
125 
build_mobile_export_with_out()126 std::shared_ptr<Graph> build_mobile_export_with_out() {
127   const auto graph_string = R"IR(
128     graph(%x.1 : Tensor,
129           %y.1 : Tensor):
130       %8 : NoneType = prim::Constant()
131       %6 : int = prim::Constant[value=1]()
132       %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1)
133       return (%8))IR";
134 
135   auto g = std::make_shared<Graph>();
136   torch::jit::parseIR(graph_string, g.get());
137   g->lint();
138   return g;
139 }
140 
build_mobile_export_analysis_graph_nested()141 std::shared_ptr<Graph> build_mobile_export_analysis_graph_nested() {
142   // this is pretty much same test as build_mobile_export_analysis_graph(),
143   // but some aten::slice operators are hidden under block statement to check
144   // if we are correctly recursing all the nodes in graph.
145   const auto graph_string = R"IR(
146     graph(%0 : Tensor):
147       %1 : int = prim::Constant[value=1]()
148       %2 : int = prim::Constant[value=2]()
149       %20 : int = prim::Constant[value=0]()
150       %21 : int = prim::Constant[value=9223372036854775807]()
151       %22 : str = prim::Constant[value="value"]()
152       %3 : Tensor  = aten::slice(%0, %1, %20, %2, %1)
153       %23 : bool = aten::Bool(%3)
154       %c : Tensor = prim::If(%23)
155         block0():
156           %4 : Tensor = aten::slice(%0, %2, %20, %21, %1)
157           %5 : str = aten::slice(%22, %20, %21, %2)
158           %c.1 : Tensor = aten::slice(%0, %1, %20, %2, %1)
159           -> (%c.1)
160         block1():
161           -> (%3)
162       return (%3, %3))IR";
163   auto g = std::make_shared<Graph>();
164   torch::jit::parseIR(graph_string, g.get());
165   g->lint();
166   return g;
167 }
168 
build_mobile_export_analysis_graph_with_vararg()169 std::shared_ptr<Graph> build_mobile_export_analysis_graph_with_vararg() {
170   const auto graph_string = R"IR(
171     graph(%0 : Tensor):
172       %1 : int = prim::Constant[value=1]()
173       %2 : int = prim::Constant[value=2]()
174       %3 : int = prim::Constant[value=3]()
175       %4 : int[]  = prim::tolist(%1, %2)
176       %5 : int[] = prim::tolist(%1, %2, %3)
177       return (%4, %5))IR";
178 
179   auto g = std::make_shared<Graph>();
180   torch::jit::parseIR(graph_string, g.get());
181   g->lint();
182   return g;
183 }
184 
build_mobile_export_analysis_graph_non_const()185 std::shared_ptr<Graph> build_mobile_export_analysis_graph_non_const() {
186   const auto graph_string = R"IR(
187       graph(%input.1 : Tensor):
188         %7 : int = prim::Constant[value=1]() # <string>:3:58
189         %9 : int = prim::Constant[value=0]() # <string>:3:66
190         %8 : int[] = prim::ListConstruct(%7, %7)
191         %10 : int[] = prim::ListConstruct(%9, %9)
192         %11 : int[] = prim::ListConstruct(%7, %7)
193         %12 : Tensor = aten::conv2d(%input.1, %input.1, %input.1, %8, %10, %11, %7)
194         return (%12))IR";
195 
196   auto g = std::make_shared<Graph>();
197   torch::jit::parseIR(graph_string, g.get());
198   g->lint();
199   return g;
200 }
201 
t_use(at::Tensor x)202 at::Tensor t_use(at::Tensor x) {
203   return x;
204 }
t_def(at::Tensor x)205 at::Tensor t_def(at::Tensor x) {
206   return x.t();
207 }
208 
checkRtol(const at::Tensor & diff,const std::vector<at::Tensor> inputs)209 bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
210   double maxValue = 0.0;
211   for (auto& tensor : inputs) {
212     maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
213   }
214   return diff.abs().max().item<float>() < 2e-6 * maxValue;
215 }
216 
almostEqual(const at::Tensor & a,const at::Tensor & b)217 bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
218   return checkRtol(a - b, {a, b});
219 }
220 
exactlyEqual(const at::Tensor & a,const at::Tensor & b)221 bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
222   return (a - b).abs().max().item<float>() == 0.f;
223 }
224 
exactlyEqual(const std::vector<at::Tensor> & a,const std::vector<at::Tensor> & b)225 bool exactlyEqual(
226     const std::vector<at::Tensor>& a,
227     const std::vector<at::Tensor>& b) {
228   if (a.size() != b.size()) {
229     return false;
230   }
231   for (size_t i = 0; i < a.size(); ++i) {
232     if (!exactlyEqual(a[i], b[i])) {
233       return false;
234     }
235   }
236   return true;
237 }
238 
lstm(at::Tensor input,at::Tensor hx,at::Tensor cx,at::Tensor w_ih,at::Tensor w_hh)239 std::pair<at::Tensor, at::Tensor> lstm(
240     at::Tensor input,
241     at::Tensor hx,
242     at::Tensor cx,
243     at::Tensor w_ih,
244     at::Tensor w_hh) {
245   auto gates = input.mm(t_use(w_ih)) + hx.mm(t_use(w_hh));
246 
247   auto chunked_gates = gates.chunk(4, 1);
248   auto ingate = chunked_gates[0];
249   auto forgetgate = chunked_gates[1];
250   auto cellgate = chunked_gates[2];
251   auto outgate = chunked_gates[3];
252 
253   ingate = ingate.sigmoid();
254   outgate = outgate.sigmoid();
255   cellgate = cellgate.tanh();
256   forgetgate = forgetgate.sigmoid();
257 
258   auto cy = (forgetgate * cx) + (ingate * cellgate);
259   auto hy = outgate * cy.tanh();
260 
261   return {hy, cy};
262 }
263 
aliasAnalysisFromSchema()264 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
265   return c10::AliasAnalysisKind::FROM_SCHEMA;
266 }
267 
268 namespace {
269 RegisterOperators reg({
270     // This operator is intended to be used in JIT analysis and transformation
271     // pass unit tests in which Values with type Tensor are often required. It
272     // should not be used in situations in which the graph is actually executed
273     // because it always produces empty Tensors.
274     Operator(
275         "prim::MakeTestTensor() -> Tensor",
__anona6ec7ab70502(Stack& stack) 276         [](Stack& stack) { push(stack, at::Tensor()); },
277         aliasAnalysisFromSchema()),
278 });
279 } // namespace
280 
runGraph(std::shared_ptr<Graph> graph,const std::vector<at::Tensor> & inputs)281 std::vector<at::Tensor> runGraph(
282     std::shared_ptr<Graph> graph,
283     const std::vector<at::Tensor>& inputs) {
284   std::vector<IValue> stack = fmap<IValue>(inputs);
285   Code code(graph, "test");
286   InterpreterState(code).run(stack);
287   TORCH_INTERNAL_ASSERT(!stack.empty());
288   // Graph outputs that are handled below:
289   //   * A list of Tensors.
290   //   * 1 Tensor.
291   if (stack.front().isTensorList()) {
292     return stack.front().toTensorVector();
293   }
294   TORCH_INTERNAL_ASSERT(stack.front().isTensor());
295   return {stack.front().toTensor()};
296 }
297 
298 } // namespace jit
299 } // namespace torch
300