1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3
4 #include <ATen/ATen.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/core/interned_strings.h>
7 #include <ATen/core/ivalue.h>
8 #include <ATen/core/jit_type_base.h>
9 #include <c10/macros/Macros.h>
10 #include <test/cpp/jit/test_utils.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
13 #include <torch/csrc/jit/tensorexpr/kernel.h>
14
15 #include <torch/csrc/autograd/engine.h>
16 #include <torch/csrc/autograd/generated/variable_factories.h>
17 #include <torch/csrc/autograd/profiler.h>
18 #include <torch/csrc/autograd/variable.h>
19 #include <torch/csrc/jit/api/function_impl.h>
20 #include <torch/csrc/jit/api/module.h>
21 #include <torch/csrc/jit/codegen/fuser/interface.h>
22 #include <torch/csrc/jit/frontend/ir_emitter.h>
23 #include <torch/csrc/jit/frontend/tracer.h>
24 #include <torch/csrc/jit/ir/alias_analysis.h>
25 #include <torch/csrc/jit/ir/attributes.h>
26 #include <torch/csrc/jit/ir/irparser.h>
27 #include <torch/csrc/jit/ir/scope.h>
28 #include <torch/csrc/jit/ir/type_hashing.h>
29 #include <torch/csrc/jit/jit_log.h>
30 #include <torch/csrc/jit/passes/bailout_graph.h>
31 #include <torch/csrc/jit/passes/canonicalize.h>
32 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
33 #include <torch/csrc/jit/passes/constant_propagation.h>
34 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
35 #include <torch/csrc/jit/passes/dead_code_elimination.h>
36 #include <torch/csrc/jit/passes/graph_fuser.h>
37 #include <torch/csrc/jit/passes/guard_elimination.h>
38 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
39 #include <torch/csrc/jit/passes/insert_guards.h>
40 #include <torch/csrc/jit/passes/liveness.h>
41 #include <torch/csrc/jit/passes/loop_unrolling.h>
42 #include <torch/csrc/jit/passes/lower_grad_of.h>
43 #include <torch/csrc/jit/passes/lower_tuples.h>
44 #include <torch/csrc/jit/passes/pass_manager.h>
45 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
46 #include <torch/csrc/jit/passes/restore_mutation.h>
47 #include <torch/csrc/jit/passes/shape_analysis.h>
48 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
49 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
50 #include <torch/csrc/jit/runtime/argument_spec.h>
51 #include <torch/csrc/jit/runtime/autodiff.h>
52 #include <torch/csrc/jit/runtime/custom_operator.h>
53 #include <torch/csrc/jit/runtime/decomposition_registry.h>
54 #include <torch/csrc/jit/runtime/graph_executor.h>
55 #include <torch/csrc/jit/runtime/interpreter.h>
56 #include <torch/csrc/jit/runtime/jit_trace.h>
57 #include <torch/csrc/jit/runtime/profiling_record.h>
58 #include <torch/csrc/jit/runtime/symbolic_script.h>
59 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
60 #include <torch/csrc/jit/serialization/import.h>
61 #include <torch/csrc/jit/testing/file_check.h>
62 #include <torch/jit.h>
63 #include <torch/script.h>
64
65 #include <onnx/onnx_pb.h>
66
67 #include <c10/util/Exception.h>
68 #include <c10/util/ThreadLocalDebugInfo.h>
69
70 #include <torch/csrc/jit/passes/freeze_module.h>
71 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
72 #include <algorithm>
73 #include <cstddef>
74 #include <functional>
75 #include <iostream>
76 #include <memory>
77 #include <set>
78 #include <stdexcept>
79 #include <string>
80 #include <tuple>
81 #include <unordered_map>
82 #include <unordered_set>
83 #include <utility>
84 #include <vector>
85
86 namespace torch {
87 namespace jit {
aliasAnalysisFromSchema()88 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
89 return c10::AliasAnalysisKind::FROM_SCHEMA;
90 }
91
92 template <typename T>
operator <<(std::ostream & out,const std::vector<T> & list)93 std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
94 size_t i = 0;
95 out << "{";
96 for (auto&& e : list) {
97 if (i++ > 0)
98 out << ", ";
99 out << e;
100 }
101 out << "}";
102 return out;
103 }
104
TEST(InternedStringsTest,Basic)105 TEST(InternedStringsTest, Basic) {
106 ASSERT_EQ(prim::Param, Symbol::prim("Param"));
107 ASSERT_EQ(prim::Return, Symbol::prim("Return"));
108 ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
109 ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
110 Symbol newsym = Symbol::aten("__NEW_SYMBOL");
111 size_t symstart = newsym;
112 ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
113 // TODO: This test is a bit too close to the implementation details.
114 ASSERT_EQ(Symbol::aten("What"), symstart + 1);
115 ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
116 ASSERT_EQ(Symbol::aten("What"), symstart + 1);
117 ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
118 ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
119 }
120
TEST(FromQualStringTest,Basic)121 TEST(FromQualStringTest, Basic) {
122 ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
123 ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
124 ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
125 ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
126 ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
127 ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
128 ASSERT_EQ(
129 Symbol::fromQualString("::").ns().toQualString(),
130 std::string("namespaces::"));
131 ASSERT_EQ(
132 Symbol::fromQualString("new_ns::param").toUnqualString(),
133 std::string("param"));
134 ASSERT_EQ(
135 Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
136 std::string("new_ns"));
137 ASSERT_EQ(
138 Symbol::fromQualString("new_ns::param").ns(),
139 Symbol::fromQualString("namespaces::new_ns"));
140
141 auto bad_inputs = {"scope", ":", ""};
142 for (auto input : bad_inputs) {
143 try {
144 Symbol::fromQualString(input);
145 ASSERT_TRUE(0);
146 } catch (const std::exception& c) {
147 }
148 }
149 }
150
TEST(THNNConvTest,Basic)151 TEST(THNNConvTest, Basic) {
152 std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
153 std::vector<int64_t> kernel_size = {3, 5};
154 std::vector<int64_t> stride = {1, 2};
155 std::vector<int64_t> padding = {2, 1};
156 constexpr int out_channels = 5;
157
158 // make inputs
159 at::Tensor input = torch::randn(input_size);
160 at::Tensor weight = torch::randn(
161 {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
162 at::Tensor bias = torch::randn({out_channels});
163
164 // run forward eagerly
165 at::Tensor output = at::_slow_conv2d_forward(
166 input, weight, kernel_size, bias, stride, padding);
167
168 // make grad_outputs
169 at::Tensor grad_output =
170 torch::randn_like(output, at::MemoryFormat::Preserve);
171
172 // run backward eagerly
173 auto [grad_input, grad_weight, grad_bias] = at::_slow_conv2d_backward(
174 grad_output,
175 input,
176 weight,
177 kernel_size,
178 stride,
179 padding,
180 {true, true, true});
181
182 // make JIT graph
183 auto graph = std::make_shared<Graph>();
184 auto ksz_val = graph->insertConstant(kernel_size);
185 auto kst_val = graph->insertConstant(stride);
186 auto pad_val = graph->insertConstant(padding);
187
188 auto inputg = graph->addInput("self");
189 auto weightg = graph->addInput("weight");
190 auto biasg = graph->addInput("bias");
191
192 Value* conv = graph->insert(
193 aten::_slow_conv2d_forward,
194 {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
195 auto outputs = conv->node()->outputs();
196 for (auto output : outputs) {
197 graph->registerOutput(output);
198 }
199 LowerAllTuples(graph);
200 graph->lint();
201
202 // differentiate JIT graph
203 EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
204 ConstantPropagation(graph);
205 auto grad_spec = differentiate(graph);
206 LowerGradOf(*grad_spec.df);
207
208 // prepare JIT inputs / gradients
209 tensor_list tensors_in;
210 tensors_in.push_back(input);
211 tensors_in.push_back(weight);
212 tensors_in.push_back(bias);
213
214 tensor_list tensor_grads_in;
215 tensor_grads_in.push_back(grad_output);
216
217 // Get outputs from the interpreter
218 auto [tensors_out, tensor_grads_out] =
219 runGradient(grad_spec, tensors_in, tensor_grads_in);
220
221 // prepare expected structs
222 tensor_list expected_tensors_out, expected_tensor_grads_out;
223 expected_tensors_out.push_back(output);
224 expected_tensor_grads_out.push_back(grad_input);
225 expected_tensor_grads_out.push_back(grad_weight);
226 expected_tensor_grads_out.push_back(grad_bias);
227
228 // Compare results
229 assertAllClose(tensors_out, expected_tensors_out);
230 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
231 }
232
TEST(ATenNativeBatchNormTest,Basic)233 TEST(ATenNativeBatchNormTest, Basic) {
234 // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
235 // running_mean, Tensor running_var, bool training, float momentum, float eps)
236 // -> (Tensor, Tensor, Tensor)
237 std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
238 bool training = true;
239 float momentum = 0.9;
240 float eps = 1e-5;
241
242 // make inputs
243 at::Tensor input = torch::randn(input_size);
244 at::Tensor weight = torch::randn({input_size[1]});
245 at::Tensor bias = torch::randn({input_size[1]});
246 at::Tensor running_mean = torch::randn({input_size[1]});
247 at::Tensor running_var = torch::randn({input_size[1]});
248
249 // running_mean and running_var are changed in-place, so clone and send them
250 at::Tensor running_mean_eager = running_mean.clone();
251 at::Tensor running_var_eager = running_var.clone();
252 at::Tensor running_mean_jit = running_mean.clone();
253 at::Tensor running_var_jit = running_var.clone();
254
255 // run forward eagerly
256 auto [output, savemean, saveinvstd] = at::native_batch_norm(
257 input,
258 weight,
259 bias,
260 running_mean_eager,
261 running_var_eager,
262 training,
263 momentum,
264 eps);
265
266 // make grad_outputs
267 at::Tensor grad_output =
268 torch::randn_like(output, at::MemoryFormat::Preserve);
269 at::Tensor grad_savemean =
270 torch::zeros_like(savemean, at::MemoryFormat::Preserve);
271 at::Tensor grad_saveinvstd =
272 torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
273
274 // run backward eagerly
275 // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
276 // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
277 // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
278 // Tensor, Tensor)
279 auto [grad_input, grad_weight, grad_bias] = at::native_batch_norm_backward(
280 grad_output,
281 input,
282 weight,
283 running_mean_eager,
284 running_var_eager,
285 savemean,
286 saveinvstd,
287 training,
288 eps,
289 {true, true, true});
290
291 // make JIT graph
292 auto graph = std::make_shared<Graph>();
293 auto training_val = graph->insertConstant(IValue(training));
294 auto momentum_val = graph->insertConstant(IValue(momentum));
295 auto eps_val = graph->insertConstant(IValue(eps));
296
297 auto inputg = graph->addInput("self");
298 auto weightg = graph->addInput("weight");
299 auto biasg = graph->addInput("bias");
300 auto running_meang = graph->addInput("running_mean");
301 auto running_varg = graph->addInput("running_var");
302
303 Value* bn = graph->insert(
304 aten::native_batch_norm,
305 {inputg,
306 weightg,
307 biasg,
308 running_meang,
309 running_varg,
310 training_val,
311 momentum_val,
312 eps_val});
313 auto outputs = bn->node()->outputs();
314 for (auto output : outputs) {
315 graph->registerOutput(output);
316 }
317 LowerAllTuples(graph);
318 graph->lint();
319
320 // differentiate JIT graph
321 EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
322 ConstantPropagation(graph);
323 auto grad_spec = differentiate(graph);
324 LowerGradOf(*grad_spec.df);
325
326 // prepare JIT inputs / gradients
327 tensor_list tensors_in;
328 tensors_in.push_back(input);
329 tensors_in.push_back(weight);
330 tensors_in.push_back(bias);
331 tensors_in.push_back(running_mean_jit);
332 tensors_in.push_back(running_var_jit);
333
334 tensor_list tensor_grads_in;
335 tensor_grads_in.push_back(grad_output);
336 tensor_grads_in.push_back(grad_savemean);
337 tensor_grads_in.push_back(grad_saveinvstd);
338
339 // Get outputs from the interpreter
340 auto [tensors_out, tensor_grads_out] =
341 runGradient(grad_spec, tensors_in, tensor_grads_in);
342
343 // prepare expected structs
344 tensor_list expected_tensors_out, expected_tensor_grads_out;
345 expected_tensors_out.push_back(output);
346 expected_tensors_out.push_back(savemean);
347 expected_tensors_out.push_back(saveinvstd);
348 expected_tensors_out.push_back(running_mean_eager);
349 expected_tensors_out.push_back(running_var_eager);
350 expected_tensor_grads_out.push_back(grad_input);
351 expected_tensor_grads_out.push_back(grad_weight);
352 expected_tensor_grads_out.push_back(grad_bias);
353
354 tensors_out.push_back(running_mean_jit);
355 tensors_out.push_back(running_var_jit);
356
357 // Compare results
358 assertAllClose(tensors_out, expected_tensors_out);
359 assertAllClose(tensor_grads_out, expected_tensor_grads_out);
360 }
361
TEST(CustomFusionTest,Basic)362 TEST(CustomFusionTest, Basic) {
363 #if defined(FBCODE_CAFFE2)
364 return;
365 #endif
366
367 auto graph_string = R"IR(
368 graph(%0 : Float(2, 3, 4),
369 %1 : Float(2, 3, 4)):
370 %2 : Tensor = aten::mul(%0, %1)
371 %3 : Tensor = aten::mul(%2, %0)
372 return (%3))IR";
373 auto g = std::make_shared<Graph>();
374 torch::jit::parseIR(graph_string, g.get());
375
376 torch::jit::overrideCanFuseOnCPU(true);
377 CustomFuseGraph(
378 g,
379 [](Node* n) { return n->kind() != prim::Param; },
380 Symbol::fromQualString("prim::FusionGroup"));
381 torch::jit::overrideCanFuseOnCPU(false);
382
383 const auto& nodes = g->nodes();
384 auto fusion_group =
385 std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
386 return node->kind() == Symbol::fromQualString("prim::FusionGroup");
387 });
388 AT_ASSERT(fusion_group != nodes.end());
389
390 auto subgraph = fusion_group->g(attr::Subgraph);
391 auto hits = 0;
392 // two multiplications
393 for (const auto& n : subgraph->nodes()) {
394 (void)n;
395 hits++;
396 }
397 AT_ASSERT(hits == 2);
398 }
399
TEST(CustomFusionTest,NestedBlocks)400 TEST(CustomFusionTest, NestedBlocks) {
401 #if defined(FBCODE_CAFFE2)
402 return;
403 #endif
404
405 auto graph_string = R"IR(
406 graph(%0 : Float(2, 3, 4),
407 %1 : Float(2, 3, 4),
408 %2 : Float(2, 3, 4)):
409 %3 : int = prim::Constant[value=1]()
410 %4 : Tensor = prim::If(%2)
411 block0():
412 %5 : Tensor = aten::mul(%0, %2)
413 %6 : Tensor = aten::mul(%5, %1)
414 -> (%6)
415 block1():
416 %7 : Tensor = aten::add(%0, %2, %3)
417 %8 : Tensor = aten::add(%7, %1, %3)
418 -> (%8)
419 %9 : Tensor = aten::add(%4, %2, %3)
420 return (%4))IR";
421 auto g = std::make_shared<Graph>();
422 torch::jit::parseIR(graph_string, g.get());
423
424 CustomFuseGraph(
425 g,
426 [](Node* n) { return n->kind() == aten::mul; },
427 Symbol::fromQualString("prim::FusionGroup"));
428
429 // Could be done in more efficient ways, but this is only a test.
430 std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
431 Symbol s) {
432 for (auto node : b->nodes()) {
433 if (node->kind() == s)
434 return true;
435 for (auto nested_b : node->blocks())
436 if (dfs(nested_b, s))
437 return true;
438 }
439 return false;
440 };
441
442 AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
443 }
444
445 static const auto cf_examples = R"JIT(
446 def if_test(a, b):
447 # FIXME: use 0 instead of a.
448 # c = 0
449 c = a
450 if bool(a < b):
451 c = b
452 else:
453 c = a
454 return c
455 def if_one(a, b):
456 c = b
457 if bool(a < b):
458 c = a
459 return c
460 def while_test(a, i):
461 while bool(i < 3):
462 a *= a
463 i += 1
464 return a
465 )JIT";
466
TEST(ControlFlowTest,Basic)467 TEST(ControlFlowTest, Basic) {
468 auto cu = compile(cf_examples);
469
470 auto run = [&](const std::string& name, std::vector<IValue> stack) {
471 auto graph = toGraphFunction(cu->get_function(name)).graph();
472 Code code(graph, "");
473 InterpreterState interp(code);
474 interp.run(stack);
475 return stack;
476 };
477
478 auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
479 auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
480 auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
481 return V(run(name, {L(a), L(b)})[0]);
482 };
483 ASSERT_EQ(2, run_binary("if_test", 1, 2));
484 ASSERT_EQ(3, run_binary("if_test", 3, 2));
485 ASSERT_EQ(2, run_binary("if_one", 2, 3));
486 ASSERT_EQ(2, run_binary("if_one", 3, 2));
487 ASSERT_EQ(256, run_binary("while_test", 2, 0));
488 }
489
490 #if !(C10_ASAN_ENABLED || C10_UBSAN_ENABLED)
491 // This test fails vptr UBSAN checks
492
TEST(ProtoTest,Basic)493 TEST(ProtoTest, Basic) {
494 ::ONNX_NAMESPACE::ModelProto proto;
495 proto.set_producer_name("foo");
496 }
497 #endif
498
499 // test a few features that are not directly used in schemas yet
TEST(SchemaParserTest,NestedArrays)500 TEST(SchemaParserTest, NestedArrays) {
501 // nested arrays
502 auto s = parseSchema("at::what(int[][4] foo) -> ()");
503 ASSERT_TRUE(s.arguments().at(0).N() == 4);
504 ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
505 .at(0)
506 .type()
507 ->expectRef<ListType>()
508 .getElementType()
509 ->expectRef<ListType>()
510 .getElementType()));
511 auto s2 = parseSchema("at::what(int[][] foo) -> ()");
512 ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
513 .at(0)
514 .type()
515 ->expectRef<ListType>()
516 .getElementType()
517 ->expectRef<ListType>()
518 .getElementType()));
519 }
520
TEST(SchemaParserTest,OutVariant)521 TEST(SchemaParserTest, OutVariant) {
522 auto schema_with_out = parseSchema(
523 "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
524 ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
525 ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
526
527 auto schema_without_out =
528 parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
529
530 for (const auto& arg : schema_without_out.arguments()) {
531 ASSERT_TRUE(!arg.is_out());
532 }
533
534 auto schema_with_is_write = parseSchema(
535 "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
536
537 for (const auto& arg : schema_with_is_write.arguments()) {
538 ASSERT_TRUE(!arg.is_out());
539 }
540 }
541
542 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest,NamedReturns)543 TEST(SchemaParserTest, NamedReturns) {
544 // named returns
545 parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
546 auto s3 =
547 parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
548 ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
549 ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
550 }
551
TEST(SchemaParserTest,Futures)552 TEST(SchemaParserTest, Futures) {
553 // futures
554 auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
555 ASSERT_TRUE(IntType::get()->isSubtypeOf(
556 *s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
557 }
558
TEST(SchemaParserTest,AnnotatedAliasSets)559 TEST(SchemaParserTest, AnnotatedAliasSets) {
560 // test tensor with annotated alias sets
561 parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
562 }
563
TEST(SchemaParserTest,TensorListAnnotatedAliasSets)564 TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
565 const auto s = parseSchema(
566 "at::foo(Tensor(a!) self, Tensor(b!)[] out)"
567 " -> ()");
568 const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
569 const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
570 ASSERT_TRUE(
571 selfAliasInfo->beforeSets() ==
572 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
573 ASSERT_TRUE(selfAliasInfo->isWrite());
574
575 ASSERT_TRUE(outAliasInfo->isWrite());
576 ASSERT_TRUE(outAliasInfo->beforeSets().empty());
577 ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
578
579 auto containedType = outAliasInfo->containedTypes()[0];
580
581 ASSERT_TRUE(containedType.isWrite());
582 ASSERT_TRUE(
583 containedType.beforeSets() ==
584 std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
585 }
586
TEST(SchemaParserTest,AnnotatedAliasWithoutBeforeSet)587 TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
588 EXPECT_THAT(
589 []() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
590 ::testing::Throws<std::runtime_error>(::testing::Property(
591 &std::runtime_error::what,
592 ::testing::HasSubstr("expected ident but found '!' here"))));
593 }
594
TEST(SchemaParserTest,BeforeAfterSets)595 TEST(SchemaParserTest, BeforeAfterSets) {
596 const auto s = parseSchema(
597 "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
598 " -> (Tensor(b|c)[](a!))");
599
600 // The list itself is annotated with `a`
601 const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
602 ASSERT_NE(aliasInfo, nullptr);
603 ASSERT_TRUE(
604 aliasInfo->beforeSets() ==
605 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
606 ASSERT_TRUE(aliasInfo->isWrite());
607
608 // Check the contained types
609 ASSERT_TRUE(!aliasInfo->containedTypes().empty());
610 const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
611 const auto expected = std::unordered_set<Symbol>{
612 Symbol::fromQualString("alias::b"),
613 Symbol::fromQualString("alias::c"),
614 };
615 ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
616 ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
617 ASSERT_FALSE(containedAliasInfo.isWrite());
618 }
619
TEST(SchemaParserTest,BeforeAfterSets2)620 TEST(SchemaParserTest, BeforeAfterSets2) {
621 const auto s = parseSchema(
622 "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
623 " -> (Tensor(b|c)[](a!))");
624
625 // The list itself is annotated with `a`
626 const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
627 ASSERT_NE(aliasInfo, nullptr);
628 ASSERT_EQ(
629 aliasInfo->beforeSets(),
630 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
631 ASSERT_EQ(
632 aliasInfo->afterSets(),
633 std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
634 ASSERT_TRUE(aliasInfo->isWrite());
635 ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
636
637 // Check the contained types
638 ASSERT_TRUE(!aliasInfo->containedTypes().empty());
639 const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
640 const auto expectedBefore = std::unordered_set<Symbol>{
641 Symbol::fromQualString("alias::b"),
642 };
643 const auto expectedAfter = std::unordered_set<Symbol>{
644 Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
645 ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
646 ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
647 ASSERT_FALSE(containedAliasInfo.isWrite());
648 }
649
TEST(TopologicalIndexTest,Basic)650 TEST(TopologicalIndexTest, Basic) {
651 Graph graph;
652 auto node1 = graph.create(prim::AutogradZero);
653 auto node2 = graph.create(prim::AutogradZero);
654 auto node3 = graph.create(prim::AutogradZero);
655 auto node4 = graph.create(prim::AutogradZero);
656
657 graph.appendNode(node4);
658 graph.prependNode(node1);
659 node2->insertAfter(node1);
660 node3->insertBefore(node4);
661
662 // nodes should be in numerical order
663 ASSERT_TRUE(node1->isBefore(node2));
664 ASSERT_TRUE(node1->isBefore(node3));
665 ASSERT_TRUE(node1->isBefore(node4));
666 ASSERT_TRUE(node2->isAfter(node1));
667 ASSERT_TRUE(node2->isBefore(node3));
668 ASSERT_TRUE(node2->isBefore(node4));
669 ASSERT_FALSE(node3->isBefore(node1));
670 ASSERT_FALSE(node3->isBefore(node2));
671 ASSERT_FALSE(node3->isAfter(node4));
672
673 // Built up a block structure
674 // node3
675 // /\ ...
676 // A B block1
677 // \ ...
678 // C block2
679 auto block1 = node3->addBlock();
680 auto A = graph.create(prim::AutogradZero);
681 block1->appendNode(A);
682 auto B = graph.create(prim::AutogradZero);
683 block1->appendNode(B);
684 auto block2 = B->addBlock();
685 auto C = graph.create(prim::AutogradZero);
686 block2->appendNode(C);
687
688 // Check isAfter on different block levels
689 ASSERT_TRUE(node1->isBefore(A));
690 ASSERT_TRUE(A->isBefore(B));
691 ASSERT_TRUE(A->isBefore(C));
692
693 // make sure things don't blow up on deletions
694 node2->destroy();
695 auto node2p = graph.create(prim::AutogradZero);
696 node2p->insertAfter(node1);
697 ASSERT_TRUE(node1->isBefore(node2p));
698 ASSERT_TRUE(node2p->isBefore(node3));
699 }
700
TEST(TopologicalIndexTest,Reindex)701 TEST(TopologicalIndexTest, Reindex) {
702 // Induce reindexing to test that path
703 Graph graph;
704 std::map<size_t, Node*> nodes;
705
706 auto anchor = graph.create(prim::AutogradZero);
707 graph.appendNode(anchor);
708 // Inserting to the same place a lot will trigger reindexing
709 for (auto i = 0; i < 100; ++i) {
710 auto n = graph.create(prim::AutogradZero);
711 n->insertAfter(anchor);
712 nodes[i] = n;
713 }
714
715 // Nodes should be in reverse order
716 for (auto i = 0; i < 100; ++i) {
717 for (auto j = i + 1; j < 100; ++j) {
718 ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
719 }
720 }
721 }
722
invokeTestRecordFunction(at::Tensor & t)723 at::Tensor invokeTestRecordFunction(at::Tensor& t) {
724 RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
725
726 auto t2 = t.pow(2);
727 return t2;
728 }
729
730 static const auto invokeTestRecordFunction_JIT = R"JIT(
731 def foo(self, t):
732 t2 = t.pow(2)
733 return t2
734
735 def forward(self, t):
736 return self.foo(t)
737 )JIT";
738
invokeTestRecordFunctionJIT(at::Tensor & t)739 at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
740 RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
741
742 auto module = std::make_shared<script::Module>(
743 "RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
744 module->define(invokeTestRecordFunction_JIT);
745 return module->forward({t}).toTensor();
746 }
747
748 using TracedTestValues =
749 std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
750
checkTracedInputs(const TracedTestValues & inputs)751 void checkTracedInputs(const TracedTestValues& inputs) {
752 bool found_test = false;
753 bool found_pow = false;
754 bool found_mul = false;
755 for (const auto& input : inputs) {
756 const auto& fn = std::get<0>(input);
757 const auto& sizes = std::get<1>(input);
758
759 if (fn == "test") {
760 found_test = true;
761 TORCH_CHECK(sizes.size() == 1);
762 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
763 } else if (fn == "aten::pow") {
764 found_pow = true;
765 TORCH_CHECK(sizes.size() == 2);
766 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
767 TORCH_CHECK(sizes[1].empty());
768 } else if (fn == "aten::mul") {
769 found_mul = true;
770 TORCH_CHECK(sizes.size() > 1);
771 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
772 }
773 }
774 TORCH_CHECK(found_test);
775 TORCH_CHECK(found_pow);
776 TORCH_CHECK(found_mul);
777 }
778
checkTracedOutputs(const TracedTestValues & outputs)779 void checkTracedOutputs(const TracedTestValues& outputs) {
780 bool found_test = false;
781 bool found_pow = false;
782 bool found_mul = false;
783 for (const auto& output : outputs) {
784 const auto& fn = std::get<0>(output);
785 const auto& sizes = std::get<1>(output);
786
787 if (fn == "test") {
788 found_test = true;
789 TORCH_CHECK(sizes.empty());
790 } else if (fn == "aten::pow") {
791 found_pow = true;
792 TORCH_CHECK(sizes.size() == 1);
793 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
794 } else if (fn == "aten::mul") {
795 found_mul = true;
796 TORCH_CHECK(sizes.size() == 1);
797 TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
798 }
799 }
800 TORCH_CHECK(found_test);
801 TORCH_CHECK(found_pow);
802 TORCH_CHECK(found_mul);
803 }
804
805 static bool bad_scope = false;
806 template <RecordScope scope, size_t* cnt>
checkScopeCallback(const at::RecordFunction & fn)807 std::unique_ptr<at::ObserverContext> checkScopeCallback(
808 const at::RecordFunction& fn) {
809 if (fn.scope() == scope) {
810 ++(*cnt);
811 } else {
812 bad_scope = true;
813 }
814 return nullptr;
815 }
816
817 template <RecordScope scope, size_t* cnt>
pushScopedCallback()818 void pushScopedCallback() {
819 at::addGlobalCallback(
820 at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
821 .scopes({scope}));
822 }
823
824 // These cannot be function-local because that would prohibit them
825 // from being used as template arguments prior to C++17.
826 static size_t fun_cnt;
827 static size_t ts_fun_cnt;
828 static size_t user_scope_cnt;
829
checkScopeCallbacks()830 void checkScopeCallbacks() {
831 static bool found_function_scope;
832 static bool found_method_scope;
833 static bool found_user_scope;
834 found_function_scope = false;
835 found_method_scope = false;
836 found_user_scope = false;
837 at::addGlobalCallback(at::RecordFunctionCallback(
838 [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
839 if (fn.scope() == at::RecordScope::FUNCTION &&
840 std::string(fn.name()) == "test_function") {
841 found_function_scope = true;
842 }
843 if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
844 std::string(fn.name()) == "test_method") {
845 found_method_scope = true;
846 }
847 if (fn.scope() == at::RecordScope::USER_SCOPE &&
848 std::string(fn.name()) == "test_user_scope") {
849 found_user_scope = true;
850 }
851 return nullptr;
852 }));
853
854 bad_scope = false;
855 fun_cnt = 0;
856 pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
857 ts_fun_cnt = 0;
858 pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
859 user_scope_cnt = 0;
860 pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
861
862 TORCH_CHECK(at::hasCallbacks());
863
864 {
865 RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
866 { RECORD_FUNCTION("test_function", {}); }
867 { RECORD_USER_SCOPE("test_user_scope"); }
868 }
869
870 TORCH_CHECK(!bad_scope);
871 TORCH_CHECK(fun_cnt == 1);
872 TORCH_CHECK(ts_fun_cnt == 1);
873 TORCH_CHECK(user_scope_cnt == 1);
874
875 TORCH_CHECK(found_function_scope);
876 TORCH_CHECK(found_method_scope);
877 TORCH_CHECK(found_user_scope);
878 }
879
880 static TracedTestValues traced_inputs;
881 static TracedTestValues traced_outputs;
882 static std::unordered_set<std::string> ts_input_names;
883 static std::unordered_set<std::string> ts_output_names;
884
tracedInputsCallback(const RecordFunction & fn)885 std::unique_ptr<at::ObserverContext> tracedInputsCallback(
886 const RecordFunction& fn) {
887 if (fn.scope() == RecordScope::FUNCTION) {
888 auto inputs = fn.inputs();
889 std::vector<std::vector<int64_t>> sizes;
890 for (const auto& input : inputs) {
891 if (input.isTensor()) {
892 sizes.push_back(input.toTensor().sizes().vec());
893 } else if (input.isScalar()) {
894 // NOLINTNEXTLINE(modernize-use-emplace)
895 sizes.push_back(std::vector<int64_t>());
896 }
897 }
898 traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
899 } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
900 ts_input_names.insert(fn.name());
901 }
902 return nullptr;
903 }
904
tracedOutputsCallback(const RecordFunction & fn,ObserverContext * ctx_ptr)905 void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
906 if (fn.scope() == RecordScope::FUNCTION) {
907 auto outputs = fn.outputs();
908 std::vector<std::vector<int64_t>> sizes;
909 for (const auto& output : outputs) {
910 if (output.isTensor()) {
911 sizes.push_back(output.toTensor().sizes().vec());
912 } else if (output.isScalar()) {
913 sizes.emplace_back();
914 }
915 }
916 traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
917 } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
918 ts_output_names.insert(fn.name());
919 }
920 }
921
TEST(RecordFunctionTest,TracedTestInputsOutputs)922 TEST(RecordFunctionTest, TracedTestInputsOutputs) {
923 // disabling the inlining of method calls
924 GraphOptimizerEnabledGuard opt_guard(false);
925
926 // [(fn, [[sizes], [sizes], ...]), ...]
927 addGlobalCallback(
928 RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
929 .needsInputs(true)
930 .needsOutputs(true));
931
932 TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
933 {
934 auto t = torch::randn({1, 2, 3}, at::kCPU);
935 t.set_requires_grad(true);
936 auto t2 = invokeTestRecordFunction(t);
937 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
938 eager_inputs = traced_inputs;
939 eager_outputs = traced_outputs;
940 traced_inputs.clear();
941 traced_outputs.clear();
942
943 TORCH_CHECK(ts_input_names.empty());
944 TORCH_CHECK(ts_output_names.empty());
945
946 t = torch::randn({1, 2, 3}, at::kCPU);
947 t.set_requires_grad(true);
948 t2 = invokeTestRecordFunctionJIT(t);
949 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
950 jit_inputs = traced_inputs;
951 jit_outputs = traced_outputs;
952 traced_inputs.clear();
953 traced_outputs.clear();
954 }
955
956 TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
957 TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
958 TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
959 TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
960
961 checkTracedInputs(eager_inputs);
962 checkTracedOutputs(eager_outputs);
963 checkTracedInputs(jit_inputs);
964 checkTracedOutputs(jit_outputs);
965 at::clearCallbacks();
966 }
967
968 static int sampled_cb_ctr = 0;
sampledCallback(const RecordFunction & fn)969 std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
970 if (std::string(fn.name()) == "test") {
971 ++sampled_cb_ctr;
972 }
973 return nullptr;
974 }
975
976 static int non_sampled_cb_ctr = 0;
nonSampledCallback(const RecordFunction & fn)977 std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
978 if (std::string(fn.name()) == "test") {
979 ++non_sampled_cb_ctr;
980 }
981 return nullptr;
982 }
983
TEST(RecordFunctionTest,SampledCallbacks)984 TEST(RecordFunctionTest, SampledCallbacks) {
985 // disabling the inlining of method calls
986 GraphOptimizerEnabledGuard opt_guard(false);
987
988 // test sampled callbacks
989 sampled_cb_ctr = 0;
990 auto setup_sampled_callback = [](double sampling_prob) {
991 return addGlobalCallback(
992 RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
993 };
994
995 addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
996
997 auto handle = setup_sampled_callback(0.5);
998
999 auto run_test_function = []() {
1000 auto t = torch::randn({1, 2, 3}, at::kCPU);
1001 for (auto k = 0; k < 1000; k++) {
1002 invokeTestRecordFunction(t);
1003 }
1004 };
1005
1006 run_test_function();
1007 TORCH_CHECK(non_sampled_cb_ctr == 1000);
1008 TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
1009
1010 sampled_cb_ctr = 0;
1011 removeCallback(handle);
1012 handle = setup_sampled_callback(0.0);
1013 run_test_function();
1014
1015 TORCH_CHECK(non_sampled_cb_ctr == 2000);
1016 TORCH_CHECK(sampled_cb_ctr == 0);
1017
1018 sampled_cb_ctr = 0;
1019 removeCallback(handle);
1020 // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1021 handle = setup_sampled_callback(1.0);
1022 run_test_function();
1023
1024 TORCH_CHECK(non_sampled_cb_ctr == 3000);
1025 TORCH_CHECK(sampled_cb_ctr == 1000);
1026 clearCallbacks();
1027
1028 // test the scope of the callbacks
1029 checkScopeCallbacks();
1030 clearCallbacks();
1031 }
1032
TEST(RecordFunctionTest,RecordFunctionGuard)1033 TEST(RecordFunctionTest, RecordFunctionGuard) {
1034 // disabling the inlining of method calls
1035 GraphOptimizerEnabledGuard opt_guard(false);
1036
1037 static std::vector<std::string> fn_names;
1038 static std::mutex guard_mtx;
1039
1040 // check record function guard
1041 addGlobalCallback(RecordFunctionCallback(
1042 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1043 std::lock_guard<std::mutex> lock(guard_mtx);
1044 // NOLINTNEXTLINE(modernize-use-emplace)
1045 fn_names.push_back(fn.name());
1046 return nullptr;
1047 }));
1048 {
1049 RecordFunctionGuard g1(false);
1050 {
1051 RECORD_USER_SCOPE("A");
1052 {
1053 RecordFunctionGuard g2(true);
1054 RECORD_USER_SCOPE("B");
1055 {
1056 DisableRecordFunctionGuard g3;
1057 RECORD_USER_SCOPE("C");
1058 }
1059 }
1060 { RECORD_USER_SCOPE("D"); }
1061 }
1062 }
1063 TORCH_CHECK(fn_names.size() == 1);
1064 TORCH_CHECK(fn_names[0] == "B");
1065 clearCallbacks();
1066 }
1067
1068 static std::vector<size_t> ids;
1069
1070 template <size_t id>
add_remove_test_add_cb()1071 auto add_remove_test_add_cb() {
1072 return addGlobalCallback(RecordFunctionCallback(
1073 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1074 ids.push_back(id);
1075 return nullptr;
1076 }));
1077 }
1078
TEST(RecordFunctionTest,Callbacks)1079 TEST(RecordFunctionTest, Callbacks) {
1080 // disabling the inlining of method calls
1081 GraphOptimizerEnabledGuard opt_guard(false);
1082
1083 auto h1 = add_remove_test_add_cb<1>();
1084 add_remove_test_add_cb<2>();
1085 auto h3 = add_remove_test_add_cb<3>();
1086
1087 { RECORD_USER_SCOPE("test"); }
1088
1089 TORCH_CHECK(ids.size() == 3);
1090 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1091 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1092 TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1093
1094 ids.clear();
1095 removeCallback(h1);
1096
1097 { RECORD_USER_SCOPE("test"); }
1098
1099 TORCH_CHECK(ids.size() == 2);
1100 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1101 TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1102
1103 ids.clear();
1104 removeCallback(h3);
1105
1106 { RECORD_USER_SCOPE("test"); }
1107
1108 TORCH_CHECK(ids.size() == 1);
1109 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1110
1111 clearCallbacks();
1112
1113 // thread local / global callbacks
1114
1115 ids.clear();
1116 add_remove_test_add_cb<1>();
1117
1118 { RECORD_USER_SCOPE("test"); }
1119
1120 TORCH_CHECK(ids.size() == 1);
1121 TORCH_CHECK(ids[0] == 1);
1122 ids.clear();
1123
1124 auto th = std::thread([]() {
1125 addThreadLocalCallback(RecordFunctionCallback(
1126 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1127 ids.push_back(2);
1128 return nullptr;
1129 }));
1130
1131 { RECORD_USER_SCOPE("test_thread"); }
1132 });
1133 th.join();
1134 TORCH_CHECK(ids.size() == 2);
1135 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1136 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1137 ids.clear();
1138
1139 { RECORD_USER_SCOPE("test"); }
1140
1141 TORCH_CHECK(ids.size() == 1);
1142 TORCH_CHECK(ids[0] == 1);
1143 ids.clear();
1144
1145 clearCallbacks();
1146
1147 // START: thread local / global context check callbacks
1148 struct TestContext : public ObserverContext {
1149 int a{0};
1150 std::string b;
1151 };
1152 ids.clear();
1153 { // START: global test
1154 addGlobalCallback(RecordFunctionCallback(
1155 [](const RecordFunction&
1156 /* unused */) -> std::unique_ptr<at::ObserverContext> {
1157 auto ctx = std::make_unique<TestContext>();
1158 ctx->a = 123;
1159 ctx->b = "test_str";
1160 ids.push_back(1);
1161 return ctx;
1162 },
1163 [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1164 auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1165 TORCH_CHECK(ctx != nullptr);
1166 TORCH_CHECK(ctx->a == 123);
1167 TORCH_CHECK(ctx->b == "test_str");
1168 }));
1169
1170 { RECORD_USER_SCOPE("test"); }
1171
1172 TORCH_CHECK(ids.size() == 1);
1173 TORCH_CHECK(ids[0] == 1);
1174 ids.clear();
1175 } // END: global test
1176 { // START: thread local test
1177 auto ctx_th = std::thread([]() {
1178 const std::string test_str = "test thread str";
1179 addThreadLocalCallback(RecordFunctionCallback(
1180 [](const RecordFunction&
1181 /* unused */) -> std::unique_ptr<at::ObserverContext> {
1182 auto ctx = std::make_unique<TestContext>();
1183 ctx->a = 234;
1184 ctx->b = "test_thread_str";
1185 ids.push_back(2);
1186 return ctx;
1187 },
1188 [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1189 auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1190 TORCH_CHECK(ctx_ptr != nullptr);
1191 TORCH_CHECK(ctx->a == 234);
1192 TORCH_CHECK(ctx->b == "test_thread_str");
1193 }));
1194
1195 // Will call both global and thread local callbacks.
1196 { RECORD_USER_SCOPE("test_thread"); }
1197 });
1198 ctx_th.join();
1199 TORCH_CHECK(ids.size() == 2);
1200 TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1201 TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1202 ids.clear();
1203 } // END: thread local test
1204
1205 clearCallbacks();
1206 }
1207
TEST(RecordFunctionTest,ShouldRun)1208 TEST(RecordFunctionTest, ShouldRun) {
1209 // disabling the inlining of method calls
1210 GraphOptimizerEnabledGuard opt_guard(false);
1211
1212 static bool ran = false;
1213 auto handle = addGlobalCallback(RecordFunctionCallback(
1214 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1215 ran = true;
1216 return nullptr;
1217 }));
1218
1219 { RECORD_USER_SCOPE("test"); }
1220
1221 EXPECT_TRUE(ran) << "first run didn't happen";
1222 ran = false;
1223
1224 disableCallback(handle);
1225
1226 { RECORD_USER_SCOPE("test"); }
1227
1228 EXPECT_FALSE(ran) << "second run happened but shouldn't have";
1229 ran = false;
1230
1231 reenableCallback(handle);
1232
1233 { RECORD_USER_SCOPE("test"); }
1234
1235 EXPECT_TRUE(ran) << "run after re-enable didn't happen";
1236 ran = false;
1237
1238 clearCallbacks();
1239 }
1240
TEST(RecordFunctionTest,Basic)1241 TEST(RecordFunctionTest, Basic) {
1242 // disabling the inlining of method calls
1243 GraphOptimizerEnabledGuard opt_guard(false);
1244
1245 static std::string recorded_op;
1246 static bool has_ids = false;
1247
1248 // test propagation of TLS callbacks
1249 std::thread t([]() {
1250 RecordFunctionGuard enable_rec_fn;
1251 auto handle = addThreadLocalCallback(RecordFunctionCallback(
1252 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1253 recorded_op = fn.name();
1254 return nullptr;
1255 }));
1256 ThreadLocalState state;
1257 std::thread t_child([state]() {
1258 ThreadLocalStateGuard g_tls(state);
1259 RECORD_USER_SCOPE("test_in_thread");
1260 });
1261 t_child.join();
1262 EXPECT_EQ(recorded_op, "test_in_thread");
1263 removeCallback(handle);
1264 });
1265 t.join();
1266 clearCallbacks();
1267
1268 // test set ids
1269 addGlobalCallback(
1270 RecordFunctionCallback(
1271 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1272 has_ids = fn.handle() > 0;
1273 return nullptr;
1274 })
1275 .needsIds(true));
1276 { RECORD_USER_SCOPE("test"); }
1277 TORCH_CHECK(has_ids);
1278 clearCallbacks();
1279 has_ids = false;
1280 addGlobalCallback(RecordFunctionCallback(
1281 [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1282 has_ids = fn.handle() > 0;
1283 return nullptr;
1284 }));
1285 { RECORD_USER_SCOPE("test"); }
1286 TORCH_CHECK(!has_ids);
1287 clearCallbacks();
1288 }
1289
TEST(RecordFunctionTest,OperatorNameOverload)1290 TEST(RecordFunctionTest, OperatorNameOverload) {
1291 static std::set<std::string> operator_names;
1292 at::addGlobalCallback(at::RecordFunctionCallback(
1293 [](const at::RecordFunction& fn)
1294 -> std::unique_ptr<at::ObserverContext> {
1295 std::optional<c10::OperatorName> op_name =
1296 fn.operator_name();
1297 if (op_name.has_value()) {
1298 operator_names.insert(c10::toString(*op_name));
1299 } else {
1300 operator_names.insert("No Operator Name");
1301 }
1302 return nullptr;
1303 })
1304 .scopes({at::RecordScope::FUNCTION}));
1305 auto t = torch::randn({1, 2, 3}, at::kCPU);
1306 t.set_requires_grad(false);
1307 auto t2 = t.pow(2);
1308
1309 at::clearCallbacks();
1310 EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
1311 << "Expected that all traced operators had an associated OperatorName object";
1312 EXPECT_TRUE(operator_names.count("aten::randn") == 1)
1313 << "Expected aten::randn to have been called and recorded, but it was not";
1314 EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
1315 << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
1316 }
1317
1318 class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
1319 public:
getModelId() const1320 int getModelId() const {
1321 return model_id_;
1322 }
1323
setModelId(int model_id)1324 void setModelId(int model_id) {
1325 model_id_ = model_id;
1326 }
1327
1328 // NOLINTNEXTLINE(modernize-use-equals-default)
~TestThreadLocalDebugInfo()1329 virtual ~TestThreadLocalDebugInfo() override {}
1330
1331 private:
1332 int model_id_ = 0;
1333 };
1334
checkDebugInfo(c10::DebugInfoKind kind,int model_id)1335 void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
1336 auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
1337 TORCH_CHECK(debug_info != nullptr);
1338 auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
1339 TORCH_CHECK(test_debug_info != nullptr);
1340 TORCH_CHECK(test_debug_info->getModelId() == model_id);
1341 }
1342
TEST(ThreadLocalDebugInfoTest,Basic)1343 TEST(ThreadLocalDebugInfoTest, Basic) {
1344 static std::atomic<bool> done{false};
1345
1346 TORCH_CHECK(
1347 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1348 auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1349 debug_info->setModelId(42);
1350 {
1351 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1352 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1353 }
1354
1355 // check that thread local debug info is propagated through fork calls
1356 TORCH_CHECK(
1357 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1358 {
1359 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1360 at::launch([]() {
1361 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1362 done = true;
1363 });
1364 }
1365 while (!done) {
1366 }
1367
1368 // check that thread local debug info is propagated through backward pass
1369 TORCH_CHECK(
1370 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1371 done = false;
1372 auto handle = addGlobalCallback(RecordFunctionCallback(
1373 [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
1374 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1375 done = true;
1376 return nullptr;
1377 }));
1378 {
1379 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1380 auto t = torch::randn({1, 2, 3}, at::kCPU);
1381 t.set_requires_grad(true);
1382 auto t2 = t.pow(2);
1383 t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
1384 }
1385 removeCallback(handle);
1386 TORCH_CHECK(done);
1387
1388 // check nested debug info
1389 TORCH_CHECK(
1390 c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1391 {
1392 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1393 {
1394 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1395 {
1396 auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1397 debug_info->setModelId(314);
1398 c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
1399 {
1400 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1401 checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1402 done = false;
1403 at::launch([]() {
1404 checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1405 checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1406 done = true;
1407 });
1408 while (!done) {
1409 }
1410 }
1411 }
1412 }
1413 }
1414 }
1415
TEST(TestSymIntArrayRef,BasicConversion)1416 TEST(TestSymIntArrayRef, BasicConversion) {
1417 const size_t X = 2, Y = 4, Z = 5;
1418 std::vector<int64_t> tgt_size_v{2, 4, 5};
1419 std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
1420 auto a = at::randn({1, 4, 1}, at::kCPU);
1421 auto b = a.expand_symint(tgt_size);
1422 auto c = a.expand(tgt_size_v);
1423 ASSERT_TRUE(torch::allclose(b, c));
1424 }
1425
TEST(TestSymInt,NarrowCopyWithSymbolicInt)1426 TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
1427 static const size_t LENGTH = 5;
1428 auto a = at::randn({10}, at::kCPU);
1429 c10::SymInt si(LENGTH);
1430 auto b = a.narrow_copy_symint(0, 0, si);
1431 auto c = a.narrow(0, 0, LENGTH);
1432 ASSERT_TRUE(torch::allclose(b, c));
1433 }
1434
TEST(TestSymInt,NarrowCopy)1435 TEST(TestSymInt, NarrowCopy) {
1436 static const size_t LENGTH = 5;
1437 auto a = at::randn({10}, at::kCPU);
1438 auto b = a.narrow_copy(0, 0, LENGTH);
1439 auto c = a.narrow(0, 0, LENGTH);
1440 ASSERT_TRUE(torch::allclose(b, c));
1441 }
1442
TEST(TestSymInt,AddSymbolicInt)1443 TEST(TestSymInt, AddSymbolicInt) {
1444 c10::SymInt a(5);
1445 c10::SymInt b(3);
1446 ASSERT_TRUE((a + b).expect_int() == 8);
1447 }
1448
TEST(FallbackGraphsTest,Basic)1449 TEST(FallbackGraphsTest, Basic) {
1450 auto x = at::randn({1}, at::kCPU);
1451 auto y = at::randn({1}, at::kCPU);
1452 auto stack = createStack({x.clone(), y.clone()});
1453
1454 auto graph_string = R"IR(
1455 graph(%0 : Float(1),
1456 %1 : Float(1)):
1457 %2 : Tensor = aten::mul(%0, %1)
1458 %3 : Tensor = aten::mul(%2, %0)
1459 return (%3))IR";
1460 auto graph = std::make_shared<Graph>();
1461 torch::jit::parseIR(graph_string, graph.get());
1462
1463 {
1464 Code code(graph, "");
1465 InterpreterState interpreter{code};
1466 interpreter.run(stack);
1467 }
1468 at::Tensor et;
1469 pop(stack, et);
1470 float ef = et.item<float>();
1471 {
1472 EnableProfilingGuard epg;
1473 GraphFunction f("fallbackGraphs", graph, nullptr);
1474 for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
1475 stack.emplace_back(x.clone());
1476 stack.emplace_back(y.clone());
1477 if (i == getNumProfiledRuns()) {
1478 // we will be modifying a profiled graph
1479 // before ProfilingGraphExecutor
1480 // will optimize it in the next iteration
1481 auto opt_graph = lastExecutedOptimizedGraph();
1482 // this is safe to do since we are done profiling
1483 ProfilingRecord::removeProfileCounter(opt_graph->block());
1484 replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
1485 auto it = opt_graph->block()->nodes().begin();
1486 ASSERT_EQ(it->kind(), prim::FallbackGraph);
1487 auto fallback = *it++;
1488 ASSERT_EQ(it, opt_graph->block()->nodes().end());
1489 ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
1490 testing::FileCheck()
1491 .check("Tensor = aten::mul")
1492 ->check("Tensor = aten::mul")
1493 ->run(*fallback->g(attr::Subgraph));
1494 }
1495 f.run(stack);
1496 at::Tensor at;
1497 pop(stack, at);
1498 float af = at.item<float>();
1499 ASSERT_EQ(af, ef);
1500 }
1501
1502 auto opt_graph = lastExecutedOptimizedGraph();
1503 testing::FileCheck()
1504 .check("(Tensor) = prim::CallFunction")
1505 ->run(*opt_graph);
1506 }
1507 }
1508
1509 // TODO this test wasn't running and is broken.
1510 // TEST(AutogradProfilerTest, Basic) {
1511 // constexpr int batch_size = 4;
1512 // constexpr int input_size = 256;
1513 // constexpr int seq_len = 32;
1514
1515 // int hidden_size = 2 * input_size;
1516 // auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
1517 // auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
1518 // auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
1519 // auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
1520 // auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
1521
1522 // std::stringstream ss;
1523 // {
1524 // RecordProfile guard(ss);
1525 // for (size_t i = 0; i < 100; ++i) {
1526 // std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
1527 // }
1528 // }
1529
1530 // std::string result = ss.str();
1531 // size_t count = 0;
1532 // for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
1533 // count++, pos++) {
1534 // }
1535 // ASSERT_EQ((count, 200);
1536 // }
1537
TEST(NoneSchemaMatchTest,Basic)1538 TEST(NoneSchemaMatchTest, Basic) {
1539 RegisterOperators reg({
1540 Operator(
1541 "prim::test_none() -> int?",
1542 [](Stack& stack) { push(stack, IValue()); },
1543 aliasAnalysisFromSchema()),
1544 Operator(
1545 "prim::is_none(int? a) -> bool",
1546 [](Stack& stack) {
1547 IValue a = pop(stack);
1548 if (a.isNone()) {
1549 push(stack, true);
1550 } else {
1551 push(stack, false);
1552 }
1553 },
1554 aliasAnalysisFromSchema()),
1555 });
1556
1557 // Constant propagation will run test_none and produce a None,
1558 // testing that its type is set appropriately and schema matching doesn't
1559 // fail when running is_none
1560
1561 auto r = std::make_shared<Graph>();
1562 auto& g = *r;
1563 auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
1564 auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
1565 g.registerOutput(out_bool);
1566 ConstantPropagation(r);
1567
1568 auto nodes = r->block()->nodes();
1569 // checking that constant propagation ran wo/failure
1570 AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
1571 }
1572
1573 static int testPassValue = 0;
fakePass(std::shared_ptr<Graph> & g)1574 void fakePass(std::shared_ptr<Graph>& g) {
1575 testPassValue++;
1576 return;
1577 }
1578
1579 RegisterPass p(fakePass);
1580
TEST(PassManagementTest,Basic)1581 TEST(PassManagementTest, Basic) {
1582 std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1583 parseIR(
1584 R"IR(
1585 graph(%a):
1586 return (%a))IR",
1587 &*graph);
1588
1589 std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
1590 auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
1591 GraphExecutor executor(graph, "");
1592 executor.run(stack);
1593 return stack;
1594 };
1595 run(graph, stack);
1596 // we will not run fusion in simple mode
1597 if (!getExecutorMode()) {
1598 AT_ASSERT(testPassValue);
1599 }
1600 }
1601
checkShape(TypePtr typ,std::vector<int64_t> expected)1602 static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
1603 auto ptp = typ->expect<TensorType>();
1604 ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
1605 }
1606
checkShape(Node * n,std::vector<int64_t> expected,bool prev=true)1607 static void checkShape(
1608 Node* n,
1609 std::vector<int64_t> expected,
1610 bool prev = true) {
1611 auto profile = (prev) ? n->inputs().at(0)->node() : n;
1612 checkShape(profile->output()->type(), expected);
1613 }
1614
count_(Block * block,const std::function<bool (Node * n)> & pred,size_t & count)1615 void count_(
1616 Block* block,
1617 const std::function<bool(Node* n)>& pred,
1618 size_t& count) {
1619 for (Node* n : block->nodes()) {
1620 if (pred(n)) {
1621 count++;
1622 }
1623
1624 for (Block* ib : n->blocks()) {
1625 count_(ib, pred, count);
1626 }
1627 }
1628 }
1629
countNodes(const std::shared_ptr<Graph> & graph,const std::function<bool (Node * n)> & pred)1630 size_t countNodes(
1631 const std::shared_ptr<Graph>& graph,
1632 const std::function<bool(Node* n)>& pred) {
1633 size_t count = 0;
1634 count_(graph->block(), pred, count);
1635 return count;
1636 }
1637
true_pred(Node * n)1638 bool true_pred(Node* n) {
1639 return true;
1640 };
1641
is_loop(Node * n)1642 bool is_loop(Node* n) {
1643 return n->kind() == prim::Loop;
1644 };
1645
TEST(LoopPeelerTest,NoInductionVariableUse)1646 TEST(LoopPeelerTest, NoInductionVariableUse) {
1647 // do not use an induction variable explicitly
1648 static const auto str_func_def = R"JIT(
1649 def test_peel_n_times():
1650 sum = 0
1651 for i in range(10):
1652 sum += 2
1653 return sum
1654 )JIT";
1655
1656 auto cu = compile(str_func_def);
1657 auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1658 auto stack = createStack({});
1659 // peeling loop once
1660 {
1661 LoopsPeeler peeler(true_pred, 1);
1662 auto copy = f.graph()->copy();
1663 peeler.run(copy);
1664 int num_loops =
1665 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1666 ASSERT_EQ(num_loops, 2);
1667 Code code(copy, "");
1668 InterpreterState interpreter{code};
1669 interpreter.run(stack);
1670 ASSERT_EQ(stack.back().toInt(), 20);
1671 }
1672
1673 // test peeling more than one iteration
1674 {
1675 LoopsPeeler peeler(true_pred, 3);
1676 auto copy = f.graph()->copy();
1677 peeler.run(copy);
1678 int num_loops =
1679 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1680 ASSERT_EQ(num_loops, 2);
1681 Code code(copy, "");
1682 InterpreterState interpreter{code};
1683 interpreter.run(stack);
1684 ASSERT_EQ(stack.back().toInt(), 20);
1685 }
1686 }
1687
TEST(LoopPeelerTest,YesInductionVariableUse)1688 TEST(LoopPeelerTest, YesInductionVariableUse) {
1689 // uses the induction variable
1690 static const auto str_func_def = R"JIT(
1691 def test_peel_n_times():
1692 sum = 0
1693 for i in range(10):
1694 sum += i
1695 return sum
1696 )JIT";
1697
1698 auto cu = compile(str_func_def);
1699 auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1700 auto stack = createStack({});
1701 // peeling loop once
1702 {
1703 LoopsPeeler peeler(true_pred, 1);
1704 auto copy = f.graph()->copy();
1705 peeler.run(copy);
1706 int num_loops =
1707 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1708 ASSERT_EQ(num_loops, 2);
1709 Code code(copy, "");
1710 InterpreterState interpreter{code};
1711 interpreter.run(stack);
1712 ASSERT_EQ(stack.back().toInt(), 45);
1713 }
1714
1715 // test peeling more than one iteration
1716 {
1717 LoopsPeeler peeler(true_pred, 3);
1718 auto copy = f.graph()->copy();
1719 peeler.run(copy);
1720 int num_loops =
1721 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1722 ASSERT_EQ(num_loops, 2);
1723 Code code(copy, "");
1724 InterpreterState interpreter{code};
1725 interpreter.run(stack);
1726 ASSERT_EQ(stack.back().toInt(), 45);
1727 }
1728 }
1729
TEST(LoopPeelerTest,LoopWithTerminationCondition)1730 TEST(LoopPeelerTest, LoopWithTerminationCondition) {
1731 // tests with explicit termination conditions
1732 static const auto str_func_def = R"JIT(
1733 def test_with_cond_times():
1734 sum = 0
1735 i = 0
1736 while (sum < 2):
1737 sum += i
1738 i += 1
1739 return sum
1740 )JIT";
1741
1742 // the peel changes the termination condition to false
1743 // so the original loop doesn't run
1744 auto cu = compile(str_func_def);
1745 auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
1746 auto stack = createStack({});
1747 // peeling 5 iterations should update the termination
1748 // condition to false
1749 {
1750 LoopsPeeler peeler(true_pred, 5);
1751 auto copy = f.graph()->copy();
1752 peeler.run(copy);
1753 int num_loops =
1754 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1755 ASSERT_EQ(num_loops, 2);
1756 Code code(copy, "");
1757 InterpreterState interpreter{code};
1758 interpreter.run(stack);
1759 ASSERT_EQ(stack.back().toInt(), 3);
1760 }
1761
1762 // the termination condition remains true
1763 {
1764 LoopsPeeler peeler(true_pred, 1);
1765 auto copy = f.graph()->copy();
1766 peeler.run(copy);
1767 int num_loops =
1768 std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1769 ASSERT_EQ(num_loops, 2);
1770 Code code(copy, "");
1771 InterpreterState interpreter{code};
1772 interpreter.run(stack);
1773 ASSERT_EQ(stack.back().toInt(), 3);
1774 }
1775 }
1776
1777 // tests simple nested loops
TEST(LoopPeelerTest,SimpleNestedLoops)1778 TEST(LoopPeelerTest, SimpleNestedLoops) {
1779 static const auto str_func_def = R"JIT(
1780 def test_nested_loops():
1781 sum = 0
1782 i = 0
1783 for i in range(10):
1784 for j in range(10):
1785 sum += i + j
1786 return sum
1787 )JIT";
1788
1789 auto cu = compile(str_func_def);
1790 auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1791 auto stack = createStack({});
1792
1793 {
1794 LoopsPeeler peeler(true_pred, 1);
1795 auto copy = f.graph()->copy();
1796 peeler.run(copy);
1797 ASSERT_EQ(countNodes(copy, is_loop), 5);
1798 Code code(copy, "");
1799 InterpreterState interpreter{code};
1800 interpreter.run(stack);
1801 ASSERT_EQ(stack.back().toInt(), 900);
1802 }
1803
1804 {
1805 LoopsPeeler peeler(true_pred, 5);
1806 auto copy = f.graph()->copy();
1807 peeler.run(copy);
1808 ASSERT_EQ(countNodes(copy, is_loop), 5);
1809 Code code(copy, "");
1810 InterpreterState interpreter{code};
1811 interpreter.run(stack);
1812 ASSERT_EQ(stack.back().toInt(), 900);
1813 }
1814 }
1815
TEST(LoopPeelerTest,SimpleNestedLoops2)1816 TEST(LoopPeelerTest, SimpleNestedLoops2) {
1817 static const auto str_func_def = R"JIT(
1818 def test_nested_loops():
1819 sum = 0
1820 i = 0
1821 for i in range(10):
1822 j = 0
1823 while sum < 2:
1824 sum += i + j
1825 j += 1
1826 return sum
1827 )JIT";
1828
1829 auto cu = compile(str_func_def);
1830 auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1831 auto stack = createStack({});
1832 {
1833 LoopsPeeler peeler(true_pred, 1);
1834 auto copy = f.graph()->copy();
1835 peeler.run(copy);
1836 ASSERT_EQ(countNodes(copy, is_loop), 5);
1837 Code code(copy, "");
1838 InterpreterState interpreter{code};
1839 interpreter.run(stack);
1840 ASSERT_EQ(stack.back().toInt(), 3);
1841 }
1842
1843 {
1844 LoopsPeeler peeler(true_pred, 5);
1845 auto copy = f.graph()->copy();
1846 peeler.run(copy);
1847 ASSERT_EQ(countNodes(copy, is_loop), 5);
1848 Code code(copy, "");
1849 InterpreterState interpreter{code};
1850 interpreter.run(stack);
1851 ASSERT_EQ(stack.back().toInt(), 3);
1852 }
1853 }
1854
TEST(JitTracing,Basic)1855 TEST(JitTracing, Basic) {
1856 constexpr int batch_size = 4;
1857 constexpr int input_size = 256;
1858
1859 int hidden_size = 2 * input_size;
1860
1861 auto input = at::randn({batch_size, input_size}, at::kCPU);
1862 auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
1863 auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
1864 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
1865 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
1866
1867 auto graph = build_lstm();
1868 auto stack = createStack({input, hx, cx, w_ih, w_hh});
1869 auto traced = TraceGraph(graph, stack);
1870
1871 // Check that the inputs of traced graph have the same type as the inputs
1872 // specified here.
1873 ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
1874 ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
1875 ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
1876 ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
1877 ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
1878
1879 Tensor prof_out;
1880 pop(stack, prof_out);
1881
1882 {
1883 stack = createStack({input, hx, cx, w_ih, w_hh});
1884 Code cd(traced, "traced");
1885 InterpreterState is{cd};
1886 is.run(stack);
1887 Tensor traced_out;
1888 pop(stack, traced_out);
1889 torch::allclose(prof_out, traced_out);
1890 }
1891
1892 {
1893 stack = createStack({input, hx, cx, w_ih, w_hh});
1894 Code cd(graph, "graph");
1895 InterpreterState is{cd};
1896 is.run(stack);
1897 Tensor scripted_out;
1898 pop(stack, scripted_out);
1899 torch::allclose(prof_out, scripted_out);
1900 }
1901 }
1902
1903 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InsertAndEliminateRedundantGuardsTest,Basic)1904 TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
1905 static const auto basic_example = R"JIT(
1906 def basic(x, y):
1907 a = x + y
1908 b = x * y
1909 c = x + 1
1910 d = a - c
1911 e = b - c
1912 return d + e
1913 )JIT";
1914
1915 auto cu = compile(basic_example);
1916 auto& fun = toGraphFunction(cu->get_function("basic"));
1917 auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1918 auto x = at::randn({2, 3}, at::kCPU);
1919 auto y = at::randn({2, 3}, at::kCPU);
1920 auto stack = createStack({x, y});
1921 // introduce some profiling information
1922 Code cd(pr->profiled_graph_, "");
1923 InterpreterState is{cd};
1924 is.run(stack);
1925 auto copy = pr->profiled_graph_->copy();
1926 ProfilingRecord::removeProfileCounter(copy->block());
1927 InsertGuards(copy);
1928 auto nodes = copy->block()->nodes();
1929 auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
1930 return n->kind() == prim::Guard;
1931 });
1932 ASSERT_NE(guard, nodes.end());
1933 ASSERT_EQ(
1934 guard->input()->type()->expectRef<TensorType>().sizes().size(),
1935 std::nullopt);
1936 checkShape(*guard, {2, 3}, false);
1937 auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1938 int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1939 ASSERT_EQ(num_guards, 12);
1940 // now eliminate as many guards as possible
1941 // we should be left with two guards on x and y's defs
1942 EliminateRedundantGuards(copy);
1943 num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1944 ASSERT_EQ(num_guards, 2);
1945 }
1946
TEST(InsertBailOutsTest,Basic)1947 TEST(InsertBailOutsTest, Basic) {
1948 static const auto basic_example = R"JIT(
1949 def basic_loop(x, y):
1950
1951 a = x + 1
1952 b = y + 2
1953 c = x + y + 3
1954
1955 for i in range(10):
1956 a = a + b
1957 # invariant
1958 d = b * c
1959 #
1960 a = a - d
1961
1962 e = a + 4
1963 return e
1964 )JIT";
1965
1966 auto cu = compile(basic_example);
1967 auto& fun = toGraphFunction(cu->get_function("basic_loop"));
1968 auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1969 auto x = at::randn({2, 3}, at::kCPU);
1970 auto y = at::randn({2, 3}, at::kCPU);
1971 auto stack = createStack({x, y});
1972 // introduce some profiling information
1973 Code cd(pr->profiled_graph_, "");
1974 InterpreterState is{cd};
1975 is.run(stack);
1976 auto copy = pr->profiled_graph_->copy();
1977 ProfilingRecord::removeProfileCounter(copy->block());
1978 InsertGuards(copy);
1979 EliminateRedundantGuards(copy);
1980 auto nodes = copy->block()->nodes();
1981 auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1982 auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1983 ASSERT_EQ(num_guards, 3);
1984 InsertBailOuts(copy);
1985 auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
1986 auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
1987 ASSERT_EQ(num_guards, num_bailouts);
1988 std::vector<Node*> bailouts(num_bailouts);
1989 std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
1990
1991 for (auto blo : bailouts) {
1992 ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
1993 }
1994 }
1995
TEST(ProfilerTest,Basic)1996 TEST(ProfilerTest, Basic) {
1997 constexpr int batch_size = 4;
1998 constexpr int input_size = 256;
1999
2000 int hidden_size = 2 * input_size;
2001
2002 auto input = at::randn({batch_size, input_size}, at::kCPU);
2003 auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
2004 auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
2005 auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
2006 auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
2007
2008 auto g = build_lstm();
2009 auto stack = createStack({input, hx, cx, w_ih, w_hh});
2010
2011 auto& opt_graph = *g.get();
2012 ArgumentSpecCreator arg_spec_creator(opt_graph);
2013 ArgumentSpec spec =
2014 arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
2015 arg_spec_creator.specializeTypes(opt_graph, spec);
2016 auto pr = ProfilingRecord::instrumentGraph(g);
2017 Code cd(pr->profiled_graph_, "");
2018 InterpreterState is{cd};
2019 is.run(stack);
2020
2021 // profiled types are stored as attributes and show up in the dump, e.g.
2022 // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
2023 // requires_grad=0, device=cpu)
2024 testing::FileCheck()
2025 .check("Tensor = prim::profile[profiled_type")
2026 ->check_same("256")
2027 ->run(*pr->profiled_graph_);
2028
2029 auto begin = pr->profiled_graph_->block()->nodes().begin();
2030 auto end = pr->profiled_graph_->block()->nodes().end();
2031 auto mm =
2032 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
2033 ASSERT_NE(mm, end);
2034 std::vector<int64_t> mm_expected{4, 2048};
2035 std::vector<int64_t> eltwise{4, 512};
2036 checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
2037 auto mul_n =
2038 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
2039 ASSERT_NE(mul_n, end);
2040 checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2041 auto tanh_n =
2042 std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
2043 checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2044 }
2045
TEST(ProfilerTest,OptionalProfiling)2046 TEST(ProfilerTest, OptionalProfiling) {
2047 auto graph = std::make_shared<Graph>();
2048 std::unordered_map<std::string, Value*> vmap;
2049 parseIR(
2050 R"IR(
2051 graph(%inp : Tensor,
2052 %weight : Tensor,
2053 %bias : Tensor?):
2054 %1 : Tensor = aten::linear(%inp, %weight, %bias)
2055 return (%1))IR",
2056 &*graph,
2057 vmap);
2058
2059 auto pr = ProfilingRecord::instrumentGraph(graph);
2060 pr->profiling_count_ = 2;
2061
2062 auto input = torch::randn({1, 2});
2063 auto weight = torch::randn({2, 2});
2064 auto bias = torch::randn({1, 2});
2065
2066 auto stack = createStack({input, weight, bias});
2067 Code cd(pr->profiled_graph_, "");
2068 InterpreterState is{cd};
2069 is.run(stack);
2070
2071 testing::FileCheck()
2072 .check_count("Tensor? = prim::profile[profiled_type", 1, true)
2073 ->run(*pr->profiled_graph_);
2074
2075 // make sure we recorded the shape
2076 auto begin = pr->profiled_graph_->block()->nodes().begin();
2077 auto end = pr->profiled_graph_->block()->nodes().end();
2078 auto linear = std::find_if(
2079 begin, end, [](Node* n) { return n->kind() == aten::linear; });
2080 ASSERT_NE(linear, end);
2081 std::vector<int64_t> bias_expected_shape = {1, 2};
2082 auto profiled_bias = linear->namedInput("bias")->node();
2083 checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2084 ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
2085
2086 auto none_bias = c10::IValue();
2087
2088 stack.clear();
2089 stack.emplace_back(input);
2090 stack.emplace_back(weight);
2091 stack.emplace_back(none_bias);
2092 is = InterpreterState{cd};
2093 is.run(stack);
2094
2095 // make sure we recorded that "None" was seen.
2096 begin = pr->profiled_graph_->block()->nodes().begin();
2097 end = pr->profiled_graph_->block()->nodes().end();
2098 linear = std::find_if(
2099 begin, end, [](Node* n) { return n->kind() == aten::linear; });
2100 ASSERT_NE(linear, end);
2101 profiled_bias = linear->namedInput("bias")->node();
2102 checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2103 ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
2104 }
2105
TEST(CallStackTest,Basic)2106 TEST(CallStackTest, Basic) {
2107 const auto text = R"(
2108 def ham(x):
2109 return x/7
2110
2111 def bar(x):
2112 return x*3
2113
2114 def baz(x):
2115 return ham(x)*x
2116
2117 def foo(x):
2118 return bar(x)*baz(x)*11
2119 )";
2120 auto cu = compile(text);
2121 const auto& foo = toGraphFunction(cu->get_function("foo"));
2122 for (Node* n : foo.optimized_graph()->nodes()) {
2123 if (n->kind() == prim::Constant) {
2124 if (!n->hasAttribute(attr::value) ||
2125 n->kindOf(attr::value) != AttributeKind::i) {
2126 continue;
2127 }
2128 int v = n->i(attr::value);
2129 switch (v) {
2130 case 3: {
2131 // Const 3 comes from function 'bar', which gets inlined to 'foo'.
2132 // The callstack for the corresponding node should contain only the
2133 // function 'bar'.
2134 ASSERT_TRUE(n->callstack());
2135 auto callstack_vector = (*n->callstack())->vec();
2136 ASSERT_EQ(callstack_vector.size(), 1);
2137 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
2138 break;
2139 }
2140 case 7: {
2141 // Const 7 comes from function 'ham', which gets inlined to 'baz',
2142 // which is then inlined to 'foo'. The callstack for the corresponding
2143 // node should contain these two functions.
2144 ASSERT_TRUE(n->callstack());
2145 auto callstack_vector = (*n->callstack())->vec();
2146 ASSERT_EQ(callstack_vector.size(), 2);
2147 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
2148 ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
2149 break;
2150 }
2151 case 11: {
2152 // Const 11 comes from function 'foo', which is not inlined anywhere
2153 // and thus it should not have a callstack.
2154 ASSERT_FALSE(n->callstack());
2155 break;
2156 }
2157 }
2158 }
2159 }
2160
2161 // Check that inlining doesn't corrupt callstack of the callee's nodes.
2162 const auto& baz = toGraphFunction(cu->get_function("baz"));
2163 for (Node* n : baz.optimized_graph()->nodes()) {
2164 if (n->kind() == prim::Constant) {
2165 if (!n->hasAttribute(attr::value) ||
2166 n->kindOf(attr::value) != AttributeKind::i) {
2167 continue;
2168 }
2169 int v = n->i(attr::value);
2170 ASSERT_TRUE(v == 7);
2171 // Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
2172 // was also inlined into 'foo', but when looking at the graph of 'baz' we
2173 // should only see a callstack of depth 1 (containing only 'ham').
2174 ASSERT_TRUE(n->callstack());
2175 auto callstack_vector = (*n->callstack())->vec();
2176 ASSERT_EQ(callstack_vector.size(), 1);
2177 ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
2178 }
2179 }
2180 }
2181
TEST(CallStackTest,Caching)2182 TEST(CallStackTest, Caching) {
2183 const auto text = R"(
2184
2185 def a(x):
2186 print("a1")
2187 print("a2")
2188 return x
2189
2190 def b(x):
2191 print("b1")
2192 print("b2")
2193 a(x)
2194 return x
2195
2196 def c(x):
2197 print("c1")
2198 print("c2")
2199 b(x)
2200 return x
2201 )";
2202 auto cu = compile(text);
2203 const auto& baz = toGraphFunction(cu->get_function("c"));
2204 std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
2205 for (Node* n : baz.optimized_graph()->nodes()) {
2206 if (n->kind() == prim::Constant) {
2207 if (!n->hasAttribute(attr::value) ||
2208 n->kindOf(attr::value) != AttributeKind::s) {
2209 continue;
2210 }
2211 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
2212 std::string v = n->s(attr::value);
2213 if (n->callstack()) {
2214 callstack_objects[v] = n->callstack()->get();
2215 }
2216 }
2217 }
2218 // We expect to see nodes prim::Constant[value="a1"] and
2219 // prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
2220 // the same (a->b->c), so we want to make sure we're not creating different
2221 // callstack entries for them.
2222 ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
2223 ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
2224 }
2225
TEST(InlinedCallStackTest,BlockAnnotation)2226 TEST(InlinedCallStackTest, BlockAnnotation) {
2227 Module a("A");
2228 a.define(R"(
2229 def forward(self, x, y, z: int):
2230 if (z == 1):
2231 return x + y
2232 else:
2233 return x * y
2234 )");
2235 Module b("B");
2236 b.define(R"(
2237 def forward(self, x):
2238 return x + 2
2239 )");
2240 Module c("C");
2241 c.register_module("A0", a);
2242 c.register_module("B0", b);
2243 c.define(R"(
2244 def forward(self, x, y, z: int):
2245 return self.A0.forward(x, y, z) + self.B0.forward(x)
2246 )");
2247
2248 auto graph =
2249 toGraphFunction(c.get_method("forward").function()).optimized_graph();
2250 std::stringstream add_ss, mul_ss;
2251 for (Node* n : graph->nodes()) {
2252 if (n->kind() == prim::If) {
2253 for (Block* block : n->blocks()) {
2254 for (Node* if_node : block->nodes()) {
2255 if (if_node->kind() == aten::add) {
2256 for (const auto& e : if_node->callstack().value()->vec()) {
2257 add_ss << std::get<1>(e);
2258 }
2259 add_ss << if_node->sourceRange();
2260 }
2261 if (if_node->kind() == aten::mul) {
2262 for (const auto& e : if_node->callstack().value()->vec()) {
2263 mul_ss << std::get<1>(e);
2264 }
2265 mul_ss << if_node->sourceRange();
2266 }
2267 }
2268 }
2269 }
2270 }
2271 ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
2272 ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
2273 ASSERT_NE(
2274 add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2275 ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
2276 ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
2277 ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
2278 ASSERT_NE(
2279 mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2280 ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
2281 }
2282
2283 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest,SelfCallMethods)2284 TEST(InlinedCallStackTest, SelfCallMethods) {
2285 Module a("A");
2286 a.define(R"(
2287 def my_new_method(self, x):
2288 return x * 3
2289 def forward_impl_(self, x, y):
2290 return self.my_new_method(x) + y
2291 def forward(self, x, y):
2292 y = y + 2
2293 return self.forward_impl_(x, y)
2294 )");
2295 Module b("B");
2296 b.define(R"(
2297 def forward(self, x):
2298 return x + 2
2299 )");
2300 Module c("C");
2301 c.register_module("A0", a);
2302 c.register_module("B0", b);
2303 c.define(R"(
2304 def call_b(self, x):
2305 return self.B0.forward(x)
2306 def forward(self, x, y):
2307 return self.A0.forward(x, y) + self.call_b(x)
2308 )");
2309
2310 auto graph =
2311 toGraphFunction(c.get_method("forward").function()).optimized_graph();
2312 std::unordered_map<std::string, size_t> module_hierarchies;
2313 for (Node* n : graph->nodes()) {
2314 auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
2315 if (module_hierarchies.count(hierarchy) == 0) {
2316 module_hierarchies[hierarchy] = 0;
2317 }
2318 module_hierarchies[hierarchy] += 1;
2319 }
2320 ASSERT_EQ(module_hierarchies["A0(A)"], 2);
2321 ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
2322 ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
2323 ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
2324 ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
2325 }
2326
TEST(AutogradSymbolsTest,Basic)2327 TEST(AutogradSymbolsTest, Basic) {
2328 Symbol sym = Symbol::fromQualString("aten::test_symbol");
2329 Graph graph;
2330 auto node = graph.create(sym);
2331 TORCH_CHECK(canRunWithAutograd(node));
2332
2333 sym = Symbol::fromQualString("prim::test_symbol");
2334 node = graph.create(sym);
2335 TORCH_CHECK(canRunWithAutograd(node));
2336
2337 sym = Symbol::fromQualString("prim::FusionGroup");
2338 node = graph.create(sym);
2339 TORCH_CHECK(!canRunWithAutograd(node));
2340
2341 sym = Symbol::fromQualString("custom::test_symbol");
2342 node = graph.create(sym);
2343 TORCH_CHECK(!canRunWithAutograd(node));
2344 }
2345
TEST(DefaultArgTypeHintingTest,Basic)2346 TEST(DefaultArgTypeHintingTest, Basic) {
2347 const auto text_non_hinted = R"(
2348
2349 def a(x, y=1):
2350 print("a1")
2351 print("a2")
2352 return x
2353 )";
2354
2355 const auto text_hinted = R"(
2356
2357 def a(x, y:int=1):
2358 print("a1")
2359 print("a2")
2360 return x
2361 )";
2362
2363 try {
2364 compile(text_non_hinted);
2365 ASSERT_TRUE(0);
2366 } catch (const std::exception& c) {
2367 }
2368
2369 auto cu = compile(text_hinted);
2370 }
2371
2372 // Basic set case.
TEST(FuturesTest,Basic)2373 TEST(FuturesTest, Basic) {
2374 auto f1 = c10::make_intrusive<Future>(IntType::get());
2375 ASSERT_FALSE(f1->completed());
2376 ASSERT_FALSE(f1->hasValue());
2377 int32_t sat1 = 0;
2378 int32_t sat2 = 0;
2379 f1->addCallback([&](Future& /* unused */) { ++sat1; });
2380 f1->markCompleted(43);
2381 ASSERT_TRUE(f1->completed());
2382 ASSERT_TRUE(f1->hasValue());
2383 ASSERT_FALSE(f1->hasError());
2384 ASSERT_EQ(sat1, 1);
2385 ASSERT_EQ(f1->constValue().toInt(), 43);
2386 ASSERT_EQ(f1->value().toInt(), 43);
2387 f1->addCallback([&](Future& /* unused */) { ++sat2; });
2388 ASSERT_EQ(sat1, 1);
2389 ASSERT_EQ(sat2, 1);
2390 }
2391
2392 // Sparse CUDA tensor test
TEST(FutureTest,SparseTensor)2393 TEST(FutureTest, SparseTensor) {
2394 // Skip test if CUDA is not available.
2395 bool has_cuda = at::globalContext().hasCUDA();
2396 if (!has_cuda) {
2397 LOG(INFO) << "CUDA not available, skipping test";
2398 }
2399 for (int i = 0; i < 2; ++i) {
2400 auto f = c10::make_intrusive<Future>(TensorType::get());
2401 at::TensorOptions opts = at::TensorOptions().device(at::DeviceType::CUDA);
2402 auto sparse_tensor = i == 0 ? at::ones(10).to_sparse()
2403 : at::sparse_coo_tensor(
2404 at::arange(10).unsqueeze(0).to(at::kLong),
2405 at::ones({10, 10}),
2406 opts);
2407 // Runs storage extraction for sparse CUDA tensors
2408 f->markCompleted(sparse_tensor);
2409 ASSERT_TRUE(f->completed());
2410 ASSERT_FALSE(f->hasError());
2411 }
2412 }
2413
2414 // Basic error cases.
TEST(FuturesTest,Error)2415 TEST(FuturesTest, Error) {
2416 auto f1 = c10::make_intrusive<Future>(IntType::get());
2417 int sat1 = 0;
2418 int sat2 = 0;
2419 f1->addCallback([&](Future& /* unused */) { ++sat1; });
2420 f1->setError(
2421 std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2422 ASSERT_EQ(sat1, 1);
2423 ASSERT_TRUE(f1->completed());
2424 ASSERT_TRUE(f1->hasError());
2425 ASSERT_FALSE(f1->hasValue());
2426 try {
2427 (void)f1->value();
2428 ASSERT_TRUE(false); // Supposed to throw.
2429 } catch (const std::exception& e) {
2430 ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
2431 }
2432 f1->addCallback([&](Future& /* unused */) { ++sat2; });
2433 ASSERT_EQ(sat1, 1);
2434 ASSERT_EQ(sat2, 1);
2435 f1->setErrorIfNeeded(
2436 std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
2437 ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
2438 ASSERT_EQ(sat1, 1);
2439 ASSERT_EQ(sat2, 1);
2440 try {
2441 (void)f1->constValue();
2442 ASSERT_TRUE(false); // Supposed to throw.
2443 } catch (const std::exception& e) {
2444 // Original error should be logged.
2445 ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
2446 }
2447 }
2448
2449 // then
TEST(FuturesTest,Then)2450 TEST(FuturesTest, Then) {
2451 auto f1 = c10::make_intrusive<Future>(IntType::get());
2452 auto f2 = f1->then(
2453 [](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
2454 IntType::get());
2455 auto f3 = f2->then(
2456 [](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
2457 IntType::get());
2458 bool done = false;
2459 f3->addCallback([&done](Future& f3) {
2460 ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
2461 done = true;
2462 });
2463 ASSERT_FALSE(done);
2464 f1->markCompleted(42);
2465 ASSERT_TRUE(done);
2466 }
2467
2468 // collectAll()
TEST(FuturesTest,CollectAll)2469 TEST(FuturesTest, CollectAll) {
2470 auto s1 = c10::make_intrusive<Future>(IntType::get());
2471 auto s2 = c10::make_intrusive<Future>(IntType::get());
2472 auto s3 = c10::make_intrusive<Future>(IntType::get());
2473
2474 // Empty case
2475 c10::List<intrusive_ptr<ivalue::Future>> futures(
2476 FutureType::create(IntType::get()));
2477 auto c1 = collectAll(futures);
2478 ASSERT_TRUE(c1->completed());
2479 ASSERT_EQ(c1->value().toList().size(), 0);
2480 ASSERT_TRUE(
2481 *(c1->value().toList().elementType()) ==
2482 *FutureType::create(IntType::get()));
2483
2484 // 1-element, initially not completed.
2485 futures.push_back(s1);
2486 auto c2 = collectAll(futures);
2487 ASSERT_FALSE(c2->completed());
2488 s1->markCompleted(5);
2489 ASSERT_TRUE(c2->completed());
2490 ASSERT_EQ(c2->value().toList().size(), 1);
2491 ASSERT_TRUE(
2492 *(c2->value().toList().elementType()) ==
2493 *FutureType::create(IntType::get()));
2494 ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
2495
2496 // 1-element, already completed
2497 auto c3 = collectAll(futures);
2498 ASSERT_TRUE(c3->completed());
2499 ASSERT_EQ(c3->value().toList().size(), 1);
2500 ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
2501
2502 // 3 elements.
2503 futures.push_back(s2);
2504 futures.push_back(s3);
2505 auto c4 = collectAll(futures);
2506 ASSERT_FALSE(c4->completed());
2507 s3->markCompleted(7);
2508 ASSERT_FALSE(c4->completed());
2509 s2->markCompleted(6);
2510 ASSERT_TRUE(c4->completed());
2511 ASSERT_EQ(c4->value().toList().size(), 3);
2512 ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
2513 ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
2514 ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
2515 ASSERT_TRUE(
2516 *(c4->value().toList().elementType()) ==
2517 *FutureType::create(IntType::get()));
2518
2519 // Handle exception in the list.
2520 auto s4 = c10::make_intrusive<Future>(IntType::get());
2521 futures.push_back(s4);
2522 auto c5 = collectAll(futures);
2523 ASSERT_FALSE(c5->completed());
2524 s4->setError(
2525 std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2526 ASSERT_TRUE(c5->completed());
2527 try {
2528 c5->value();
2529 ASSERT_TRUE(false); // supposed to throw
2530 } catch (const std::exception& e) {
2531 ASSERT_EQ(std::string(e.what()), "Failed");
2532 }
2533 }
2534
2535 // collectAny()
TEST(FuturesTest,CollectAny)2536 TEST(FuturesTest, CollectAny) {
2537 auto s1 = c10::make_intrusive<Future>(IntType::get());
2538
2539 // Empty case
2540 c10::List<intrusive_ptr<ivalue::Future>> futures(
2541 FutureType::create(IntType::get()));
2542 auto c1 = collectAny(futures);
2543 ASSERT_TRUE(c1->completed());
2544
2545 // 1 element, not yet satisfied
2546 futures.push_back(s1);
2547 auto c2 = collectAny(futures);
2548 ASSERT_FALSE(c2->completed());
2549 s1->markCompleted(5);
2550 ASSERT_TRUE(c2->completed());
2551 ASSERT_TRUE(c2->value().isInt());
2552 ASSERT_EQ(c2->value().toInt(), 5);
2553
2554 // 1 element already satisfied.
2555 auto c3 = collectAny(futures);
2556 ASSERT_TRUE(c3->completed());
2557 ASSERT_TRUE(c3->value().isInt());
2558 ASSERT_EQ(c3->value().toInt(), 5);
2559
2560 // 2 elements
2561 futures.clear();
2562 auto s2 = c10::make_intrusive<Future>(IntType::get());
2563 auto s3 = c10::make_intrusive<Future>(IntType::get());
2564 futures.push_back(s2);
2565 futures.push_back(s3);
2566 auto c4 = collectAny(futures);
2567 ASSERT_FALSE(c4->completed());
2568 s3->markCompleted(7);
2569 ASSERT_TRUE(c4->completed());
2570 ASSERT_EQ(c4->value().toInt(), 7);
2571 s2->markCompleted(1);
2572 ASSERT_EQ(c4->value().toInt(), 7);
2573 }
2574
TEST(TLSFutureCallbacksTest,Basic)2575 TEST(TLSFutureCallbacksTest, Basic) {
2576 // cb that verifies the profiler is enabled
2577 auto profilerEnabledCb = [](Future& /* unused */) {
2578 ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2579 };
2580 // test running callbacks with propagation of TLS state.
2581 {
2582 // Enable the profiler in this thread
2583 torch::autograd::profiler::enableProfilerLegacy(
2584 torch::autograd::profiler::ProfilerConfig(
2585 torch::autograd::profiler::ProfilerState::CPU, false, false));
2586 auto s1 = c10::make_intrusive<Future>(IntType::get());
2587 s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
2588 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2589 // Since we join here, we can ensure that all callbacks corresponding to
2590 // markCompleted() have finished.
2591 t.join();
2592 torch::autograd::profiler::disableProfilerLegacy();
2593 }
2594 // then() with TLS State
2595 {
2596 // Enable the profiler in this thread
2597 torch::autograd::profiler::enableProfilerLegacy(
2598 torch::autograd::profiler::ProfilerConfig(
2599 torch::autograd::profiler::ProfilerState::CPU, false, false));
2600 auto s1 = c10::make_intrusive<Future>(IntType::get());
2601 auto s2 = s1->then(
2602 wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
2603 profilerEnabledCb(s1);
2604 return at::IValue(1);
2605 }),
2606 IntType::get());
2607 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2608 t.join();
2609 s2->wait();
2610 torch::autograd::profiler::disableProfilerLegacy();
2611 }
2612 }
2613
TEST(ProfilerDisableInCallbackTest,Basic)2614 TEST(ProfilerDisableInCallbackTest, Basic) {
2615 // cb that verifies the profiler is enabled
2616 auto profilerEnabledCb = []() {
2617 ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2618 };
2619 torch::autograd::profiler::enableProfilerLegacy(
2620 torch::autograd::profiler::ProfilerConfig(
2621 torch::autograd::profiler::ProfilerState::CPU, false, false));
2622 auto s1 = c10::make_intrusive<Future>(IntType::get());
2623 auto verifyProfilerCb =
2624 wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
2625 // Ensure the profiler is still enabled in this thread.
2626 profilerEnabledCb();
2627 auto t1 = torch::ones({2, 2});
2628 auto t2 = torch::ones({2, 2});
2629 torch::add(t1, t2);
2630 // Don't cleanup TLSState, and just consolidate.
2631 auto opts =
2632 torch::autograd::profiler::ProfilerDisableOptions(false, true);
2633 auto thread_event_lists =
2634 // NOLINTNEXTLINE(performance-move-const-arg)
2635 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2636 // Ensure that the events from this thread are still profiled and we
2637 // obtain the expected in events in our consolidated list when calling
2638 // disableProfilerLegacy().
2639 bool found_ones = false;
2640 bool found_add = false;
2641 for (const auto& li : thread_event_lists) {
2642 for (const auto& evt : li) {
2643 if (strcmp(evt.name(), "aten::add") == 0) {
2644 found_add = true;
2645 } else if (strcmp(evt.name(), "aten::ones") == 0) {
2646 found_ones = true;
2647 }
2648 }
2649 if (found_add && found_ones) {
2650 break;
2651 }
2652 }
2653 ASSERT_TRUE(found_ones);
2654 ASSERT_TRUE(found_add);
2655 });
2656
2657 s1->addCallback(verifyProfilerCb);
2658 // Disable the profiler, but do not consolidate results in the main thread.
2659 auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2660 // NOLINTNEXTLINE(performance-move-const-arg)
2661 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2662 std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
2663 t.join();
2664
2665 // Similar to above test, but verifies correctness in the case where
2666 // continuation runs on the main thread.
2667 torch::autograd::profiler::enableProfilerLegacy(
2668 torch::autograd::profiler::ProfilerConfig(
2669 torch::autograd::profiler::ProfilerState::CPU, false, false));
2670 s1 = c10::make_intrusive<Future>(IntType::get());
2671 s1->addCallback(verifyProfilerCb);
2672 // Runs callback inline
2673 s1->markCompleted(at::IValue(1));
2674 opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2675 // NOLINTNEXTLINE(performance-move-const-arg)
2676 torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2677 }
2678
TEST(RecordDebugHandles,Basic)2679 TEST(RecordDebugHandles, Basic) {
2680 // Enable the profiler in this thread
2681 const std::set<torch::autograd::profiler::ActivityType> activities(
2682 {torch::autograd::profiler::ActivityType::CPU});
2683 torch::autograd::profiler::prepareProfiler(
2684 torch::autograd::profiler::ProfilerConfig(
2685 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2686 activities);
2687 torch::autograd::profiler::enableProfiler(
2688 torch::autograd::profiler::ProfilerConfig(
2689 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2690 activities);
2691 {
2692 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2693 float x{5.9999}, y{2.1212};
2694 float z = x / y;
2695 (void)z;
2696 }
2697 {
2698 RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2699 float x{5.9999}, y{2.1212};
2700 float z = x / y;
2701 (void)z;
2702 }
2703 auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2704 const auto& kineto_events = profiler_results_ptr->events();
2705 size_t my_events{0};
2706 for (const auto& e : kineto_events) {
2707 if (e.name() == "my_function") {
2708 ASSERT_EQ(e.debugHandle(), 42);
2709 my_events++;
2710 } else if (e.name() == "not_my_function") {
2711 ASSERT_EQ(e.debugHandle(), -1);
2712 my_events++;
2713 }
2714 }
2715 ASSERT_EQ(my_events, 2);
2716 }
2717
TEST(RecordDebugHandles,ScopedCallbacks)2718 TEST(RecordDebugHandles, ScopedCallbacks) {
2719 // Enable the profiler in this thread
2720 torch::autograd::profiler::prepareProfiler(
2721 torch::autograd::profiler::ProfilerConfig(
2722 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2723 {torch::autograd::profiler::ActivityType::CPU});
2724 torch::autograd::profiler::enableProfiler(
2725 torch::autograd::profiler::ProfilerConfig(
2726 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2727 {torch::autograd::profiler::ActivityType::CPU});
2728
2729 {
2730 auto a = torch::rand({128, 128});
2731 auto b = torch::rand({128, 128});
2732 auto c = a + b;
2733 }
2734 auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2735 ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
2736
2737 // Enable the profiler in this thread
2738 torch::autograd::profiler::prepareProfiler(
2739 torch::autograd::profiler::ProfilerConfig(
2740 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2741 {torch::autograd::profiler::ActivityType::CPU});
2742 torch::autograd::profiler::enableProfiler(
2743 torch::autograd::profiler::ProfilerConfig(
2744 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2745 {torch::autograd::profiler::ActivityType::CPU},
2746 {at::RecordScope::LITE_INTERPRETER});
2747 {
2748 auto a = torch::rand({128, 128});
2749 auto b = torch::rand({128, 128});
2750 auto c = a + b;
2751 }
2752 profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2753 ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
2754
2755 torch::autograd::profiler::prepareProfiler(
2756 torch::autograd::profiler::ProfilerConfig(
2757 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2758 {torch::autograd::profiler::ActivityType::CPU});
2759 torch::autograd::profiler::enableProfiler(
2760 torch::autograd::profiler::ProfilerConfig(
2761 torch::autograd::profiler::ProfilerState::KINETO, false, false),
2762 {torch::autograd::profiler::ActivityType::CPU},
2763 {at::RecordScope::LITE_INTERPRETER});
2764 {
2765 RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2766 auto a = torch::rand({128, 128});
2767 auto b = torch::rand({128, 128});
2768 auto c = a + b;
2769 }
2770 {
2771 RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2772 auto a = torch::rand({128, 128});
2773 auto b = torch::rand({128, 128});
2774 auto c = a + b;
2775 }
2776 profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2777 const auto& kineto_events = profiler_results_ptr->events();
2778 for (const auto& e : kineto_events) {
2779 if (e.name() == "my_function") {
2780 ASSERT_EQ(e.debugHandle(), 42);
2781 }
2782 }
2783 ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
2784 }
2785
TEST(IValueKWargsTest,Basic)2786 TEST(IValueKWargsTest, Basic) {
2787 const auto text = R"(
2788 def foo(a : int, b : int, c : int = 4):
2789 return a + 2*b + 3*c
2790 )";
2791 auto cu = compile(text);
2792 auto result = cu->get_function("foo")({1}, {{"b", 3}});
2793 ASSERT_EQ(result.toInt(), 19);
2794 }
2795
TEST(ComputeFlopsTest,Basic)2796 TEST(ComputeFlopsTest, Basic) {
2797 uint64_t flops = 0;
2798
2799 // Test unknown operator
2800 std::unordered_map<std::string, c10::IValue> extra_args;
2801 flops = torch::profiler::impl::computeFlops(
2802 std::string("aten::unknown"), extra_args);
2803 ASSERT_EQ(flops, 0);
2804
2805 // Test aten::conv2d
2806 extra_args.clear();
2807 std::vector<int64_t> input_size = {4, 5, 6, 7};
2808 std::vector<int64_t> weight_size = {3, 5, 2, 1};
2809 std::vector<int64_t> padding = {1, 0};
2810 std::vector<int64_t> stride = {1, 1};
2811 std::vector<int64_t> dilation = {0, 0};
2812 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2813 extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2814 extra_args["groups"] = 1;
2815 extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
2816 extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2817 extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
2818 flops = torch::profiler::impl::computeFlops(
2819 std::string("aten::conv2d"), extra_args);
2820 ASSERT_EQ(flops, 13440);
2821
2822 // Test aten::conv2d fail
2823 input_size = {4, 5, 6, 7};
2824 weight_size = {4, 5, 6};
2825 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2826 extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2827 flops = torch::profiler::impl::computeFlops(
2828 std::string("aten::conv2d"), extra_args);
2829 ASSERT_EQ(flops, 0);
2830
2831 // Test aten::conv2d fail 2
2832 weight_size = {3, 5, 2, 1};
2833 stride = {0, 0};
2834 extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
2835 extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2836 flops = torch::profiler::impl::computeFlops(
2837 std::string("aten::conv2d"), extra_args);
2838 ASSERT_EQ(flops, 0);
2839
2840 // Test aten::conv2d fail 3
2841 extra_args.clear();
2842 input_size = {4, 5, 6, 7};
2843 extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2844 flops = torch::profiler::impl::computeFlops(
2845 std::string("aten::conv2d"), extra_args);
2846 ASSERT_EQ(flops, 0);
2847
2848 // Test aten::mm
2849 extra_args.clear();
2850 std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
2851 std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
2852 extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2853 extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2854 flops =
2855 torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2856 ASSERT_EQ(flops, 43200);
2857
2858 // Test aten::addmm
2859 flops = torch::profiler::impl::computeFlops(
2860 std::string("aten::addmm"), extra_args);
2861 ASSERT_EQ(flops, 43200);
2862
2863 // Test aten::bmm
2864 extra_args.clear();
2865 mat1_sizes = {7, 5, 6};
2866 mat2_sizes = {7, 6, 3};
2867 extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2868 extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2869 flops =
2870 torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
2871 ASSERT_EQ(flops, 1260);
2872
2873 // Test aten::baddbmm
2874 flops = torch::profiler::impl::computeFlops(
2875 std::string("aten::baddbmm"), extra_args);
2876 ASSERT_EQ(flops, 1260);
2877
2878 // Test mm out of range
2879 extra_args.clear();
2880 flops =
2881 torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2882 ASSERT_EQ(flops, 0);
2883
2884 // Test aten::add.Tensor
2885 extra_args.clear();
2886 std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
2887 extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2888 flops =
2889 torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
2890 ASSERT_EQ(flops, 360);
2891
2892 // Test aten::mul.Tensor
2893 extra_args.clear();
2894 mat_sizes = {3, 4, 5, 6};
2895 extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2896 flops =
2897 torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
2898 ASSERT_EQ(flops, 360);
2899 }
2900
TEST(TestConstant,TensorGrad)2901 TEST(TestConstant, TensorGrad) {
2902 auto graph = std::make_shared<Graph>();
2903 IValue ten = torch::randn({3, 5}).requires_grad_(true);
2904 auto con = tryInsertConstant(*graph, ten);
2905 ASSERT_TRUE(con == std::nullopt);
2906 }
2907
TEST(TestMutation,Basic)2908 TEST(TestMutation, Basic) {
2909 auto graph = std::make_shared<Graph>();
2910 std::unordered_map<std::string, Value*> vmap;
2911 parseIR(
2912 R"IR(
2913 graph(%x.1 : Tensor):
2914 %2 : int = prim::Constant[value=1]()
2915 %9 : int = prim::Constant[value=4]()
2916 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2917 %7 : Tensor = aten::add_(%x.3, %2, %2)
2918 %y.1 : Tensor = aten::add(%x.3, %9, %2)
2919 return (%y.1))IR",
2920 &*graph,
2921 vmap);
2922 RemoveTensorMutation(graph, [](Node*) { return false; });
2923 testing::FileCheck().check("aten::add_")->run(*graph);
2924 RemoveTensorMutation(graph, [](Node*) { return true; });
2925 testing::FileCheck().check_not("aten::add_")->run(*graph);
2926 }
2927
TEST(TestInplaceToFunctionalActivation,Basic)2928 TEST(TestInplaceToFunctionalActivation, Basic) {
2929 auto graph = std::make_shared<Graph>();
2930 std::unordered_map<std::string, Value*> vmap;
2931 parseIR(
2932 R"IR(
2933 graph(%x.1 : Tensor):
2934 %2 : int = prim::Constant[value=1]()
2935 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2936 %y : Tensor = aten::relu_(%x.3)
2937 return (%y))IR",
2938 &*graph,
2939 vmap);
2940 InplaceToFunctionalActivation(graph);
2941 testing::FileCheck().check("aten::relu")->run(*graph);
2942 testing::FileCheck().check_not("aten::relu_")->run(*graph);
2943 }
2944
TEST(TestRegisterShapeOp,Basic)2945 TEST(TestRegisterShapeOp, Basic) {
2946 auto graph = std::make_shared<Graph>();
2947 std::unordered_map<std::string, Value*> vmap;
2948 parseIR(
2949 R"IR(
2950 graph():
2951 %2 : int = prim::Constant[value=5]()
2952 %3: int[] = prim::ListConstruct(%2, %2)
2953 return (%3))IR",
2954 &*graph,
2955 vmap);
2956
2957 auto g2 = std::make_shared<Graph>();
2958 parseIR(
2959 R"IR(
2960 graph():
2961 %2 : Tensor = prim::MakeTestTensor()
2962 return (%2))IR",
2963 &*g2,
2964 vmap);
2965
2966 const FunctionSchema& schema = g2->nodes().begin()->schema();
2967 torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
2968 PropagateShapesOnGraph(g2);
2969 testing::FileCheck().check("5, 5")->run(*g2);
2970 }
2971
TEST(TestFunctionalToInplaceActivation,Basic)2972 TEST(TestFunctionalToInplaceActivation, Basic) {
2973 auto graph = std::make_shared<Graph>();
2974 std::unordered_map<std::string, Value*> vmap;
2975 parseIR(
2976 R"IR(
2977 graph(%x.1 : Tensor):
2978 %2 : int = prim::Constant[value=1]()
2979 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2980 %y : Tensor = aten::relu(%x.3)
2981 return (%y))IR",
2982 &*graph,
2983 vmap);
2984 FunctionalToInplaceActivation(graph);
2985 testing::FileCheck().check("aten::relu_")->run(*graph);
2986 testing::FileCheck().check_not("aten::relu(")->run(*graph);
2987 }
2988
TEST(TestFunctionExecutor,SimpleExecutorTest)2989 TEST(TestFunctionExecutor, SimpleExecutorTest) {
2990 auto graph = std::make_shared<Graph>();
2991 parseIR(
2992 R"IR(
2993 graph(%x.1 : Tensor):
2994 %2 : int = prim::Constant[value=1]()
2995 %x.3 : Tensor = aten::add(%x.1, %2, %2)
2996 %y : Tensor = aten::relu(%x.3)
2997 return (%y))IR",
2998 &*graph);
2999 {
3000 auto func = std::make_unique<GraphFunction>(
3001 "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
3002 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3003 Stack stack = {a};
3004 func->run(stack);
3005 auto g = lastExecutedOptimizedGraph();
3006 testing::FileCheck()
3007 .check("prim::profile")
3008 ->check("aten::add")
3009 ->check("aten::relu")
3010 ->run(*g);
3011 }
3012 {
3013 auto func = std::make_unique<GraphFunction>(
3014 "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
3015 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3016 Stack stack = {a};
3017 func->run(stack);
3018 auto g = func->getDebugState().graph;
3019 testing::FileCheck()
3020 .check_not("prim::profile")
3021 ->check("aten::add")
3022 ->check("aten::relu")
3023 ->run(*g);
3024 }
3025 }
3026
TEST(TestFunctionExecutor,RunDecompositionTest)3027 TEST(TestFunctionExecutor, RunDecompositionTest) {
3028 static auto* func = torch::jit::GetDecompositionExecutor(
3029 "aten::var(Tensor self, bool unbiased=True) -> Tensor");
3030 for (bool unbiased : {true, false}) {
3031 auto input = at::rand({4, 4});
3032 Stack stack = {input, unbiased};
3033 func->run(stack);
3034 at::Tensor out = pop(stack).toTensor();
3035 ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
3036 }
3037 }
3038
TEST(TestShapeGraphLinting,Basic)3039 TEST(TestShapeGraphLinting, Basic) {
3040 auto schemas = RegisteredShapeComputeSchemas();
3041 for (const auto& schema : schemas) {
3042 // arange does not actually support complex, leave as
3043 // union[int, float] for now
3044 if (schema->name() == "aten::arange") {
3045 continue;
3046 }
3047 auto g = shapeComputeGraphForSchema(*schema);
3048 TORCH_INTERNAL_ASSERT(g);
3049 LintShapeComputeGraph(schema, *g);
3050 }
3051 }
3052
3053 // TODO: move to test_kernel when global settings are explicit
3054 // fusion parameters
3055 class Composed : public ::testing::Test {
3056 public:
SetUp()3057 void SetUp() override {
3058 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3059 }
3060 };
3061
TEST_F(Composed,ComposedOp)3062 TEST_F(Composed, ComposedOp) {
3063 struct WithCPUFuser {
3064 WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
3065 overrideCanFuseOnCPU(val);
3066 }
3067
3068 ~WithCPUFuser() {
3069 overrideCanFuseOnCPU(cpuFuserEnabled);
3070 }
3071
3072 bool cpuFuserEnabled;
3073 };
3074
3075 #ifdef TORCH_ENABLE_LLVM
3076 const auto graph_string = R"IR(
3077 graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
3078 %1 : Float(5, 3, strides=[1, 5], device=cpu)):
3079 %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
3080 %3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
3081 %4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
3082 return (%3, %4))IR";
3083 auto graph = std::make_shared<Graph>();
3084 parseIR(graph_string, &*graph);
3085
3086 // wrong input sizes so we hit the fallback path
3087 auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3088 auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
3089 .transpose(0, 1);
3090 auto ref1 = a * (a * b);
3091 auto ref2 = a * ref1;
3092 WithCPUFuser g(true);
3093 bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
3094 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3095 FuseTensorExprs(
3096 graph,
3097 /*min_group_size*/ 2,
3098 /*add_composed_op*/ true,
3099 /*fuse_to_dynamic_shapes*/ true);
3100 Code code(graph, "");
3101 InterpreterState interpreter{code};
3102 std::vector<IValue> stack = {a, b};
3103 interpreter.run(stack);
3104 at::Tensor out2 = pop(stack).toTensor();
3105 at::Tensor out1 = pop(stack).toTensor();
3106 ASSERT_TRUE(at::allclose(ref1, out1));
3107 ASSERT_TRUE(at::allclose(ref2, out2));
3108
3109 auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3110 auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3111 stack = {inp_1, inp_2, a, b};
3112 InterpreterState interpreter2{code};
3113 interpreter2.run(stack);
3114 out2 = pop(stack).toTensor();
3115 out1 = pop(stack).toTensor();
3116 ASSERT_TRUE(at::allclose(ref1, out1));
3117 ASSERT_TRUE(at::allclose(ref2, out2));
3118 // inp_1 is on the bottom of the stack, and corresponds
3119 // to the second output. inp_2 is on the top corresponds to first output
3120 ASSERT_TRUE(at::allclose(inp_1, ref2));
3121 ASSERT_TRUE(at::allclose(inp_2, ref1));
3122 torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
3123 #endif
3124 }
3125
TEST(ConstantPropagation,CustomClassesCanBePropagated)3126 TEST(ConstantPropagation, CustomClassesCanBePropagated) {
3127 #ifdef USE_PYTORCH_QNNPACK
3128 const auto src = R"IR(
3129 graph():
3130 %none: NoneType = prim::Constant()
3131 %dim: int = prim::Constant[value=3]()
3132 %shape: int[] = prim::ListConstruct(%dim, %dim)
3133 %weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
3134 %scale: float = prim::Constant[value=1.]()
3135 %zero_point: int = prim::Constant[value=0]()
3136 %dtype: int = prim::Constant[value=12]()
3137 %weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
3138 %params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
3139 return (%params)
3140 )IR";
3141 auto graph = std::make_shared<Graph>();
3142 std::unordered_map<std::string, Value*> vmap;
3143 parseIR(src, graph.get(), vmap);
3144
3145 ConstantPropagation(graph);
3146
3147 testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
3148 #endif
3149 }
3150
3151 } // namespace jit
3152 } // namespace torch
3153