1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include "test_utils.h"
7
8 #include <ATen/core/ivalue.h>
9 #include <gtest/gtest.h>
10 #include <torch/csrc/jit/ir/irparser.h>
11 #include <torch/csrc/jit/runtime/graph_executor.h>
12 #include <torch/csrc/jit/runtime/graph_iterator.h>
13 #include <torch/csrc/jit/runtime/static/impl.h>
14 #include <torch/csrc/jit/runtime/static/memory_planner.h>
15 #include <torch/csrc/jit/runtime/static/passes.h>
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/allclose.h>
21 #endif
22
23 #include <memory>
24 #include <unordered_map>
25
26 using namespace torch::jit;
27 using namespace torch;
28 using c10::IValue;
29
30 namespace torch {
31 namespace jit {
32 namespace test {
33
34 namespace {
35
36 class GraphExecutorWrapper {
37 public:
38 GraphExecutorWrapper() = default;
39
GraphExecutorWrapper(const std::shared_ptr<Graph> & graph)40 explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
41 : graph_exec_(graph, "") {}
42
operator ()(const std::vector<c10::IValue> & args)43 c10::IValue operator()(const std::vector<c10::IValue>& args) {
44 Stack stack(args);
45 graph_exec_.run(stack);
46
47 if (stack.size() == 1) {
48 return stack[0];
49 }
50 return c10::ivalue::Tuple::create(stack);
51 }
52
53 private:
54 GraphExecutor graph_exec_;
55 };
56
57 // Test scripts passed to testStaticRuntime can either be IR or JIT.
58 // The logic for running the script and producing a corresponding StaticModule
59 // is a bit different for each case. This logic is encapsulated within concrete
60 // implementations of this class, and testStaticRuntime is only aware of this
61 // interface.
62 class StaticRuntimeTestContext {
63 public:
64 virtual ~StaticRuntimeTestContext() = default;
65
66 virtual IValue getExpected(const std::vector<IValue>& args) = 0;
67 virtual StaticModule makeStaticModule(
68 const StaticModuleOptions& opt) const = 0;
69 };
70
71 class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
72 public:
ModuleStaticRuntimeTestContext(const std::string & source_jit)73 explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
74 : module_("module") {
75 module_.define(source_jit);
76 }
77
getExpected(const std::vector<IValue> & args)78 IValue getExpected(const std::vector<IValue>& args) override {
79 return module_.forward(args);
80 }
81
makeStaticModule(const StaticModuleOptions & opt) const82 StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
83 return torch::jit::StaticModule(
84 module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
85 }
86
87 private:
88 Module module_;
89 };
90
91 class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
92 public:
GraphStaticRuntimeContext(const std::string & source_ir)93 explicit GraphStaticRuntimeContext(const std::string& source_ir) {
94 graph_ = std::make_shared<Graph>();
95 std::unordered_map<std::string, Value*> vmap;
96 parseIR(source_ir, graph_.get(), vmap);
97
98 graph_exec_ = GraphExecutorWrapper(graph_);
99 }
100
getExpected(const std::vector<IValue> & args)101 IValue getExpected(const std::vector<IValue>& args) override {
102 return graph_exec_(args);
103 }
104
makeStaticModule(const StaticModuleOptions & opt) const105 StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
106 return StaticModule(graph_, opt, /* sample_inputs */ {});
107 }
108
109 private:
110 std::shared_ptr<Graph> graph_;
111 GraphExecutorWrapper graph_exec_;
112 };
113
makeTestContext(const std::string & source)114 std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
115 const std::string& source) {
116 try {
117 return std::make_unique<ModuleStaticRuntimeTestContext>(source);
118 // Could not parse as TorchScript, assume it's IR
119 } catch (const std::runtime_error&) {
120 return std::make_unique<GraphStaticRuntimeContext>(source);
121 }
122 }
123
compareTensorLists(const std::vector<IValue> & l,const std::vector<IValue> & r,const bool use_allclose,const bool use_equalnan)124 void compareTensorLists(
125 const std::vector<IValue>& l, /* expects */
126 const std::vector<IValue>& r, /* values */
127 const bool use_allclose,
128 const bool use_equalnan) {
129 EXPECT_TRUE(l.size() == r.size());
130 for (auto i : c10::irange(l.size())) {
131 ASSERT_TRUE(l[i].isTensor());
132 ASSERT_TRUE(r[i].isTensor());
133 VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
134 VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
135 if (!l[i].toTensor().defined()) {
136 EXPECT_TRUE(!r[i].toTensor().defined());
137 } else {
138 if (use_allclose) {
139 EXPECT_TRUE(at::allclose(
140 l[i].toTensor(),
141 r[i].toTensor(),
142 /*rtol*/ 1e-05,
143 /*atol*/ 1e-08,
144 use_equalnan));
145 } else {
146 EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
147 }
148 }
149 }
150 }
151
152 } // namespace
153
compareResults(const IValue & expect,const IValue & actual,const bool use_allclose,const bool use_equalnan)154 void compareResults(
155 const IValue& expect,
156 const IValue& actual,
157 const bool use_allclose,
158 const bool use_equalnan) {
159 if (expect.isTensor()) {
160 VLOG(2) << "expect " << expect.toTensor() << std::endl;
161 VLOG(2) << "output " << actual.toTensor() << std::endl;
162 EXPECT_TRUE(actual.isTensor());
163 if (use_allclose) {
164 EXPECT_TRUE(at::allclose(
165 expect.toTensor(),
166 actual.toTensor(),
167 /*rtol*/ 1e-05,
168 /*atol*/ 1e-08,
169 use_equalnan));
170 } else {
171 EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
172 }
173 return;
174 } else if (expect.isTuple()) {
175 EXPECT_TRUE(actual.isTuple());
176 auto lhs = expect.toTupleRef().elements();
177 auto rhs = actual.toTupleRef().elements();
178 ASSERT_TRUE(lhs.size() == rhs.size());
179 for (size_t i = 0; i < lhs.size(); i++) {
180 compareResults(lhs[i], rhs[i]);
181 }
182 } else if (expect.isList()) {
183 EXPECT_TRUE(actual.isList());
184 auto lhs = expect.toList();
185 auto rhs = actual.toList();
186 ASSERT_TRUE(lhs.size() == rhs.size());
187 for (size_t i = 0; i < lhs.size(); i++) {
188 compareResults(lhs[i], rhs[i]);
189 }
190 } else if (expect.isGenericDict()) {
191 EXPECT_TRUE(actual.isGenericDict());
192 auto lhs = expect.toGenericDict();
193 auto rhs = actual.toGenericDict();
194 EXPECT_TRUE(lhs.size() == rhs.size());
195 for (auto& lh : lhs) {
196 auto f = rhs.find(lh.key());
197 ASSERT_FALSE(f == rhs.end());
198 compareResults(lh.value(), f->value());
199 }
200 } else {
201 // fall back to the default comparison impl in IValue
202 EXPECT_TRUE(expect == actual);
203 }
204 }
205
getTensor(const at::IValue & ival)206 at::Tensor getTensor(const at::IValue& ival) {
207 if (ival.isTensor()) {
208 return ival.toTensor();
209 } else if (ival.isTensorList()) {
210 auto tensor_vec = ival.toTensorVector();
211 TORCH_CHECK(tensor_vec.size() == 1);
212 return tensor_vec[0];
213 } else if (ival.isTuple()) {
214 auto tuple = ival.toTuple();
215 auto ivalue_vec = tuple->elements();
216 TORCH_CHECK(ivalue_vec.size() == 1);
217 return ivalue_vec[0].toTensor();
218 } else {
219 CAFFE_THROW("Unknown input IValue");
220 }
221 }
222
getNodeWithKind(const StaticModule & smodule,const std::string & kind)223 Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
224 return smodule.findNodeWithKindForTesting(kind);
225 }
226
getNodeWithKind(std::shared_ptr<Graph> & graph,const std::string & kind)227 Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
228 const auto symbol = c10::Symbol::fromQualString(kind);
229 DepthFirstGraphNodeIterator it(graph);
230 for (auto* node = it.next(); node != nullptr; node = it.next()) {
231 if (node->kind() == symbol) {
232 return node;
233 }
234 }
235 return nullptr;
236 }
237
hasNodeWithKind(const StaticModule & smodule,const std::string & kind)238 bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
239 return getNodeWithKind(smodule, kind) != nullptr;
240 }
241
hasNodeWithKind(std::shared_ptr<Graph> & graph,const std::string & kind)242 bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
243 return getNodeWithKind(graph, kind) != nullptr;
244 }
245
getGraphFromScript(const std::string & jit_script)246 std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script) {
247 script::Module module("module");
248 module.define(jit_script);
249
250 Method method = module.get_method("forward");
251 return module.get_method("forward").graph();
252 }
253
getGraphFromIR(const std::string & ir)254 std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
255 auto graph = std::make_shared<Graph>();
256 std::unordered_map<std::string, Value*> vmap;
257 parseIR(ir, graph.get(), vmap);
258 return graph;
259 }
260
compareResultsWithJIT(StaticRuntime & runtime,const std::shared_ptr<Graph> & graph,const std::vector<c10::IValue> & args,const bool use_allclose,const bool use_equalnan)261 void compareResultsWithJIT(
262 StaticRuntime& runtime,
263 const std::shared_ptr<Graph>& graph,
264 const std::vector<c10::IValue>& args,
265 const bool use_allclose,
266 const bool use_equalnan) {
267 GraphExecutorWrapper graph_exec(graph);
268 auto expected = graph_exec(args);
269 auto actual = runtime(args, {});
270 runtime.check_for_memory_leak();
271 compareResults(expected, actual, use_allclose, use_equalnan);
272 }
273
testStaticRuntime(const std::string & source,const std::vector<IValue> & args,const std::vector<IValue> & args2,const bool use_allclose,const bool use_equalnan,const bool check_resize)274 void testStaticRuntime(
275 const std::string& source,
276 const std::vector<IValue>& args,
277 const std::vector<IValue>& args2,
278 const bool use_allclose,
279 const bool use_equalnan,
280 const bool check_resize) {
281 auto test_context = makeTestContext(source);
282
283 std::vector<IValue> args_tensors, args_copy;
284 for (const auto& ival : args) {
285 if (ival.isTensor()) {
286 args_tensors.emplace_back(ival);
287 const at::Tensor& t = ival.toTensor();
288 args_copy.emplace_back(t.clone());
289 }
290 }
291
292 auto expect = test_context->getExpected(args);
293
294 for (bool enable_out_variant : {true, false}) {
295 for (bool manage_output_tensors : {true, false}) {
296 for (bool enable_tensorexpr_fusion : {true, false}) {
297 if (!enable_out_variant && manage_output_tensors) {
298 continue;
299 }
300 // run static runtime three times
301 // 1st run: collect allocation profiles (args)
302 // 2nd run: exercise memory planner and resizing with args2
303 // 3rd run: run with args again
304 StaticModuleOptions opts;
305 opts.enable_out_variant = enable_out_variant;
306 opts.optimize_memory = enable_out_variant;
307 opts.manage_output_tensors = manage_output_tensors;
308 opts.enable_tensorexpr_fusion = enable_tensorexpr_fusion;
309
310 auto smodule = test_context->makeStaticModule(opts);
311 StaticRuntime runtime(smodule);
312 auto actual = runtime(args, {});
313 if (actual.isTensor()) {
314 EXPECT_GE(smodule.num_nodes(), 2)
315 << "If we only have one node, the output of the op we are testing is "
316 << "not being managed by the memory planner! A failure here "
317 << "can typically be fixed by clone()ing the output of the test script.";
318 }
319 runtime.check_for_memory_leak();
320 // first run
321 VLOG(2) << "enable_out_variant: " << enable_out_variant;
322 VLOG(2) << "manage_output_tensors: " << manage_output_tensors;
323 VLOG(2) << "enable_tensorexpr_fusion: " << enable_tensorexpr_fusion;
324 VLOG(2) << "args: " << args;
325 VLOG(2) << "args2: " << args2;
326 VLOG(2) << "expect: " << expect;
327 VLOG(2) << "actual: " << actual;
328 compareResults(expect, actual, use_allclose, use_equalnan);
329 VLOG(2) << "first run comparison done";
330 if (manage_output_tensors) {
331 actual = IValue();
332 runtime.deallocateOutputTensors();
333 runtime.checkOutputTensorMemoryLeaks();
334 }
335
336 if (!args2.empty()) {
337 auto* memory_planner = runtime.get_memory_planner();
338 size_t managed_bytes =
339 memory_planner ? memory_planner->total_managed() : 0;
340
341 // Run static runtime again with inputs of a different shape.
342 expect = test_context->getExpected(args2);
343 actual = runtime(args2, {});
344 runtime.check_for_memory_leak();
345 VLOG(2) << "comparing with args2";
346 compareResults(expect, actual, use_allclose, use_equalnan);
347 VLOG(2) << "second run comparison done";
348 if (manage_output_tensors) {
349 actual = IValue();
350 runtime.deallocateOutputTensors();
351 runtime.checkOutputTensorMemoryLeaks();
352 }
353
354 size_t new_managed_bytes =
355 memory_planner ? memory_planner->total_managed() : 0;
356 if (check_resize && new_managed_bytes >= 0) {
357 EXPECT_GE(new_managed_bytes, managed_bytes);
358 }
359
360 // Run static runtime again with an input of the shape observed during
361 // the profile run.
362 expect = test_context->getExpected(args);
363 actual = runtime(args, {});
364 runtime.check_for_memory_leak();
365 // third run
366 VLOG(2) << "comparing third run";
367 compareResults(expect, actual, use_allclose, use_equalnan);
368 VLOG(2) << "third run comparison done";
369 if (manage_output_tensors) {
370 actual = IValue();
371 runtime.deallocateOutputTensors();
372 runtime.checkOutputTensorMemoryLeaks();
373 }
374 } else {
375 // run static runtime again to exercise the memory planner
376 // and allocate managed tensors.
377 actual = runtime(args, {});
378 runtime.check_for_memory_leak();
379 VLOG(2) << "comparing second run with same args";
380 compareResults(expect, actual, use_allclose, use_equalnan);
381 VLOG(2) << "second run comparison done";
382 if (manage_output_tensors) {
383 actual = IValue();
384 runtime.deallocateOutputTensors();
385 runtime.checkOutputTensorMemoryLeaks();
386 }
387 // third run to use the allocated managed tensors.
388 actual = runtime(args, {});
389 runtime.check_for_memory_leak();
390 if (manage_output_tensors) {
391 actual = IValue();
392 runtime.deallocateOutputTensors();
393 runtime.checkOutputTensorMemoryLeaks();
394 }
395 }
396 }
397 }
398 }
399
400 // make sure inputs were not modified
401 VLOG(2) << "Printing out input tensors";
402 compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
403 }
404
hasProcessedNodeWithName(torch::jit::StaticModule & smodule,const char * name)405 bool hasProcessedNodeWithName(
406 torch::jit::StaticModule& smodule,
407 const char* name) {
408 return smodule.findNodeWithKindForTesting(name) != nullptr;
409 }
410
411 } // namespace test
412 } // namespace jit
413 } // namespace torch
414