xref: /aosp_15_r20/external/pytorch/benchmarks/static_runtime/test_utils.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include "test_utils.h"
7 
8 #include <ATen/core/ivalue.h>
9 #include <gtest/gtest.h>
10 #include <torch/csrc/jit/ir/irparser.h>
11 #include <torch/csrc/jit/runtime/graph_executor.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13 #include <torch/csrc/jit/runtime/static/impl.h>
14 #include <torch/csrc/jit/runtime/static/memory_planner.h>
15 #include <torch/csrc/jit/runtime/static/passes.h>
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/allclose.h>
21 #endif
22 
23 #include <memory>
24 #include <unordered_map>
25 
26 using namespace torch::jit;
27 using namespace torch;
28 using c10::IValue;
29 
30 namespace torch {
31 namespace jit {
32 namespace test {
33 
34 namespace {
35 
36 class GraphExecutorWrapper {
37  public:
38   GraphExecutorWrapper() = default;
39 
GraphExecutorWrapper(const std::shared_ptr<Graph> & graph)40   explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
41       : graph_exec_(graph, "") {}
42 
operator ()(const std::vector<c10::IValue> & args)43   c10::IValue operator()(const std::vector<c10::IValue>& args) {
44     Stack stack(args);
45     graph_exec_.run(stack);
46 
47     if (stack.size() == 1) {
48       return stack[0];
49     }
50     return c10::ivalue::Tuple::create(stack);
51   }
52 
53  private:
54   GraphExecutor graph_exec_;
55 };
56 
57 // Test scripts passed to testStaticRuntime can either be IR or JIT.
58 // The logic for running the script and producing a corresponding StaticModule
59 // is a bit different for each case. This logic is encapsulated within concrete
60 // implementations of this class, and testStaticRuntime is only aware of this
61 // interface.
62 class StaticRuntimeTestContext {
63  public:
64   virtual ~StaticRuntimeTestContext() = default;
65 
66   virtual IValue getExpected(const std::vector<IValue>& args) = 0;
67   virtual StaticModule makeStaticModule(
68       const StaticModuleOptions& opt) const = 0;
69 };
70 
71 class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
72  public:
ModuleStaticRuntimeTestContext(const std::string & source_jit)73   explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
74       : module_("module") {
75     module_.define(source_jit);
76   }
77 
getExpected(const std::vector<IValue> & args)78   IValue getExpected(const std::vector<IValue>& args) override {
79     return module_.forward(args);
80   }
81 
makeStaticModule(const StaticModuleOptions & opt) const82   StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
83     return torch::jit::StaticModule(
84         module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
85   }
86 
87  private:
88   Module module_;
89 };
90 
91 class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
92  public:
GraphStaticRuntimeContext(const std::string & source_ir)93   explicit GraphStaticRuntimeContext(const std::string& source_ir) {
94     graph_ = std::make_shared<Graph>();
95     std::unordered_map<std::string, Value*> vmap;
96     parseIR(source_ir, graph_.get(), vmap);
97 
98     graph_exec_ = GraphExecutorWrapper(graph_);
99   }
100 
getExpected(const std::vector<IValue> & args)101   IValue getExpected(const std::vector<IValue>& args) override {
102     return graph_exec_(args);
103   }
104 
makeStaticModule(const StaticModuleOptions & opt) const105   StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
106     return StaticModule(graph_, opt, /* sample_inputs */ {});
107   }
108 
109  private:
110   std::shared_ptr<Graph> graph_;
111   GraphExecutorWrapper graph_exec_;
112 };
113 
makeTestContext(const std::string & source)114 std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
115     const std::string& source) {
116   try {
117     return std::make_unique<ModuleStaticRuntimeTestContext>(source);
118     // Could not parse as TorchScript, assume it's IR
119   } catch (const std::runtime_error&) {
120     return std::make_unique<GraphStaticRuntimeContext>(source);
121   }
122 }
123 
compareTensorLists(const std::vector<IValue> & l,const std::vector<IValue> & r,const bool use_allclose,const bool use_equalnan)124 void compareTensorLists(
125     const std::vector<IValue>& l, /* expects */
126     const std::vector<IValue>& r, /* values */
127     const bool use_allclose,
128     const bool use_equalnan) {
129   EXPECT_TRUE(l.size() == r.size());
130   for (auto i : c10::irange(l.size())) {
131     ASSERT_TRUE(l[i].isTensor());
132     ASSERT_TRUE(r[i].isTensor());
133     VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
134     VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
135     if (!l[i].toTensor().defined()) {
136       EXPECT_TRUE(!r[i].toTensor().defined());
137     } else {
138       if (use_allclose) {
139         EXPECT_TRUE(at::allclose(
140             l[i].toTensor(),
141             r[i].toTensor(),
142             /*rtol*/ 1e-05,
143             /*atol*/ 1e-08,
144             use_equalnan));
145       } else {
146         EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
147       }
148     }
149   }
150 }
151 
152 } // namespace
153 
compareResults(const IValue & expect,const IValue & actual,const bool use_allclose,const bool use_equalnan)154 void compareResults(
155     const IValue& expect,
156     const IValue& actual,
157     const bool use_allclose,
158     const bool use_equalnan) {
159   if (expect.isTensor()) {
160     VLOG(2) << "expect " << expect.toTensor() << std::endl;
161     VLOG(2) << "output " << actual.toTensor() << std::endl;
162     EXPECT_TRUE(actual.isTensor());
163     if (use_allclose) {
164       EXPECT_TRUE(at::allclose(
165           expect.toTensor(),
166           actual.toTensor(),
167           /*rtol*/ 1e-05,
168           /*atol*/ 1e-08,
169           use_equalnan));
170     } else {
171       EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
172     }
173     return;
174   } else if (expect.isTuple()) {
175     EXPECT_TRUE(actual.isTuple());
176     auto lhs = expect.toTupleRef().elements();
177     auto rhs = actual.toTupleRef().elements();
178     ASSERT_TRUE(lhs.size() == rhs.size());
179     for (size_t i = 0; i < lhs.size(); i++) {
180       compareResults(lhs[i], rhs[i]);
181     }
182   } else if (expect.isList()) {
183     EXPECT_TRUE(actual.isList());
184     auto lhs = expect.toList();
185     auto rhs = actual.toList();
186     ASSERT_TRUE(lhs.size() == rhs.size());
187     for (size_t i = 0; i < lhs.size(); i++) {
188       compareResults(lhs[i], rhs[i]);
189     }
190   } else if (expect.isGenericDict()) {
191     EXPECT_TRUE(actual.isGenericDict());
192     auto lhs = expect.toGenericDict();
193     auto rhs = actual.toGenericDict();
194     EXPECT_TRUE(lhs.size() == rhs.size());
195     for (auto& lh : lhs) {
196       auto f = rhs.find(lh.key());
197       ASSERT_FALSE(f == rhs.end());
198       compareResults(lh.value(), f->value());
199     }
200   } else {
201     // fall back to the default comparison impl in IValue
202     EXPECT_TRUE(expect == actual);
203   }
204 }
205 
getTensor(const at::IValue & ival)206 at::Tensor getTensor(const at::IValue& ival) {
207   if (ival.isTensor()) {
208     return ival.toTensor();
209   } else if (ival.isTensorList()) {
210     auto tensor_vec = ival.toTensorVector();
211     TORCH_CHECK(tensor_vec.size() == 1);
212     return tensor_vec[0];
213   } else if (ival.isTuple()) {
214     auto tuple = ival.toTuple();
215     auto ivalue_vec = tuple->elements();
216     TORCH_CHECK(ivalue_vec.size() == 1);
217     return ivalue_vec[0].toTensor();
218   } else {
219     CAFFE_THROW("Unknown input IValue");
220   }
221 }
222 
getNodeWithKind(const StaticModule & smodule,const std::string & kind)223 Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
224   return smodule.findNodeWithKindForTesting(kind);
225 }
226 
getNodeWithKind(std::shared_ptr<Graph> & graph,const std::string & kind)227 Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
228   const auto symbol = c10::Symbol::fromQualString(kind);
229   DepthFirstGraphNodeIterator it(graph);
230   for (auto* node = it.next(); node != nullptr; node = it.next()) {
231     if (node->kind() == symbol) {
232       return node;
233     }
234   }
235   return nullptr;
236 }
237 
hasNodeWithKind(const StaticModule & smodule,const std::string & kind)238 bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
239   return getNodeWithKind(smodule, kind) != nullptr;
240 }
241 
hasNodeWithKind(std::shared_ptr<Graph> & graph,const std::string & kind)242 bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
243   return getNodeWithKind(graph, kind) != nullptr;
244 }
245 
getGraphFromScript(const std::string & jit_script)246 std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script) {
247   script::Module module("module");
248   module.define(jit_script);
249 
250   Method method = module.get_method("forward");
251   return module.get_method("forward").graph();
252 }
253 
getGraphFromIR(const std::string & ir)254 std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
255   auto graph = std::make_shared<Graph>();
256   std::unordered_map<std::string, Value*> vmap;
257   parseIR(ir, graph.get(), vmap);
258   return graph;
259 }
260 
compareResultsWithJIT(StaticRuntime & runtime,const std::shared_ptr<Graph> & graph,const std::vector<c10::IValue> & args,const bool use_allclose,const bool use_equalnan)261 void compareResultsWithJIT(
262     StaticRuntime& runtime,
263     const std::shared_ptr<Graph>& graph,
264     const std::vector<c10::IValue>& args,
265     const bool use_allclose,
266     const bool use_equalnan) {
267   GraphExecutorWrapper graph_exec(graph);
268   auto expected = graph_exec(args);
269   auto actual = runtime(args, {});
270   runtime.check_for_memory_leak();
271   compareResults(expected, actual, use_allclose, use_equalnan);
272 }
273 
testStaticRuntime(const std::string & source,const std::vector<IValue> & args,const std::vector<IValue> & args2,const bool use_allclose,const bool use_equalnan,const bool check_resize)274 void testStaticRuntime(
275     const std::string& source,
276     const std::vector<IValue>& args,
277     const std::vector<IValue>& args2,
278     const bool use_allclose,
279     const bool use_equalnan,
280     const bool check_resize) {
281   auto test_context = makeTestContext(source);
282 
283   std::vector<IValue> args_tensors, args_copy;
284   for (const auto& ival : args) {
285     if (ival.isTensor()) {
286       args_tensors.emplace_back(ival);
287       const at::Tensor& t = ival.toTensor();
288       args_copy.emplace_back(t.clone());
289     }
290   }
291 
292   auto expect = test_context->getExpected(args);
293 
294   for (bool enable_out_variant : {true, false}) {
295     for (bool manage_output_tensors : {true, false}) {
296       for (bool enable_tensorexpr_fusion : {true, false}) {
297         if (!enable_out_variant && manage_output_tensors) {
298           continue;
299         }
300         // run static runtime three times
301         // 1st run: collect allocation profiles (args)
302         // 2nd run: exercise memory planner and resizing with args2
303         // 3rd run: run with args again
304         StaticModuleOptions opts;
305         opts.enable_out_variant = enable_out_variant;
306         opts.optimize_memory = enable_out_variant;
307         opts.manage_output_tensors = manage_output_tensors;
308         opts.enable_tensorexpr_fusion = enable_tensorexpr_fusion;
309 
310         auto smodule = test_context->makeStaticModule(opts);
311         StaticRuntime runtime(smodule);
312         auto actual = runtime(args, {});
313         if (actual.isTensor()) {
314           EXPECT_GE(smodule.num_nodes(), 2)
315               << "If we only have one node, the output of the op we are testing is "
316               << "not being managed by the memory planner! A failure here "
317               << "can typically be fixed by clone()ing the output of the test script.";
318         }
319         runtime.check_for_memory_leak();
320         // first run
321         VLOG(2) << "enable_out_variant: " << enable_out_variant;
322         VLOG(2) << "manage_output_tensors: " << manage_output_tensors;
323         VLOG(2) << "enable_tensorexpr_fusion: " << enable_tensorexpr_fusion;
324         VLOG(2) << "args: " << args;
325         VLOG(2) << "args2: " << args2;
326         VLOG(2) << "expect: " << expect;
327         VLOG(2) << "actual: " << actual;
328         compareResults(expect, actual, use_allclose, use_equalnan);
329         VLOG(2) << "first run comparison done";
330         if (manage_output_tensors) {
331           actual = IValue();
332           runtime.deallocateOutputTensors();
333           runtime.checkOutputTensorMemoryLeaks();
334         }
335 
336         if (!args2.empty()) {
337           auto* memory_planner = runtime.get_memory_planner();
338           size_t managed_bytes =
339               memory_planner ? memory_planner->total_managed() : 0;
340 
341           // Run static runtime again with inputs of a different shape.
342           expect = test_context->getExpected(args2);
343           actual = runtime(args2, {});
344           runtime.check_for_memory_leak();
345           VLOG(2) << "comparing with args2";
346           compareResults(expect, actual, use_allclose, use_equalnan);
347           VLOG(2) << "second run comparison done";
348           if (manage_output_tensors) {
349             actual = IValue();
350             runtime.deallocateOutputTensors();
351             runtime.checkOutputTensorMemoryLeaks();
352           }
353 
354           size_t new_managed_bytes =
355               memory_planner ? memory_planner->total_managed() : 0;
356           if (check_resize && new_managed_bytes >= 0) {
357             EXPECT_GE(new_managed_bytes, managed_bytes);
358           }
359 
360           // Run static runtime again with an input of the shape observed during
361           // the profile run.
362           expect = test_context->getExpected(args);
363           actual = runtime(args, {});
364           runtime.check_for_memory_leak();
365           // third run
366           VLOG(2) << "comparing third run";
367           compareResults(expect, actual, use_allclose, use_equalnan);
368           VLOG(2) << "third run comparison done";
369           if (manage_output_tensors) {
370             actual = IValue();
371             runtime.deallocateOutputTensors();
372             runtime.checkOutputTensorMemoryLeaks();
373           }
374         } else {
375           // run static runtime again to exercise the memory planner
376           // and allocate managed tensors.
377           actual = runtime(args, {});
378           runtime.check_for_memory_leak();
379           VLOG(2) << "comparing second run with same args";
380           compareResults(expect, actual, use_allclose, use_equalnan);
381           VLOG(2) << "second run comparison done";
382           if (manage_output_tensors) {
383             actual = IValue();
384             runtime.deallocateOutputTensors();
385             runtime.checkOutputTensorMemoryLeaks();
386           }
387           // third run to use the allocated managed tensors.
388           actual = runtime(args, {});
389           runtime.check_for_memory_leak();
390           if (manage_output_tensors) {
391             actual = IValue();
392             runtime.deallocateOutputTensors();
393             runtime.checkOutputTensorMemoryLeaks();
394           }
395         }
396       }
397     }
398   }
399 
400   // make sure inputs were not modified
401   VLOG(2) << "Printing out input tensors";
402   compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
403 }
404 
hasProcessedNodeWithName(torch::jit::StaticModule & smodule,const char * name)405 bool hasProcessedNodeWithName(
406     torch::jit::StaticModule& smodule,
407     const char* name) {
408   return smodule.findNodeWithKindForTesting(name) != nullptr;
409 }
410 
411 } // namespace test
412 } // namespace jit
413 } // namespace torch
414