xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_misc.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gmock/gmock.h>
2 #include <gtest/gtest.h>
3 
4 #include <ATen/ATen.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/core/interned_strings.h>
7 #include <ATen/core/ivalue.h>
8 #include <ATen/core/jit_type_base.h>
9 #include <c10/macros/Macros.h>
10 #include <test/cpp/jit/test_utils.h>
11 #include <torch/csrc/jit/passes/remove_mutation.h>
12 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
13 #include <torch/csrc/jit/tensorexpr/kernel.h>
14 
15 #include <torch/csrc/autograd/engine.h>
16 #include <torch/csrc/autograd/generated/variable_factories.h>
17 #include <torch/csrc/autograd/profiler.h>
18 #include <torch/csrc/autograd/variable.h>
19 #include <torch/csrc/jit/api/function_impl.h>
20 #include <torch/csrc/jit/api/module.h>
21 #include <torch/csrc/jit/codegen/fuser/interface.h>
22 #include <torch/csrc/jit/frontend/ir_emitter.h>
23 #include <torch/csrc/jit/frontend/tracer.h>
24 #include <torch/csrc/jit/ir/alias_analysis.h>
25 #include <torch/csrc/jit/ir/attributes.h>
26 #include <torch/csrc/jit/ir/irparser.h>
27 #include <torch/csrc/jit/ir/scope.h>
28 #include <torch/csrc/jit/ir/type_hashing.h>
29 #include <torch/csrc/jit/jit_log.h>
30 #include <torch/csrc/jit/passes/bailout_graph.h>
31 #include <torch/csrc/jit/passes/canonicalize.h>
32 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
33 #include <torch/csrc/jit/passes/constant_propagation.h>
34 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
35 #include <torch/csrc/jit/passes/dead_code_elimination.h>
36 #include <torch/csrc/jit/passes/graph_fuser.h>
37 #include <torch/csrc/jit/passes/guard_elimination.h>
38 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
39 #include <torch/csrc/jit/passes/insert_guards.h>
40 #include <torch/csrc/jit/passes/liveness.h>
41 #include <torch/csrc/jit/passes/loop_unrolling.h>
42 #include <torch/csrc/jit/passes/lower_grad_of.h>
43 #include <torch/csrc/jit/passes/lower_tuples.h>
44 #include <torch/csrc/jit/passes/pass_manager.h>
45 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
46 #include <torch/csrc/jit/passes/restore_mutation.h>
47 #include <torch/csrc/jit/passes/shape_analysis.h>
48 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
49 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
50 #include <torch/csrc/jit/runtime/argument_spec.h>
51 #include <torch/csrc/jit/runtime/autodiff.h>
52 #include <torch/csrc/jit/runtime/custom_operator.h>
53 #include <torch/csrc/jit/runtime/decomposition_registry.h>
54 #include <torch/csrc/jit/runtime/graph_executor.h>
55 #include <torch/csrc/jit/runtime/interpreter.h>
56 #include <torch/csrc/jit/runtime/jit_trace.h>
57 #include <torch/csrc/jit/runtime/profiling_record.h>
58 #include <torch/csrc/jit/runtime/symbolic_script.h>
59 #include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
60 #include <torch/csrc/jit/serialization/import.h>
61 #include <torch/csrc/jit/testing/file_check.h>
62 #include <torch/jit.h>
63 #include <torch/script.h>
64 
65 #include <onnx/onnx_pb.h>
66 
67 #include <c10/util/Exception.h>
68 #include <c10/util/ThreadLocalDebugInfo.h>
69 
70 #include <torch/csrc/jit/passes/freeze_module.h>
71 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
72 #include <algorithm>
73 #include <cstddef>
74 #include <functional>
75 #include <iostream>
76 #include <memory>
77 #include <set>
78 #include <stdexcept>
79 #include <string>
80 #include <tuple>
81 #include <unordered_map>
82 #include <unordered_set>
83 #include <utility>
84 #include <vector>
85 
86 namespace torch {
87 namespace jit {
aliasAnalysisFromSchema()88 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
89   return c10::AliasAnalysisKind::FROM_SCHEMA;
90 }
91 
92 template <typename T>
operator <<(std::ostream & out,const std::vector<T> & list)93 std::ostream& operator<<(std::ostream& out, const std::vector<T>& list) {
94   size_t i = 0;
95   out << "{";
96   for (auto&& e : list) {
97     if (i++ > 0)
98       out << ", ";
99     out << e;
100   }
101   out << "}";
102   return out;
103 }
104 
TEST(InternedStringsTest,Basic)105 TEST(InternedStringsTest, Basic) {
106   ASSERT_EQ(prim::Param, Symbol::prim("Param"));
107   ASSERT_EQ(prim::Return, Symbol::prim("Return"));
108   ASSERT_EQ(prim::Return.toUnqualString(), std::string("Return"));
109   ASSERT_EQ(prim::Return.toQualString(), std::string("prim::Return"));
110   Symbol newsym = Symbol::aten("__NEW_SYMBOL");
111   size_t symstart = newsym;
112   ASSERT_EQ(newsym.toQualString(), std::string("aten::__NEW_SYMBOL"));
113   // TODO: This test is a bit too close to the implementation details.
114   ASSERT_EQ(Symbol::aten("What"), symstart + 1);
115   ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
116   ASSERT_EQ(Symbol::aten("What"), symstart + 1);
117   ASSERT_EQ(Symbol::aten("What2"), symstart + 2);
118   ASSERT_EQ(Symbol(symstart + 2).toUnqualString(), std::string("What2"));
119 }
120 
TEST(FromQualStringTest,Basic)121 TEST(FromQualStringTest, Basic) {
122   ASSERT_EQ(Symbol::fromQualString("prim::Param"), Symbol::prim("Param"));
123   ASSERT_EQ(Symbol::fromQualString("aten::mm"), Symbol::aten("mm"));
124   ASSERT_EQ(Symbol::fromQualString("onnx::LSTM"), Symbol::onnx("LSTM"));
125   ASSERT_EQ(Symbol::fromQualString("attr::value"), Symbol::attr("value"));
126   ASSERT_EQ(Symbol::fromQualString("scope::"), Symbol::scope(""));
127   ASSERT_EQ(Symbol::fromQualString("::").toUnqualString(), std::string(""));
128   ASSERT_EQ(
129       Symbol::fromQualString("::").ns().toQualString(),
130       std::string("namespaces::"));
131   ASSERT_EQ(
132       Symbol::fromQualString("new_ns::param").toUnqualString(),
133       std::string("param"));
134   ASSERT_EQ(
135       Symbol::fromQualString("new_ns::param").ns().toUnqualString(),
136       std::string("new_ns"));
137   ASSERT_EQ(
138       Symbol::fromQualString("new_ns::param").ns(),
139       Symbol::fromQualString("namespaces::new_ns"));
140 
141   auto bad_inputs = {"scope", ":", ""};
142   for (auto input : bad_inputs) {
143     try {
144       Symbol::fromQualString(input);
145       ASSERT_TRUE(0);
146     } catch (const std::exception& c) {
147     }
148   }
149 }
150 
TEST(THNNConvTest,Basic)151 TEST(THNNConvTest, Basic) {
152   std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
153   std::vector<int64_t> kernel_size = {3, 5};
154   std::vector<int64_t> stride = {1, 2};
155   std::vector<int64_t> padding = {2, 1};
156   constexpr int out_channels = 5;
157 
158   // make inputs
159   at::Tensor input = torch::randn(input_size);
160   at::Tensor weight = torch::randn(
161       {out_channels, input_size[1], kernel_size[0], kernel_size[1]});
162   at::Tensor bias = torch::randn({out_channels});
163 
164   // run forward eagerly
165   at::Tensor output = at::_slow_conv2d_forward(
166       input, weight, kernel_size, bias, stride, padding);
167 
168   // make grad_outputs
169   at::Tensor grad_output =
170       torch::randn_like(output, at::MemoryFormat::Preserve);
171 
172   // run backward eagerly
173   auto [grad_input, grad_weight, grad_bias] = at::_slow_conv2d_backward(
174       grad_output,
175       input,
176       weight,
177       kernel_size,
178       stride,
179       padding,
180       {true, true, true});
181 
182   // make JIT graph
183   auto graph = std::make_shared<Graph>();
184   auto ksz_val = graph->insertConstant(kernel_size);
185   auto kst_val = graph->insertConstant(stride);
186   auto pad_val = graph->insertConstant(padding);
187 
188   auto inputg = graph->addInput("self");
189   auto weightg = graph->addInput("weight");
190   auto biasg = graph->addInput("bias");
191 
192   Value* conv = graph->insert(
193       aten::_slow_conv2d_forward,
194       {inputg, weightg, ksz_val, biasg, kst_val, pad_val});
195   auto outputs = conv->node()->outputs();
196   for (auto output : outputs) {
197     graph->registerOutput(output);
198   }
199   LowerAllTuples(graph);
200   graph->lint();
201 
202   // differentiate JIT graph
203   EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
204   ConstantPropagation(graph);
205   auto grad_spec = differentiate(graph);
206   LowerGradOf(*grad_spec.df);
207 
208   // prepare JIT inputs / gradients
209   tensor_list tensors_in;
210   tensors_in.push_back(input);
211   tensors_in.push_back(weight);
212   tensors_in.push_back(bias);
213 
214   tensor_list tensor_grads_in;
215   tensor_grads_in.push_back(grad_output);
216 
217   // Get outputs from the interpreter
218   auto [tensors_out, tensor_grads_out] =
219       runGradient(grad_spec, tensors_in, tensor_grads_in);
220 
221   // prepare expected structs
222   tensor_list expected_tensors_out, expected_tensor_grads_out;
223   expected_tensors_out.push_back(output);
224   expected_tensor_grads_out.push_back(grad_input);
225   expected_tensor_grads_out.push_back(grad_weight);
226   expected_tensor_grads_out.push_back(grad_bias);
227 
228   // Compare results
229   assertAllClose(tensors_out, expected_tensors_out);
230   assertAllClose(tensor_grads_out, expected_tensor_grads_out);
231 }
232 
TEST(ATenNativeBatchNormTest,Basic)233 TEST(ATenNativeBatchNormTest, Basic) {
234   // aten::native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor
235   // running_mean, Tensor running_var, bool training, float momentum, float eps)
236   // -> (Tensor, Tensor, Tensor)
237   std::vector<int64_t> input_size = {4, 3, 15, 17}; // B x C x H x W
238   bool training = true;
239   float momentum = 0.9;
240   float eps = 1e-5;
241 
242   // make inputs
243   at::Tensor input = torch::randn(input_size);
244   at::Tensor weight = torch::randn({input_size[1]});
245   at::Tensor bias = torch::randn({input_size[1]});
246   at::Tensor running_mean = torch::randn({input_size[1]});
247   at::Tensor running_var = torch::randn({input_size[1]});
248 
249   // running_mean and running_var are changed in-place, so clone and send them
250   at::Tensor running_mean_eager = running_mean.clone();
251   at::Tensor running_var_eager = running_var.clone();
252   at::Tensor running_mean_jit = running_mean.clone();
253   at::Tensor running_var_jit = running_var.clone();
254 
255   // run forward eagerly
256   auto [output, savemean, saveinvstd] = at::native_batch_norm(
257       input,
258       weight,
259       bias,
260       running_mean_eager,
261       running_var_eager,
262       training,
263       momentum,
264       eps);
265 
266   // make grad_outputs
267   at::Tensor grad_output =
268       torch::randn_like(output, at::MemoryFormat::Preserve);
269   at::Tensor grad_savemean =
270       torch::zeros_like(savemean, at::MemoryFormat::Preserve);
271   at::Tensor grad_saveinvstd =
272       torch::zeros_like(saveinvstd, at::MemoryFormat::Preserve);
273 
274   // run backward eagerly
275   // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor
276   // weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor
277   // save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor,
278   // Tensor, Tensor)
279   auto [grad_input, grad_weight, grad_bias] = at::native_batch_norm_backward(
280       grad_output,
281       input,
282       weight,
283       running_mean_eager,
284       running_var_eager,
285       savemean,
286       saveinvstd,
287       training,
288       eps,
289       {true, true, true});
290 
291   // make JIT graph
292   auto graph = std::make_shared<Graph>();
293   auto training_val = graph->insertConstant(IValue(training));
294   auto momentum_val = graph->insertConstant(IValue(momentum));
295   auto eps_val = graph->insertConstant(IValue(eps));
296 
297   auto inputg = graph->addInput("self");
298   auto weightg = graph->addInput("weight");
299   auto biasg = graph->addInput("bias");
300   auto running_meang = graph->addInput("running_mean");
301   auto running_varg = graph->addInput("running_var");
302 
303   Value* bn = graph->insert(
304       aten::native_batch_norm,
305       {inputg,
306        weightg,
307        biasg,
308        running_meang,
309        running_varg,
310        training_val,
311        momentum_val,
312        eps_val});
313   auto outputs = bn->node()->outputs();
314   for (auto output : outputs) {
315     graph->registerOutput(output);
316   }
317   LowerAllTuples(graph);
318   graph->lint();
319 
320   // differentiate JIT graph
321   EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
322   ConstantPropagation(graph);
323   auto grad_spec = differentiate(graph);
324   LowerGradOf(*grad_spec.df);
325 
326   // prepare JIT inputs / gradients
327   tensor_list tensors_in;
328   tensors_in.push_back(input);
329   tensors_in.push_back(weight);
330   tensors_in.push_back(bias);
331   tensors_in.push_back(running_mean_jit);
332   tensors_in.push_back(running_var_jit);
333 
334   tensor_list tensor_grads_in;
335   tensor_grads_in.push_back(grad_output);
336   tensor_grads_in.push_back(grad_savemean);
337   tensor_grads_in.push_back(grad_saveinvstd);
338 
339   // Get outputs from the interpreter
340   auto [tensors_out, tensor_grads_out] =
341       runGradient(grad_spec, tensors_in, tensor_grads_in);
342 
343   // prepare expected structs
344   tensor_list expected_tensors_out, expected_tensor_grads_out;
345   expected_tensors_out.push_back(output);
346   expected_tensors_out.push_back(savemean);
347   expected_tensors_out.push_back(saveinvstd);
348   expected_tensors_out.push_back(running_mean_eager);
349   expected_tensors_out.push_back(running_var_eager);
350   expected_tensor_grads_out.push_back(grad_input);
351   expected_tensor_grads_out.push_back(grad_weight);
352   expected_tensor_grads_out.push_back(grad_bias);
353 
354   tensors_out.push_back(running_mean_jit);
355   tensors_out.push_back(running_var_jit);
356 
357   // Compare results
358   assertAllClose(tensors_out, expected_tensors_out);
359   assertAllClose(tensor_grads_out, expected_tensor_grads_out);
360 }
361 
TEST(CustomFusionTest,Basic)362 TEST(CustomFusionTest, Basic) {
363 #if defined(FBCODE_CAFFE2)
364   return;
365 #endif
366 
367   auto graph_string = R"IR(
368     graph(%0 : Float(2, 3, 4),
369           %1 : Float(2, 3, 4)):
370       %2 : Tensor = aten::mul(%0, %1)
371       %3 : Tensor = aten::mul(%2, %0)
372       return (%3))IR";
373   auto g = std::make_shared<Graph>();
374   torch::jit::parseIR(graph_string, g.get());
375 
376   torch::jit::overrideCanFuseOnCPU(true);
377   CustomFuseGraph(
378       g,
379       [](Node* n) { return n->kind() != prim::Param; },
380       Symbol::fromQualString("prim::FusionGroup"));
381   torch::jit::overrideCanFuseOnCPU(false);
382 
383   const auto& nodes = g->nodes();
384   auto fusion_group =
385       std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
386         return node->kind() == Symbol::fromQualString("prim::FusionGroup");
387       });
388   AT_ASSERT(fusion_group != nodes.end());
389 
390   auto subgraph = fusion_group->g(attr::Subgraph);
391   auto hits = 0;
392   // two multiplications
393   for (const auto& n : subgraph->nodes()) {
394     (void)n;
395     hits++;
396   }
397   AT_ASSERT(hits == 2);
398 }
399 
TEST(CustomFusionTest,NestedBlocks)400 TEST(CustomFusionTest, NestedBlocks) {
401 #if defined(FBCODE_CAFFE2)
402   return;
403 #endif
404 
405   auto graph_string = R"IR(
406   graph(%0 : Float(2, 3, 4),
407         %1 : Float(2, 3, 4),
408         %2 : Float(2, 3, 4)):
409     %3 : int = prim::Constant[value=1]()
410     %4 : Tensor = prim::If(%2)
411       block0():
412         %5 : Tensor = aten::mul(%0, %2)
413         %6 : Tensor = aten::mul(%5, %1)
414         -> (%6)
415       block1():
416         %7 : Tensor = aten::add(%0, %2, %3)
417         %8 : Tensor = aten::add(%7, %1, %3)
418         -> (%8)
419     %9 : Tensor = aten::add(%4, %2, %3)
420     return (%4))IR";
421   auto g = std::make_shared<Graph>();
422   torch::jit::parseIR(graph_string, g.get());
423 
424   CustomFuseGraph(
425       g,
426       [](Node* n) { return n->kind() == aten::mul; },
427       Symbol::fromQualString("prim::FusionGroup"));
428 
429   // Could be done in more efficient ways, but this is only a test.
430   std::function<bool(const Block*, Symbol)> dfs = [&](const Block* b,
431                                                       Symbol s) {
432     for (auto node : b->nodes()) {
433       if (node->kind() == s)
434         return true;
435       for (auto nested_b : node->blocks())
436         if (dfs(nested_b, s))
437           return true;
438     }
439     return false;
440   };
441 
442   AT_ASSERT(dfs(g->block(), Symbol::fromQualString("prim::FusionGroup")));
443 }
444 
445 static const auto cf_examples = R"JIT(
446   def if_test(a, b):
447       # FIXME: use 0 instead of a.
448       # c = 0
449       c = a
450       if bool(a < b):
451         c = b
452       else:
453         c = a
454       return c
455   def if_one(a, b):
456     c = b
457     if bool(a < b):
458       c = a
459     return c
460   def while_test(a, i):
461     while bool(i < 3):
462       a *= a
463       i += 1
464     return a
465 )JIT";
466 
TEST(ControlFlowTest,Basic)467 TEST(ControlFlowTest, Basic) {
468   auto cu = compile(cf_examples);
469 
470   auto run = [&](const std::string& name, std::vector<IValue> stack) {
471     auto graph = toGraphFunction(cu->get_function(name)).graph();
472     Code code(graph, "");
473     InterpreterState interp(code);
474     interp.run(stack);
475     return stack;
476   };
477 
478   auto L = [](int64_t l) { return IValue(scalar_to_tensor(at::Scalar(l))); };
479   auto V = [](IValue t) { return std::move(t).toTensor().item<int64_t>(); };
480   auto run_binary = [&](const std::string& name, int64_t a, int64_t b) {
481     return V(run(name, {L(a), L(b)})[0]);
482   };
483   ASSERT_EQ(2, run_binary("if_test", 1, 2));
484   ASSERT_EQ(3, run_binary("if_test", 3, 2));
485   ASSERT_EQ(2, run_binary("if_one", 2, 3));
486   ASSERT_EQ(2, run_binary("if_one", 3, 2));
487   ASSERT_EQ(256, run_binary("while_test", 2, 0));
488 }
489 
490 #if !(C10_ASAN_ENABLED || C10_UBSAN_ENABLED)
491 // This test fails vptr UBSAN checks
492 
TEST(ProtoTest,Basic)493 TEST(ProtoTest, Basic) {
494   ::ONNX_NAMESPACE::ModelProto proto;
495   proto.set_producer_name("foo");
496 }
497 #endif
498 
499 // test a few features that are not directly used in schemas yet
TEST(SchemaParserTest,NestedArrays)500 TEST(SchemaParserTest, NestedArrays) {
501   // nested arrays
502   auto s = parseSchema("at::what(int[][4] foo) -> ()");
503   ASSERT_TRUE(s.arguments().at(0).N() == 4);
504   ASSERT_TRUE(IntType::get()->isSubtypeOf(*s.arguments()
505                                                .at(0)
506                                                .type()
507                                                ->expectRef<ListType>()
508                                                .getElementType()
509                                                ->expectRef<ListType>()
510                                                .getElementType()));
511   auto s2 = parseSchema("at::what(int[][] foo) -> ()");
512   ASSERT_TRUE(IntType::get()->isSubtypeOf(*s2.arguments()
513                                                .at(0)
514                                                .type()
515                                                ->expectRef<ListType>()
516                                                .getElementType()
517                                                ->expectRef<ListType>()
518                                                .getElementType()));
519 }
520 
TEST(SchemaParserTest,OutVariant)521 TEST(SchemaParserTest, OutVariant) {
522   auto schema_with_out = parseSchema(
523       "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
524   ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
525   ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
526 
527   auto schema_without_out =
528       parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
529 
530   for (const auto& arg : schema_without_out.arguments()) {
531     ASSERT_TRUE(!arg.is_out());
532   }
533 
534   auto schema_with_is_write = parseSchema(
535       "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
536 
537   for (const auto& arg : schema_with_is_write.arguments()) {
538     ASSERT_TRUE(!arg.is_out());
539   }
540 }
541 
542 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest,NamedReturns)543 TEST(SchemaParserTest, NamedReturns) {
544   // named returns
545   parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
546   auto s3 =
547       parseSchema("at::what() -> (Tensor the_return, Tensor the_return2)");
548   ASSERT_TRUE(s3.returns().at(0).name() == "the_return");
549   ASSERT_TRUE(s3.returns().at(1).name() == "the_return2");
550 }
551 
TEST(SchemaParserTest,Futures)552 TEST(SchemaParserTest, Futures) {
553   // futures
554   auto s4 = parseSchema("at::what(Future(int) foo) -> ()");
555   ASSERT_TRUE(IntType::get()->isSubtypeOf(
556       *s4.arguments().at(0).type()->expectRef<FutureType>().getElementType()));
557 }
558 
TEST(SchemaParserTest,AnnotatedAliasSets)559 TEST(SchemaParserTest, AnnotatedAliasSets) {
560   // test tensor with annotated alias sets
561   parseSchema("at::what(Tensor(a) foo) -> (Tensor(a))");
562 }
563 
TEST(SchemaParserTest,TensorListAnnotatedAliasSets)564 TEST(SchemaParserTest, TensorListAnnotatedAliasSets) {
565   const auto s = parseSchema(
566       "at::foo(Tensor(a!) self, Tensor(b!)[] out)"
567       " -> ()");
568   const AliasInfo* selfAliasInfo = s.arguments().at(0).alias_info();
569   const AliasInfo* outAliasInfo = s.arguments().at(1).alias_info();
570   ASSERT_TRUE(
571       selfAliasInfo->beforeSets() ==
572       std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
573   ASSERT_TRUE(selfAliasInfo->isWrite());
574 
575   ASSERT_TRUE(outAliasInfo->isWrite());
576   ASSERT_TRUE(outAliasInfo->beforeSets().empty());
577   ASSERT_EQ(outAliasInfo->containedTypes().size(), 1);
578 
579   auto containedType = outAliasInfo->containedTypes()[0];
580 
581   ASSERT_TRUE(containedType.isWrite());
582   ASSERT_TRUE(
583       containedType.beforeSets() ==
584       std::unordered_set<Symbol>{Symbol::fromQualString("alias::b")});
585 }
586 
TEST(SchemaParserTest,AnnotatedAliasWithoutBeforeSet)587 TEST(SchemaParserTest, AnnotatedAliasWithoutBeforeSet) {
588   EXPECT_THAT(
589       []() { parseSchema("at::foo(Tensor(!) self) -> Tensor"); },
590       ::testing::Throws<std::runtime_error>(::testing::Property(
591           &std::runtime_error::what,
592           ::testing::HasSubstr("expected ident but found '!' here"))));
593 }
594 
TEST(SchemaParserTest,BeforeAfterSets)595 TEST(SchemaParserTest, BeforeAfterSets) {
596   const auto s = parseSchema(
597       "at::what(Tensor(b|c)[](a!) list, Tensor(c) element)"
598       " -> (Tensor(b|c)[](a!))");
599 
600   // The list itself is annotated with `a`
601   const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
602   ASSERT_NE(aliasInfo, nullptr);
603   ASSERT_TRUE(
604       aliasInfo->beforeSets() ==
605       std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
606   ASSERT_TRUE(aliasInfo->isWrite());
607 
608   // Check the contained types
609   ASSERT_TRUE(!aliasInfo->containedTypes().empty());
610   const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
611   const auto expected = std::unordered_set<Symbol>{
612       Symbol::fromQualString("alias::b"),
613       Symbol::fromQualString("alias::c"),
614   };
615   ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
616   ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
617   ASSERT_FALSE(containedAliasInfo.isWrite());
618 }
619 
TEST(SchemaParserTest,BeforeAfterSets2)620 TEST(SchemaParserTest, BeforeAfterSets2) {
621   const auto s = parseSchema(
622       "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
623       " -> (Tensor(b|c)[](a!))");
624 
625   // The list itself is annotated with `a`
626   const AliasInfo* aliasInfo = s.arguments().at(0).alias_info();
627   ASSERT_NE(aliasInfo, nullptr);
628   ASSERT_EQ(
629       aliasInfo->beforeSets(),
630       std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
631   ASSERT_EQ(
632       aliasInfo->afterSets(),
633       std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
634   ASSERT_TRUE(aliasInfo->isWrite());
635   ASSERT_EQ(aliasInfo->containedTypes().size(), 1);
636 
637   // Check the contained types
638   ASSERT_TRUE(!aliasInfo->containedTypes().empty());
639   const auto& containedAliasInfo = aliasInfo->containedTypes()[0];
640   const auto expectedBefore = std::unordered_set<Symbol>{
641       Symbol::fromQualString("alias::b"),
642   };
643   const auto expectedAfter = std::unordered_set<Symbol>{
644       Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c")};
645   ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
646   ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
647   ASSERT_FALSE(containedAliasInfo.isWrite());
648 }
649 
TEST(TopologicalIndexTest,Basic)650 TEST(TopologicalIndexTest, Basic) {
651   Graph graph;
652   auto node1 = graph.create(prim::AutogradZero);
653   auto node2 = graph.create(prim::AutogradZero);
654   auto node3 = graph.create(prim::AutogradZero);
655   auto node4 = graph.create(prim::AutogradZero);
656 
657   graph.appendNode(node4);
658   graph.prependNode(node1);
659   node2->insertAfter(node1);
660   node3->insertBefore(node4);
661 
662   // nodes should be in numerical order
663   ASSERT_TRUE(node1->isBefore(node2));
664   ASSERT_TRUE(node1->isBefore(node3));
665   ASSERT_TRUE(node1->isBefore(node4));
666   ASSERT_TRUE(node2->isAfter(node1));
667   ASSERT_TRUE(node2->isBefore(node3));
668   ASSERT_TRUE(node2->isBefore(node4));
669   ASSERT_FALSE(node3->isBefore(node1));
670   ASSERT_FALSE(node3->isBefore(node2));
671   ASSERT_FALSE(node3->isAfter(node4));
672 
673   // Built up a block structure
674   //  node3
675   //   /\        ...
676   //  A  B     block1
677   //      \      ...
678   //      C    block2
679   auto block1 = node3->addBlock();
680   auto A = graph.create(prim::AutogradZero);
681   block1->appendNode(A);
682   auto B = graph.create(prim::AutogradZero);
683   block1->appendNode(B);
684   auto block2 = B->addBlock();
685   auto C = graph.create(prim::AutogradZero);
686   block2->appendNode(C);
687 
688   // Check isAfter on different block levels
689   ASSERT_TRUE(node1->isBefore(A));
690   ASSERT_TRUE(A->isBefore(B));
691   ASSERT_TRUE(A->isBefore(C));
692 
693   // make sure things don't blow up on deletions
694   node2->destroy();
695   auto node2p = graph.create(prim::AutogradZero);
696   node2p->insertAfter(node1);
697   ASSERT_TRUE(node1->isBefore(node2p));
698   ASSERT_TRUE(node2p->isBefore(node3));
699 }
700 
TEST(TopologicalIndexTest,Reindex)701 TEST(TopologicalIndexTest, Reindex) {
702   // Induce reindexing to test that path
703   Graph graph;
704   std::map<size_t, Node*> nodes;
705 
706   auto anchor = graph.create(prim::AutogradZero);
707   graph.appendNode(anchor);
708   // Inserting to the same place a lot will trigger reindexing
709   for (auto i = 0; i < 100; ++i) {
710     auto n = graph.create(prim::AutogradZero);
711     n->insertAfter(anchor);
712     nodes[i] = n;
713   }
714 
715   // Nodes should be in reverse order
716   for (auto i = 0; i < 100; ++i) {
717     for (auto j = i + 1; j < 100; ++j) {
718       ASSERT_TRUE(nodes[i]->isAfter(nodes[j]));
719     }
720   }
721 }
722 
invokeTestRecordFunction(at::Tensor & t)723 at::Tensor invokeTestRecordFunction(at::Tensor& t) {
724   RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
725 
726   auto t2 = t.pow(2);
727   return t2;
728 }
729 
730 static const auto invokeTestRecordFunction_JIT = R"JIT(
731   def foo(self, t):
732     t2 = t.pow(2)
733     return t2
734 
735   def forward(self, t):
736     return self.foo(t)
737 )JIT";
738 
invokeTestRecordFunctionJIT(at::Tensor & t)739 at::Tensor invokeTestRecordFunctionJIT(at::Tensor& t) {
740   RECORD_FUNCTION("test", std::vector<c10::IValue>({t}));
741 
742   auto module = std::make_shared<script::Module>(
743       "RecordFunctionTestModule", std::make_shared<script::CompilationUnit>());
744   module->define(invokeTestRecordFunction_JIT);
745   return module->forward({t}).toTensor();
746 }
747 
748 using TracedTestValues =
749     std::vector<std::tuple<std::string, std::vector<std::vector<int64_t>>>>;
750 
checkTracedInputs(const TracedTestValues & inputs)751 void checkTracedInputs(const TracedTestValues& inputs) {
752   bool found_test = false;
753   bool found_pow = false;
754   bool found_mul = false;
755   for (const auto& input : inputs) {
756     const auto& fn = std::get<0>(input);
757     const auto& sizes = std::get<1>(input);
758 
759     if (fn == "test") {
760       found_test = true;
761       TORCH_CHECK(sizes.size() == 1);
762       TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
763     } else if (fn == "aten::pow") {
764       found_pow = true;
765       TORCH_CHECK(sizes.size() == 2);
766       TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
767       TORCH_CHECK(sizes[1].empty());
768     } else if (fn == "aten::mul") {
769       found_mul = true;
770       TORCH_CHECK(sizes.size() > 1);
771       TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
772     }
773   }
774   TORCH_CHECK(found_test);
775   TORCH_CHECK(found_pow);
776   TORCH_CHECK(found_mul);
777 }
778 
checkTracedOutputs(const TracedTestValues & outputs)779 void checkTracedOutputs(const TracedTestValues& outputs) {
780   bool found_test = false;
781   bool found_pow = false;
782   bool found_mul = false;
783   for (const auto& output : outputs) {
784     const auto& fn = std::get<0>(output);
785     const auto& sizes = std::get<1>(output);
786 
787     if (fn == "test") {
788       found_test = true;
789       TORCH_CHECK(sizes.empty());
790     } else if (fn == "aten::pow") {
791       found_pow = true;
792       TORCH_CHECK(sizes.size() == 1);
793       TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
794     } else if (fn == "aten::mul") {
795       found_mul = true;
796       TORCH_CHECK(sizes.size() == 1);
797       TORCH_CHECK(sizes[0] == std::vector<int64_t>({1, 2, 3}));
798     }
799   }
800   TORCH_CHECK(found_test);
801   TORCH_CHECK(found_pow);
802   TORCH_CHECK(found_mul);
803 }
804 
805 static bool bad_scope = false;
806 template <RecordScope scope, size_t* cnt>
checkScopeCallback(const at::RecordFunction & fn)807 std::unique_ptr<at::ObserverContext> checkScopeCallback(
808     const at::RecordFunction& fn) {
809   if (fn.scope() == scope) {
810     ++(*cnt);
811   } else {
812     bad_scope = true;
813   }
814   return nullptr;
815 }
816 
817 template <RecordScope scope, size_t* cnt>
pushScopedCallback()818 void pushScopedCallback() {
819   at::addGlobalCallback(
820       at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
821           .scopes({scope}));
822 }
823 
824 // These cannot be function-local because that would prohibit them
825 // from being used as template arguments prior to C++17.
826 static size_t fun_cnt;
827 static size_t ts_fun_cnt;
828 static size_t user_scope_cnt;
829 
checkScopeCallbacks()830 void checkScopeCallbacks() {
831   static bool found_function_scope;
832   static bool found_method_scope;
833   static bool found_user_scope;
834   found_function_scope = false;
835   found_method_scope = false;
836   found_user_scope = false;
837   at::addGlobalCallback(at::RecordFunctionCallback(
838       [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
839         if (fn.scope() == at::RecordScope::FUNCTION &&
840             std::string(fn.name()) == "test_function") {
841           found_function_scope = true;
842         }
843         if (fn.scope() == at::RecordScope::TORCHSCRIPT_FUNCTION &&
844             std::string(fn.name()) == "test_method") {
845           found_method_scope = true;
846         }
847         if (fn.scope() == at::RecordScope::USER_SCOPE &&
848             std::string(fn.name()) == "test_user_scope") {
849           found_user_scope = true;
850         }
851         return nullptr;
852       }));
853 
854   bad_scope = false;
855   fun_cnt = 0;
856   pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
857   ts_fun_cnt = 0;
858   pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
859   user_scope_cnt = 0;
860   pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
861 
862   TORCH_CHECK(at::hasCallbacks());
863 
864   {
865     RECORD_TORCHSCRIPT_FUNCTION("test_method", {});
866     { RECORD_FUNCTION("test_function", {}); }
867     { RECORD_USER_SCOPE("test_user_scope"); }
868   }
869 
870   TORCH_CHECK(!bad_scope);
871   TORCH_CHECK(fun_cnt == 1);
872   TORCH_CHECK(ts_fun_cnt == 1);
873   TORCH_CHECK(user_scope_cnt == 1);
874 
875   TORCH_CHECK(found_function_scope);
876   TORCH_CHECK(found_method_scope);
877   TORCH_CHECK(found_user_scope);
878 }
879 
880 static TracedTestValues traced_inputs;
881 static TracedTestValues traced_outputs;
882 static std::unordered_set<std::string> ts_input_names;
883 static std::unordered_set<std::string> ts_output_names;
884 
tracedInputsCallback(const RecordFunction & fn)885 std::unique_ptr<at::ObserverContext> tracedInputsCallback(
886     const RecordFunction& fn) {
887   if (fn.scope() == RecordScope::FUNCTION) {
888     auto inputs = fn.inputs();
889     std::vector<std::vector<int64_t>> sizes;
890     for (const auto& input : inputs) {
891       if (input.isTensor()) {
892         sizes.push_back(input.toTensor().sizes().vec());
893       } else if (input.isScalar()) {
894         // NOLINTNEXTLINE(modernize-use-emplace)
895         sizes.push_back(std::vector<int64_t>());
896       }
897     }
898     traced_inputs.push_back(std::make_tuple(fn.name(), sizes));
899   } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
900     ts_input_names.insert(fn.name());
901   }
902   return nullptr;
903 }
904 
tracedOutputsCallback(const RecordFunction & fn,ObserverContext * ctx_ptr)905 void tracedOutputsCallback(const RecordFunction& fn, ObserverContext* ctx_ptr) {
906   if (fn.scope() == RecordScope::FUNCTION) {
907     auto outputs = fn.outputs();
908     std::vector<std::vector<int64_t>> sizes;
909     for (const auto& output : outputs) {
910       if (output.isTensor()) {
911         sizes.push_back(output.toTensor().sizes().vec());
912       } else if (output.isScalar()) {
913         sizes.emplace_back();
914       }
915     }
916     traced_outputs.push_back(std::make_tuple(fn.name(), sizes));
917   } else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
918     ts_output_names.insert(fn.name());
919   }
920 }
921 
TEST(RecordFunctionTest,TracedTestInputsOutputs)922 TEST(RecordFunctionTest, TracedTestInputsOutputs) {
923   // disabling the inlining of method calls
924   GraphOptimizerEnabledGuard opt_guard(false);
925 
926   // [(fn, [[sizes], [sizes], ...]), ...]
927   addGlobalCallback(
928       RecordFunctionCallback(tracedInputsCallback, tracedOutputsCallback)
929           .needsInputs(true)
930           .needsOutputs(true));
931 
932   TracedTestValues eager_inputs, eager_outputs, jit_inputs, jit_outputs;
933   {
934     auto t = torch::randn({1, 2, 3}, at::kCPU);
935     t.set_requires_grad(true);
936     auto t2 = invokeTestRecordFunction(t);
937     t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
938     eager_inputs = traced_inputs;
939     eager_outputs = traced_outputs;
940     traced_inputs.clear();
941     traced_outputs.clear();
942 
943     TORCH_CHECK(ts_input_names.empty());
944     TORCH_CHECK(ts_output_names.empty());
945 
946     t = torch::randn({1, 2, 3}, at::kCPU);
947     t.set_requires_grad(true);
948     t2 = invokeTestRecordFunctionJIT(t);
949     t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
950     jit_inputs = traced_inputs;
951     jit_outputs = traced_outputs;
952     traced_inputs.clear();
953     traced_outputs.clear();
954   }
955 
956   TORCH_CHECK(ts_input_names.find("forward") != ts_input_names.end());
957   TORCH_CHECK(ts_input_names.find("foo") != ts_input_names.end());
958   TORCH_CHECK(ts_output_names.find("forward") != ts_output_names.end());
959   TORCH_CHECK(ts_output_names.find("foo") != ts_output_names.end());
960 
961   checkTracedInputs(eager_inputs);
962   checkTracedOutputs(eager_outputs);
963   checkTracedInputs(jit_inputs);
964   checkTracedOutputs(jit_outputs);
965   at::clearCallbacks();
966 }
967 
968 static int sampled_cb_ctr = 0;
sampledCallback(const RecordFunction & fn)969 std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
970   if (std::string(fn.name()) == "test") {
971     ++sampled_cb_ctr;
972   }
973   return nullptr;
974 }
975 
976 static int non_sampled_cb_ctr = 0;
nonSampledCallback(const RecordFunction & fn)977 std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
978   if (std::string(fn.name()) == "test") {
979     ++non_sampled_cb_ctr;
980   }
981   return nullptr;
982 }
983 
TEST(RecordFunctionTest,SampledCallbacks)984 TEST(RecordFunctionTest, SampledCallbacks) {
985   // disabling the inlining of method calls
986   GraphOptimizerEnabledGuard opt_guard(false);
987 
988   // test sampled callbacks
989   sampled_cb_ctr = 0;
990   auto setup_sampled_callback = [](double sampling_prob) {
991     return addGlobalCallback(
992         RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
993   };
994 
995   addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
996 
997   auto handle = setup_sampled_callback(0.5);
998 
999   auto run_test_function = []() {
1000     auto t = torch::randn({1, 2, 3}, at::kCPU);
1001     for (auto k = 0; k < 1000; k++) {
1002       invokeTestRecordFunction(t);
1003     }
1004   };
1005 
1006   run_test_function();
1007   TORCH_CHECK(non_sampled_cb_ctr == 1000);
1008   TORCH_CHECK(sampled_cb_ctr > 0 && sampled_cb_ctr < 1000);
1009 
1010   sampled_cb_ctr = 0;
1011   removeCallback(handle);
1012   handle = setup_sampled_callback(0.0);
1013   run_test_function();
1014 
1015   TORCH_CHECK(non_sampled_cb_ctr == 2000);
1016   TORCH_CHECK(sampled_cb_ctr == 0);
1017 
1018   sampled_cb_ctr = 0;
1019   removeCallback(handle);
1020   // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
1021   handle = setup_sampled_callback(1.0);
1022   run_test_function();
1023 
1024   TORCH_CHECK(non_sampled_cb_ctr == 3000);
1025   TORCH_CHECK(sampled_cb_ctr == 1000);
1026   clearCallbacks();
1027 
1028   // test the scope of the callbacks
1029   checkScopeCallbacks();
1030   clearCallbacks();
1031 }
1032 
TEST(RecordFunctionTest,RecordFunctionGuard)1033 TEST(RecordFunctionTest, RecordFunctionGuard) {
1034   // disabling the inlining of method calls
1035   GraphOptimizerEnabledGuard opt_guard(false);
1036 
1037   static std::vector<std::string> fn_names;
1038   static std::mutex guard_mtx;
1039 
1040   // check record function guard
1041   addGlobalCallback(RecordFunctionCallback(
1042       [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1043         std::lock_guard<std::mutex> lock(guard_mtx);
1044         // NOLINTNEXTLINE(modernize-use-emplace)
1045         fn_names.push_back(fn.name());
1046         return nullptr;
1047       }));
1048   {
1049     RecordFunctionGuard g1(false);
1050     {
1051       RECORD_USER_SCOPE("A");
1052       {
1053         RecordFunctionGuard g2(true);
1054         RECORD_USER_SCOPE("B");
1055         {
1056           DisableRecordFunctionGuard g3;
1057           RECORD_USER_SCOPE("C");
1058         }
1059       }
1060       { RECORD_USER_SCOPE("D"); }
1061     }
1062   }
1063   TORCH_CHECK(fn_names.size() == 1);
1064   TORCH_CHECK(fn_names[0] == "B");
1065   clearCallbacks();
1066 }
1067 
1068 static std::vector<size_t> ids;
1069 
1070 template <size_t id>
add_remove_test_add_cb()1071 auto add_remove_test_add_cb() {
1072   return addGlobalCallback(RecordFunctionCallback(
1073       [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1074         ids.push_back(id);
1075         return nullptr;
1076       }));
1077 }
1078 
TEST(RecordFunctionTest,Callbacks)1079 TEST(RecordFunctionTest, Callbacks) {
1080   // disabling the inlining of method calls
1081   GraphOptimizerEnabledGuard opt_guard(false);
1082 
1083   auto h1 = add_remove_test_add_cb<1>();
1084   add_remove_test_add_cb<2>();
1085   auto h3 = add_remove_test_add_cb<3>();
1086 
1087   { RECORD_USER_SCOPE("test"); }
1088 
1089   TORCH_CHECK(ids.size() == 3);
1090   TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1091   TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1092   TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1093 
1094   ids.clear();
1095   removeCallback(h1);
1096 
1097   { RECORD_USER_SCOPE("test"); }
1098 
1099   TORCH_CHECK(ids.size() == 2);
1100   TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1101   TORCH_CHECK(std::find(ids.begin(), ids.end(), 3) != ids.end());
1102 
1103   ids.clear();
1104   removeCallback(h3);
1105 
1106   { RECORD_USER_SCOPE("test"); }
1107 
1108   TORCH_CHECK(ids.size() == 1);
1109   TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1110 
1111   clearCallbacks();
1112 
1113   // thread local / global callbacks
1114 
1115   ids.clear();
1116   add_remove_test_add_cb<1>();
1117 
1118   { RECORD_USER_SCOPE("test"); }
1119 
1120   TORCH_CHECK(ids.size() == 1);
1121   TORCH_CHECK(ids[0] == 1);
1122   ids.clear();
1123 
1124   auto th = std::thread([]() {
1125     addThreadLocalCallback(RecordFunctionCallback(
1126         [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1127           ids.push_back(2);
1128           return nullptr;
1129         }));
1130 
1131     { RECORD_USER_SCOPE("test_thread"); }
1132   });
1133   th.join();
1134   TORCH_CHECK(ids.size() == 2);
1135   TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1136   TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1137   ids.clear();
1138 
1139   { RECORD_USER_SCOPE("test"); }
1140 
1141   TORCH_CHECK(ids.size() == 1);
1142   TORCH_CHECK(ids[0] == 1);
1143   ids.clear();
1144 
1145   clearCallbacks();
1146 
1147   // START: thread local / global context check callbacks
1148   struct TestContext : public ObserverContext {
1149     int a{0};
1150     std::string b;
1151   };
1152   ids.clear();
1153   { // START: global test
1154     addGlobalCallback(RecordFunctionCallback(
1155         [](const RecordFunction&
1156            /* unused */) -> std::unique_ptr<at::ObserverContext> {
1157           auto ctx = std::make_unique<TestContext>();
1158           ctx->a = 123;
1159           ctx->b = "test_str";
1160           ids.push_back(1);
1161           return ctx;
1162         },
1163         [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1164           auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1165           TORCH_CHECK(ctx != nullptr);
1166           TORCH_CHECK(ctx->a == 123);
1167           TORCH_CHECK(ctx->b == "test_str");
1168         }));
1169 
1170     { RECORD_USER_SCOPE("test"); }
1171 
1172     TORCH_CHECK(ids.size() == 1);
1173     TORCH_CHECK(ids[0] == 1);
1174     ids.clear();
1175   } // END: global test
1176   { // START: thread local test
1177     auto ctx_th = std::thread([]() {
1178       const std::string test_str = "test thread str";
1179       addThreadLocalCallback(RecordFunctionCallback(
1180           [](const RecordFunction&
1181              /* unused */) -> std::unique_ptr<at::ObserverContext> {
1182             auto ctx = std::make_unique<TestContext>();
1183             ctx->a = 234;
1184             ctx->b = "test_thread_str";
1185             ids.push_back(2);
1186             return ctx;
1187           },
1188           [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1189             auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1190             TORCH_CHECK(ctx_ptr != nullptr);
1191             TORCH_CHECK(ctx->a == 234);
1192             TORCH_CHECK(ctx->b == "test_thread_str");
1193           }));
1194 
1195       // Will call both global and thread local callbacks.
1196       { RECORD_USER_SCOPE("test_thread"); }
1197     });
1198     ctx_th.join();
1199     TORCH_CHECK(ids.size() == 2);
1200     TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1201     TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1202     ids.clear();
1203   } // END: thread local test
1204 
1205   clearCallbacks();
1206 }
1207 
TEST(RecordFunctionTest,ShouldRun)1208 TEST(RecordFunctionTest, ShouldRun) {
1209   // disabling the inlining of method calls
1210   GraphOptimizerEnabledGuard opt_guard(false);
1211 
1212   static bool ran = false;
1213   auto handle = addGlobalCallback(RecordFunctionCallback(
1214       [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1215         ran = true;
1216         return nullptr;
1217       }));
1218 
1219   { RECORD_USER_SCOPE("test"); }
1220 
1221   EXPECT_TRUE(ran) << "first run didn't happen";
1222   ran = false;
1223 
1224   disableCallback(handle);
1225 
1226   { RECORD_USER_SCOPE("test"); }
1227 
1228   EXPECT_FALSE(ran) << "second run happened but shouldn't have";
1229   ran = false;
1230 
1231   reenableCallback(handle);
1232 
1233   { RECORD_USER_SCOPE("test"); }
1234 
1235   EXPECT_TRUE(ran) << "run after re-enable didn't happen";
1236   ran = false;
1237 
1238   clearCallbacks();
1239 }
1240 
TEST(RecordFunctionTest,Basic)1241 TEST(RecordFunctionTest, Basic) {
1242   // disabling the inlining of method calls
1243   GraphOptimizerEnabledGuard opt_guard(false);
1244 
1245   static std::string recorded_op;
1246   static bool has_ids = false;
1247 
1248   // test propagation of TLS callbacks
1249   std::thread t([]() {
1250     RecordFunctionGuard enable_rec_fn;
1251     auto handle = addThreadLocalCallback(RecordFunctionCallback(
1252         [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1253           recorded_op = fn.name();
1254           return nullptr;
1255         }));
1256     ThreadLocalState state;
1257     std::thread t_child([state]() {
1258       ThreadLocalStateGuard g_tls(state);
1259       RECORD_USER_SCOPE("test_in_thread");
1260     });
1261     t_child.join();
1262     EXPECT_EQ(recorded_op, "test_in_thread");
1263     removeCallback(handle);
1264   });
1265   t.join();
1266   clearCallbacks();
1267 
1268   // test set ids
1269   addGlobalCallback(
1270       RecordFunctionCallback(
1271           [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1272             has_ids = fn.handle() > 0;
1273             return nullptr;
1274           })
1275           .needsIds(true));
1276   { RECORD_USER_SCOPE("test"); }
1277   TORCH_CHECK(has_ids);
1278   clearCallbacks();
1279   has_ids = false;
1280   addGlobalCallback(RecordFunctionCallback(
1281       [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
1282         has_ids = fn.handle() > 0;
1283         return nullptr;
1284       }));
1285   { RECORD_USER_SCOPE("test"); }
1286   TORCH_CHECK(!has_ids);
1287   clearCallbacks();
1288 }
1289 
TEST(RecordFunctionTest,OperatorNameOverload)1290 TEST(RecordFunctionTest, OperatorNameOverload) {
1291   static std::set<std::string> operator_names;
1292   at::addGlobalCallback(at::RecordFunctionCallback(
1293                             [](const at::RecordFunction& fn)
1294                                 -> std::unique_ptr<at::ObserverContext> {
1295                               std::optional<c10::OperatorName> op_name =
1296                                   fn.operator_name();
1297                               if (op_name.has_value()) {
1298                                 operator_names.insert(c10::toString(*op_name));
1299                               } else {
1300                                 operator_names.insert("No Operator Name");
1301                               }
1302                               return nullptr;
1303                             })
1304                             .scopes({at::RecordScope::FUNCTION}));
1305   auto t = torch::randn({1, 2, 3}, at::kCPU);
1306   t.set_requires_grad(false);
1307   auto t2 = t.pow(2);
1308 
1309   at::clearCallbacks();
1310   EXPECT_TRUE(operator_names.count("No Operator Name") == 0)
1311       << "Expected that all traced operators had an associated OperatorName object";
1312   EXPECT_TRUE(operator_names.count("aten::randn") == 1)
1313       << "Expected aten::randn to have been called and recorded, but it was not";
1314   EXPECT_TRUE(operator_names.count("aten::pow.Tensor_Scalar") == 1)
1315       << "Expected aten::pow.Tensor_Scalar to have been called and recorded, but it was not";
1316 }
1317 
1318 class TestThreadLocalDebugInfo : public c10::DebugInfoBase {
1319  public:
getModelId() const1320   int getModelId() const {
1321     return model_id_;
1322   }
1323 
setModelId(int model_id)1324   void setModelId(int model_id) {
1325     model_id_ = model_id;
1326   }
1327 
1328   // NOLINTNEXTLINE(modernize-use-equals-default)
~TestThreadLocalDebugInfo()1329   virtual ~TestThreadLocalDebugInfo() override {}
1330 
1331  private:
1332   int model_id_ = 0;
1333 };
1334 
checkDebugInfo(c10::DebugInfoKind kind,int model_id)1335 void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
1336   auto* debug_info = c10::ThreadLocalDebugInfo::get(kind);
1337   TORCH_CHECK(debug_info != nullptr);
1338   auto* test_debug_info = dynamic_cast<TestThreadLocalDebugInfo*>(debug_info);
1339   TORCH_CHECK(test_debug_info != nullptr);
1340   TORCH_CHECK(test_debug_info->getModelId() == model_id);
1341 }
1342 
TEST(ThreadLocalDebugInfoTest,Basic)1343 TEST(ThreadLocalDebugInfoTest, Basic) {
1344   static std::atomic<bool> done{false};
1345 
1346   TORCH_CHECK(
1347       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1348   auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1349   debug_info->setModelId(42);
1350   {
1351     c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1352     checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1353   }
1354 
1355   // check that thread local debug info is propagated through fork calls
1356   TORCH_CHECK(
1357       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1358   {
1359     c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1360     at::launch([]() {
1361       checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1362       done = true;
1363     });
1364   }
1365   while (!done) {
1366   }
1367 
1368   // check that thread local debug info is propagated through backward pass
1369   TORCH_CHECK(
1370       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1371   done = false;
1372   auto handle = addGlobalCallback(RecordFunctionCallback(
1373       [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
1374         checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1375         done = true;
1376         return nullptr;
1377       }));
1378   {
1379     c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1380     auto t = torch::randn({1, 2, 3}, at::kCPU);
1381     t.set_requires_grad(true);
1382     auto t2 = t.pow(2);
1383     t2.backward(torch::ones_like(t2, at::MemoryFormat::Preserve));
1384   }
1385   removeCallback(handle);
1386   TORCH_CHECK(done);
1387 
1388   // check nested debug info
1389   TORCH_CHECK(
1390       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
1391   {
1392     c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
1393     {
1394       checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1395       {
1396         auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
1397         debug_info->setModelId(314);
1398         c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO_2, debug_info);
1399         {
1400           checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1401           checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1402           done = false;
1403           at::launch([]() {
1404             checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
1405             checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
1406             done = true;
1407           });
1408           while (!done) {
1409           }
1410         }
1411       }
1412     }
1413   }
1414 }
1415 
TEST(TestSymIntArrayRef,BasicConversion)1416 TEST(TestSymIntArrayRef, BasicConversion) {
1417   const size_t X = 2, Y = 4, Z = 5;
1418   std::vector<int64_t> tgt_size_v{2, 4, 5};
1419   std::vector<c10::SymInt> tgt_size({SymInt(X), SymInt(Y), SymInt(Z)});
1420   auto a = at::randn({1, 4, 1}, at::kCPU);
1421   auto b = a.expand_symint(tgt_size);
1422   auto c = a.expand(tgt_size_v);
1423   ASSERT_TRUE(torch::allclose(b, c));
1424 }
1425 
TEST(TestSymInt,NarrowCopyWithSymbolicInt)1426 TEST(TestSymInt, NarrowCopyWithSymbolicInt) {
1427   static const size_t LENGTH = 5;
1428   auto a = at::randn({10}, at::kCPU);
1429   c10::SymInt si(LENGTH);
1430   auto b = a.narrow_copy_symint(0, 0, si);
1431   auto c = a.narrow(0, 0, LENGTH);
1432   ASSERT_TRUE(torch::allclose(b, c));
1433 }
1434 
TEST(TestSymInt,NarrowCopy)1435 TEST(TestSymInt, NarrowCopy) {
1436   static const size_t LENGTH = 5;
1437   auto a = at::randn({10}, at::kCPU);
1438   auto b = a.narrow_copy(0, 0, LENGTH);
1439   auto c = a.narrow(0, 0, LENGTH);
1440   ASSERT_TRUE(torch::allclose(b, c));
1441 }
1442 
TEST(TestSymInt,AddSymbolicInt)1443 TEST(TestSymInt, AddSymbolicInt) {
1444   c10::SymInt a(5);
1445   c10::SymInt b(3);
1446   ASSERT_TRUE((a + b).expect_int() == 8);
1447 }
1448 
TEST(FallbackGraphsTest,Basic)1449 TEST(FallbackGraphsTest, Basic) {
1450   auto x = at::randn({1}, at::kCPU);
1451   auto y = at::randn({1}, at::kCPU);
1452   auto stack = createStack({x.clone(), y.clone()});
1453 
1454   auto graph_string = R"IR(
1455     graph(%0 : Float(1),
1456           %1 : Float(1)):
1457       %2 : Tensor = aten::mul(%0, %1)
1458       %3 : Tensor = aten::mul(%2, %0)
1459       return (%3))IR";
1460   auto graph = std::make_shared<Graph>();
1461   torch::jit::parseIR(graph_string, graph.get());
1462 
1463   {
1464     Code code(graph, "");
1465     InterpreterState interpreter{code};
1466     interpreter.run(stack);
1467   }
1468   at::Tensor et;
1469   pop(stack, et);
1470   float ef = et.item<float>();
1471   {
1472     EnableProfilingGuard epg;
1473     GraphFunction f("fallbackGraphs", graph, nullptr);
1474     for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) {
1475       stack.emplace_back(x.clone());
1476       stack.emplace_back(y.clone());
1477       if (i == getNumProfiledRuns()) {
1478         // we will be modifying a profiled graph
1479         // before ProfilingGraphExecutor
1480         // will optimize it in the next iteration
1481         auto opt_graph = lastExecutedOptimizedGraph();
1482         // this is safe to do since we are done profiling
1483         ProfilingRecord::removeProfileCounter(opt_graph->block());
1484         replaceBlockWithFallbackGraph(opt_graph->block(), opt_graph->inputs());
1485         auto it = opt_graph->block()->nodes().begin();
1486         ASSERT_EQ(it->kind(), prim::FallbackGraph);
1487         auto fallback = *it++;
1488         ASSERT_EQ(it, opt_graph->block()->nodes().end());
1489         ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph));
1490         testing::FileCheck()
1491             .check("Tensor = aten::mul")
1492             ->check("Tensor = aten::mul")
1493             ->run(*fallback->g(attr::Subgraph));
1494       }
1495       f.run(stack);
1496       at::Tensor at;
1497       pop(stack, at);
1498       float af = at.item<float>();
1499       ASSERT_EQ(af, ef);
1500     }
1501 
1502     auto opt_graph = lastExecutedOptimizedGraph();
1503     testing::FileCheck()
1504         .check("(Tensor) = prim::CallFunction")
1505         ->run(*opt_graph);
1506   }
1507 }
1508 
1509 // TODO this test wasn't running and is broken.
1510 // TEST(AutogradProfilerTest, Basic) {
1511 //   constexpr int batch_size = 4;
1512 //   constexpr int input_size = 256;
1513 //   constexpr int seq_len = 32;
1514 
1515 //   int hidden_size = 2 * input_size;
1516 //   auto input = torch::randn({seq_len, batch_size, input_size}, at::kCPU);
1517 //   auto hx = torch::randn({batch_size, hidden_size}, at::kCPU);
1518 //   auto cx = torch::randn({batch_size, hidden_size}, at::kCPU);
1519 //   auto w_ih = t_def(torch::randn({4 * hidden_size, input_size}, at::kCPU));
1520 //   auto w_hh = t_def(torch::randn({4 * hidden_size, hidden_size}, at::kCPU));
1521 
1522 //   std::stringstream ss;
1523 //   {
1524 //     RecordProfile guard(ss);
1525 //     for (size_t i = 0; i < 100; ++i) {
1526 //       std::tie(hx, cx) = lstm(input[0], hx, cx, w_ih, w_hh);
1527 //     }
1528 //   }
1529 
1530 //   std::string result = ss.str();
1531 //   size_t count = 0;
1532 //   for (size_t pos = 0; (pos = result.find("tanh", pos)) != std::string::npos;
1533 //        count++, pos++) {
1534 //   }
1535 //   ASSERT_EQ((count, 200);
1536 // }
1537 
TEST(NoneSchemaMatchTest,Basic)1538 TEST(NoneSchemaMatchTest, Basic) {
1539   RegisterOperators reg({
1540       Operator(
1541           "prim::test_none() -> int?",
1542           [](Stack& stack) { push(stack, IValue()); },
1543           aliasAnalysisFromSchema()),
1544       Operator(
1545           "prim::is_none(int? a) -> bool",
1546           [](Stack& stack) {
1547             IValue a = pop(stack);
1548             if (a.isNone()) {
1549               push(stack, true);
1550             } else {
1551               push(stack, false);
1552             }
1553           },
1554           aliasAnalysisFromSchema()),
1555   });
1556 
1557   // Constant propagation will run test_none and produce a None,
1558   // testing that its type is set appropriately and schema matching  doesn't
1559   // fail when running is_none
1560 
1561   auto r = std::make_shared<Graph>();
1562   auto& g = *r;
1563   auto opt_int = g.insert(Symbol::fromQualString("prim::test_none"), {});
1564   auto out_bool = g.insert(Symbol::fromQualString("prim::is_none"), {opt_int});
1565   g.registerOutput(out_bool);
1566   ConstantPropagation(r);
1567 
1568   auto nodes = r->block()->nodes();
1569   // checking that constant propagation ran wo/failure
1570   AT_ASSERT(std::distance(nodes.begin(), nodes.end()) == 1);
1571 }
1572 
1573 static int testPassValue = 0;
fakePass(std::shared_ptr<Graph> & g)1574 void fakePass(std::shared_ptr<Graph>& g) {
1575   testPassValue++;
1576   return;
1577 }
1578 
1579 RegisterPass p(fakePass);
1580 
TEST(PassManagementTest,Basic)1581 TEST(PassManagementTest, Basic) {
1582   std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1583   parseIR(
1584       R"IR(
1585 graph(%a):
1586   return (%a))IR",
1587       &*graph);
1588 
1589   std::vector<IValue> stack = {IValue(torch::randn({22}, at::kCPU))};
1590   auto run = [&](std::shared_ptr<Graph>& graph, std::vector<IValue> stack) {
1591     GraphExecutor executor(graph, "");
1592     executor.run(stack);
1593     return stack;
1594   };
1595   run(graph, stack);
1596   // we will not run fusion in simple mode
1597   if (!getExecutorMode()) {
1598     AT_ASSERT(testPassValue);
1599   }
1600 }
1601 
checkShape(TypePtr typ,std::vector<int64_t> expected)1602 static void checkShape(TypePtr typ, std::vector<int64_t> expected) {
1603   auto ptp = typ->expect<TensorType>();
1604   ASSERT_EQ(ptp->sizes().concrete_sizes().value(), expected);
1605 }
1606 
checkShape(Node * n,std::vector<int64_t> expected,bool prev=true)1607 static void checkShape(
1608     Node* n,
1609     std::vector<int64_t> expected,
1610     bool prev = true) {
1611   auto profile = (prev) ? n->inputs().at(0)->node() : n;
1612   checkShape(profile->output()->type(), expected);
1613 }
1614 
count_(Block * block,const std::function<bool (Node * n)> & pred,size_t & count)1615 void count_(
1616     Block* block,
1617     const std::function<bool(Node* n)>& pred,
1618     size_t& count) {
1619   for (Node* n : block->nodes()) {
1620     if (pred(n)) {
1621       count++;
1622     }
1623 
1624     for (Block* ib : n->blocks()) {
1625       count_(ib, pred, count);
1626     }
1627   }
1628 }
1629 
countNodes(const std::shared_ptr<Graph> & graph,const std::function<bool (Node * n)> & pred)1630 size_t countNodes(
1631     const std::shared_ptr<Graph>& graph,
1632     const std::function<bool(Node* n)>& pred) {
1633   size_t count = 0;
1634   count_(graph->block(), pred, count);
1635   return count;
1636 }
1637 
true_pred(Node * n)1638 bool true_pred(Node* n) {
1639   return true;
1640 };
1641 
is_loop(Node * n)1642 bool is_loop(Node* n) {
1643   return n->kind() == prim::Loop;
1644 };
1645 
TEST(LoopPeelerTest,NoInductionVariableUse)1646 TEST(LoopPeelerTest, NoInductionVariableUse) {
1647   // do not use an induction variable explicitly
1648   static const auto str_func_def = R"JIT(
1649     def test_peel_n_times():
1650       sum = 0
1651       for i in range(10):
1652         sum += 2
1653       return sum
1654     )JIT";
1655 
1656   auto cu = compile(str_func_def);
1657   auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1658   auto stack = createStack({});
1659   // peeling loop once
1660   {
1661     LoopsPeeler peeler(true_pred, 1);
1662     auto copy = f.graph()->copy();
1663     peeler.run(copy);
1664     int num_loops =
1665         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1666     ASSERT_EQ(num_loops, 2);
1667     Code code(copy, "");
1668     InterpreterState interpreter{code};
1669     interpreter.run(stack);
1670     ASSERT_EQ(stack.back().toInt(), 20);
1671   }
1672 
1673   // test peeling more than one iteration
1674   {
1675     LoopsPeeler peeler(true_pred, 3);
1676     auto copy = f.graph()->copy();
1677     peeler.run(copy);
1678     int num_loops =
1679         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1680     ASSERT_EQ(num_loops, 2);
1681     Code code(copy, "");
1682     InterpreterState interpreter{code};
1683     interpreter.run(stack);
1684     ASSERT_EQ(stack.back().toInt(), 20);
1685   }
1686 }
1687 
TEST(LoopPeelerTest,YesInductionVariableUse)1688 TEST(LoopPeelerTest, YesInductionVariableUse) {
1689   // uses the induction variable
1690   static const auto str_func_def = R"JIT(
1691     def test_peel_n_times():
1692       sum = 0
1693       for i in range(10):
1694         sum += i
1695       return sum
1696     )JIT";
1697 
1698   auto cu = compile(str_func_def);
1699   auto& f = toGraphFunction(cu->get_function("test_peel_n_times"));
1700   auto stack = createStack({});
1701   // peeling loop once
1702   {
1703     LoopsPeeler peeler(true_pred, 1);
1704     auto copy = f.graph()->copy();
1705     peeler.run(copy);
1706     int num_loops =
1707         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1708     ASSERT_EQ(num_loops, 2);
1709     Code code(copy, "");
1710     InterpreterState interpreter{code};
1711     interpreter.run(stack);
1712     ASSERT_EQ(stack.back().toInt(), 45);
1713   }
1714 
1715   // test peeling more than one iteration
1716   {
1717     LoopsPeeler peeler(true_pred, 3);
1718     auto copy = f.graph()->copy();
1719     peeler.run(copy);
1720     int num_loops =
1721         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1722     ASSERT_EQ(num_loops, 2);
1723     Code code(copy, "");
1724     InterpreterState interpreter{code};
1725     interpreter.run(stack);
1726     ASSERT_EQ(stack.back().toInt(), 45);
1727   }
1728 }
1729 
TEST(LoopPeelerTest,LoopWithTerminationCondition)1730 TEST(LoopPeelerTest, LoopWithTerminationCondition) {
1731   // tests with explicit termination conditions
1732   static const auto str_func_def = R"JIT(
1733     def test_with_cond_times():
1734       sum = 0
1735       i = 0
1736       while (sum < 2):
1737         sum += i
1738         i += 1
1739       return sum
1740     )JIT";
1741 
1742   // the peel changes the termination condition to false
1743   // so the original loop doesn't run
1744   auto cu = compile(str_func_def);
1745   auto& f = toGraphFunction(cu->get_function("test_with_cond_times"));
1746   auto stack = createStack({});
1747   // peeling 5 iterations should update the termination
1748   // condition to false
1749   {
1750     LoopsPeeler peeler(true_pred, 5);
1751     auto copy = f.graph()->copy();
1752     peeler.run(copy);
1753     int num_loops =
1754         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1755     ASSERT_EQ(num_loops, 2);
1756     Code code(copy, "");
1757     InterpreterState interpreter{code};
1758     interpreter.run(stack);
1759     ASSERT_EQ(stack.back().toInt(), 3);
1760   }
1761 
1762   // the termination condition remains true
1763   {
1764     LoopsPeeler peeler(true_pred, 1);
1765     auto copy = f.graph()->copy();
1766     peeler.run(copy);
1767     int num_loops =
1768         std::count_if(copy->nodes().begin(), copy->nodes().end(), is_loop);
1769     ASSERT_EQ(num_loops, 2);
1770     Code code(copy, "");
1771     InterpreterState interpreter{code};
1772     interpreter.run(stack);
1773     ASSERT_EQ(stack.back().toInt(), 3);
1774   }
1775 }
1776 
1777 // tests simple nested loops
TEST(LoopPeelerTest,SimpleNestedLoops)1778 TEST(LoopPeelerTest, SimpleNestedLoops) {
1779   static const auto str_func_def = R"JIT(
1780     def test_nested_loops():
1781       sum = 0
1782       i = 0
1783       for i in range(10):
1784         for j in range(10):
1785           sum += i + j
1786       return sum
1787     )JIT";
1788 
1789   auto cu = compile(str_func_def);
1790   auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1791   auto stack = createStack({});
1792 
1793   {
1794     LoopsPeeler peeler(true_pred, 1);
1795     auto copy = f.graph()->copy();
1796     peeler.run(copy);
1797     ASSERT_EQ(countNodes(copy, is_loop), 5);
1798     Code code(copy, "");
1799     InterpreterState interpreter{code};
1800     interpreter.run(stack);
1801     ASSERT_EQ(stack.back().toInt(), 900);
1802   }
1803 
1804   {
1805     LoopsPeeler peeler(true_pred, 5);
1806     auto copy = f.graph()->copy();
1807     peeler.run(copy);
1808     ASSERT_EQ(countNodes(copy, is_loop), 5);
1809     Code code(copy, "");
1810     InterpreterState interpreter{code};
1811     interpreter.run(stack);
1812     ASSERT_EQ(stack.back().toInt(), 900);
1813   }
1814 }
1815 
TEST(LoopPeelerTest,SimpleNestedLoops2)1816 TEST(LoopPeelerTest, SimpleNestedLoops2) {
1817   static const auto str_func_def = R"JIT(
1818     def test_nested_loops():
1819       sum = 0
1820       i = 0
1821       for i in range(10):
1822         j = 0
1823         while sum < 2:
1824           sum += i + j
1825           j += 1
1826       return sum
1827     )JIT";
1828 
1829   auto cu = compile(str_func_def);
1830   auto& f = toGraphFunction(cu->get_function("test_nested_loops"));
1831   auto stack = createStack({});
1832   {
1833     LoopsPeeler peeler(true_pred, 1);
1834     auto copy = f.graph()->copy();
1835     peeler.run(copy);
1836     ASSERT_EQ(countNodes(copy, is_loop), 5);
1837     Code code(copy, "");
1838     InterpreterState interpreter{code};
1839     interpreter.run(stack);
1840     ASSERT_EQ(stack.back().toInt(), 3);
1841   }
1842 
1843   {
1844     LoopsPeeler peeler(true_pred, 5);
1845     auto copy = f.graph()->copy();
1846     peeler.run(copy);
1847     ASSERT_EQ(countNodes(copy, is_loop), 5);
1848     Code code(copy, "");
1849     InterpreterState interpreter{code};
1850     interpreter.run(stack);
1851     ASSERT_EQ(stack.back().toInt(), 3);
1852   }
1853 }
1854 
TEST(JitTracing,Basic)1855 TEST(JitTracing, Basic) {
1856   constexpr int batch_size = 4;
1857   constexpr int input_size = 256;
1858 
1859   int hidden_size = 2 * input_size;
1860 
1861   auto input = at::randn({batch_size, input_size}, at::kCPU);
1862   auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
1863   auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
1864   auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
1865   auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
1866 
1867   auto graph = build_lstm();
1868   auto stack = createStack({input, hx, cx, w_ih, w_hh});
1869   auto traced = TraceGraph(graph, stack);
1870 
1871   // Check that the inputs of traced graph have the same type as the inputs
1872   // specified here.
1873   ASSERT_EQ(*traced->inputs().at(0)->type(), *TensorType::create(input));
1874   ASSERT_EQ(*traced->inputs().at(1)->type(), *TensorType::create(hx));
1875   ASSERT_EQ(*traced->inputs().at(2)->type(), *TensorType::create(cx));
1876   ASSERT_EQ(*traced->inputs().at(3)->type(), *TensorType::create(w_ih));
1877   ASSERT_EQ(*traced->inputs().at(4)->type(), *TensorType::create(w_hh));
1878 
1879   Tensor prof_out;
1880   pop(stack, prof_out);
1881 
1882   {
1883     stack = createStack({input, hx, cx, w_ih, w_hh});
1884     Code cd(traced, "traced");
1885     InterpreterState is{cd};
1886     is.run(stack);
1887     Tensor traced_out;
1888     pop(stack, traced_out);
1889     torch::allclose(prof_out, traced_out);
1890   }
1891 
1892   {
1893     stack = createStack({input, hx, cx, w_ih, w_hh});
1894     Code cd(graph, "graph");
1895     InterpreterState is{cd};
1896     is.run(stack);
1897     Tensor scripted_out;
1898     pop(stack, scripted_out);
1899     torch::allclose(prof_out, scripted_out);
1900   }
1901 }
1902 
1903 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InsertAndEliminateRedundantGuardsTest,Basic)1904 TEST(InsertAndEliminateRedundantGuardsTest, Basic) {
1905   static const auto basic_example = R"JIT(
1906   def basic(x, y):
1907     a = x + y
1908     b = x * y
1909     c = x + 1
1910     d = a - c
1911     e = b - c
1912     return d + e
1913   )JIT";
1914 
1915   auto cu = compile(basic_example);
1916   auto& fun = toGraphFunction(cu->get_function("basic"));
1917   auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1918   auto x = at::randn({2, 3}, at::kCPU);
1919   auto y = at::randn({2, 3}, at::kCPU);
1920   auto stack = createStack({x, y});
1921   // introduce some profiling information
1922   Code cd(pr->profiled_graph_, "");
1923   InterpreterState is{cd};
1924   is.run(stack);
1925   auto copy = pr->profiled_graph_->copy();
1926   ProfilingRecord::removeProfileCounter(copy->block());
1927   InsertGuards(copy);
1928   auto nodes = copy->block()->nodes();
1929   auto guard = std::find_if(nodes.begin(), nodes.end(), [](Node* n) {
1930     return n->kind() == prim::Guard;
1931   });
1932   ASSERT_NE(guard, nodes.end());
1933   ASSERT_EQ(
1934       guard->input()->type()->expectRef<TensorType>().sizes().size(),
1935       std::nullopt);
1936   checkShape(*guard, {2, 3}, false);
1937   auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1938   int num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1939   ASSERT_EQ(num_guards, 12);
1940   // now eliminate as many guards as possible
1941   // we should be left with two guards on x and y's defs
1942   EliminateRedundantGuards(copy);
1943   num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1944   ASSERT_EQ(num_guards, 2);
1945 }
1946 
TEST(InsertBailOutsTest,Basic)1947 TEST(InsertBailOutsTest, Basic) {
1948   static const auto basic_example = R"JIT(
1949   def basic_loop(x, y):
1950 
1951       a = x + 1
1952       b = y + 2
1953       c = x + y + 3
1954 
1955       for i in range(10):
1956           a = a + b
1957           # invariant
1958           d = b * c
1959           #
1960           a = a - d
1961 
1962       e = a + 4
1963       return e
1964   )JIT";
1965 
1966   auto cu = compile(basic_example);
1967   auto& fun = toGraphFunction(cu->get_function("basic_loop"));
1968   auto pr = ProfilingRecord::instrumentGraph(fun.graph());
1969   auto x = at::randn({2, 3}, at::kCPU);
1970   auto y = at::randn({2, 3}, at::kCPU);
1971   auto stack = createStack({x, y});
1972   // introduce some profiling information
1973   Code cd(pr->profiled_graph_, "");
1974   InterpreterState is{cd};
1975   is.run(stack);
1976   auto copy = pr->profiled_graph_->copy();
1977   ProfilingRecord::removeProfileCounter(copy->block());
1978   InsertGuards(copy);
1979   EliminateRedundantGuards(copy);
1980   auto nodes = copy->block()->nodes();
1981   auto is_guard = [](Node* n) { return n->kind() == prim::Guard; };
1982   auto num_guards = std::count_if(nodes.begin(), nodes.end(), is_guard);
1983   ASSERT_EQ(num_guards, 3);
1984   InsertBailOuts(copy);
1985   auto is_bailout = [](Node* n) { return n->kind() == prim::BailOut; };
1986   auto num_bailouts = std::count_if(nodes.begin(), nodes.end(), is_bailout);
1987   ASSERT_EQ(num_guards, num_bailouts);
1988   std::vector<Node*> bailouts(num_bailouts);
1989   std::copy_if(nodes.begin(), nodes.end(), bailouts.begin(), is_bailout);
1990 
1991   for (auto blo : bailouts) {
1992     ASSERT_EQ(blo->inputs().at(0)->node()->kind(), prim::BailoutTemplate);
1993   }
1994 }
1995 
TEST(ProfilerTest,Basic)1996 TEST(ProfilerTest, Basic) {
1997   constexpr int batch_size = 4;
1998   constexpr int input_size = 256;
1999 
2000   int hidden_size = 2 * input_size;
2001 
2002   auto input = at::randn({batch_size, input_size}, at::kCPU);
2003   auto hx = at::randn({batch_size, hidden_size}, at::kCPU);
2004   auto cx = at::randn({batch_size, hidden_size}, at::kCPU);
2005   auto w_ih = t_def(at::randn({4 * hidden_size, input_size}, at::kCPU));
2006   auto w_hh = t_def(at::randn({4 * hidden_size, hidden_size}, at::kCPU));
2007 
2008   auto g = build_lstm();
2009   auto stack = createStack({input, hx, cx, w_ih, w_hh});
2010 
2011   auto& opt_graph = *g.get();
2012   ArgumentSpecCreator arg_spec_creator(opt_graph);
2013   ArgumentSpec spec =
2014       arg_spec_creator.create(autograd::GradMode::is_enabled(), stack);
2015   arg_spec_creator.specializeTypes(opt_graph, spec);
2016   auto pr = ProfilingRecord::instrumentGraph(g);
2017   Code cd(pr->profiled_graph_, "");
2018   InterpreterState is{cd};
2019   is.run(stack);
2020 
2021   // profiled types are stored as attributes and show up in the dump, e.g.
2022   // Tensor = prim::profile[profiled_type=Double(4, 256, strides=[256, 1],
2023   // requires_grad=0, device=cpu)
2024   testing::FileCheck()
2025       .check("Tensor = prim::profile[profiled_type")
2026       ->check_same("256")
2027       ->run(*pr->profiled_graph_);
2028 
2029   auto begin = pr->profiled_graph_->block()->nodes().begin();
2030   auto end = pr->profiled_graph_->block()->nodes().end();
2031   auto mm =
2032       std::find_if(begin, end, [](Node* n) { return n->kind() == aten::add; });
2033   ASSERT_NE(mm, end);
2034   std::vector<int64_t> mm_expected{4, 2048};
2035   std::vector<int64_t> eltwise{4, 512};
2036   checkShape(mm->inputs().at(0)->node()->ty(attr::profiled_type), mm_expected);
2037   auto mul_n =
2038       std::find_if(begin, end, [](Node* n) { return n->kind() == aten::mul; });
2039   ASSERT_NE(mul_n, end);
2040   checkShape(mul_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2041   auto tanh_n =
2042       std::find_if(begin, end, [](Node* n) { return n->kind() == aten::tanh; });
2043   checkShape(tanh_n->inputs().at(0)->node()->ty(attr::profiled_type), eltwise);
2044 }
2045 
TEST(ProfilerTest,OptionalProfiling)2046 TEST(ProfilerTest, OptionalProfiling) {
2047   auto graph = std::make_shared<Graph>();
2048   std::unordered_map<std::string, Value*> vmap;
2049   parseIR(
2050       R"IR(
2051 graph(%inp : Tensor,
2052       %weight : Tensor,
2053       %bias : Tensor?):
2054   %1 : Tensor = aten::linear(%inp, %weight, %bias)
2055   return (%1))IR",
2056       &*graph,
2057       vmap);
2058 
2059   auto pr = ProfilingRecord::instrumentGraph(graph);
2060   pr->profiling_count_ = 2;
2061 
2062   auto input = torch::randn({1, 2});
2063   auto weight = torch::randn({2, 2});
2064   auto bias = torch::randn({1, 2});
2065 
2066   auto stack = createStack({input, weight, bias});
2067   Code cd(pr->profiled_graph_, "");
2068   InterpreterState is{cd};
2069   is.run(stack);
2070 
2071   testing::FileCheck()
2072       .check_count("Tensor? = prim::profile[profiled_type", 1, true)
2073       ->run(*pr->profiled_graph_);
2074 
2075   // make sure we recorded the shape
2076   auto begin = pr->profiled_graph_->block()->nodes().begin();
2077   auto end = pr->profiled_graph_->block()->nodes().end();
2078   auto linear = std::find_if(
2079       begin, end, [](Node* n) { return n->kind() == aten::linear; });
2080   ASSERT_NE(linear, end);
2081   std::vector<int64_t> bias_expected_shape = {1, 2};
2082   auto profiled_bias = linear->namedInput("bias")->node();
2083   checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2084   ASSERT_EQ(0, profiled_bias->i(attr::seen_none));
2085 
2086   auto none_bias = c10::IValue();
2087 
2088   stack.clear();
2089   stack.emplace_back(input);
2090   stack.emplace_back(weight);
2091   stack.emplace_back(none_bias);
2092   is = InterpreterState{cd};
2093   is.run(stack);
2094 
2095   // make sure we recorded that "None" was seen.
2096   begin = pr->profiled_graph_->block()->nodes().begin();
2097   end = pr->profiled_graph_->block()->nodes().end();
2098   linear = std::find_if(
2099       begin, end, [](Node* n) { return n->kind() == aten::linear; });
2100   ASSERT_NE(linear, end);
2101   profiled_bias = linear->namedInput("bias")->node();
2102   checkShape(profiled_bias->ty(attr::profiled_type), bias_expected_shape);
2103   ASSERT_EQ(1, profiled_bias->i(attr::seen_none));
2104 }
2105 
TEST(CallStackTest,Basic)2106 TEST(CallStackTest, Basic) {
2107   const auto text = R"(
2108 def ham(x):
2109     return x/7
2110 
2111 def bar(x):
2112     return x*3
2113 
2114 def baz(x):
2115     return ham(x)*x
2116 
2117 def foo(x):
2118     return bar(x)*baz(x)*11
2119   )";
2120   auto cu = compile(text);
2121   const auto& foo = toGraphFunction(cu->get_function("foo"));
2122   for (Node* n : foo.optimized_graph()->nodes()) {
2123     if (n->kind() == prim::Constant) {
2124       if (!n->hasAttribute(attr::value) ||
2125           n->kindOf(attr::value) != AttributeKind::i) {
2126         continue;
2127       }
2128       int v = n->i(attr::value);
2129       switch (v) {
2130         case 3: {
2131           // Const 3 comes from function 'bar', which gets inlined to 'foo'.
2132           // The callstack for the corresponding node should contain only the
2133           // function 'bar'.
2134           ASSERT_TRUE(n->callstack());
2135           auto callstack_vector = (*n->callstack())->vec();
2136           ASSERT_EQ(callstack_vector.size(), 1);
2137           ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("bar"));
2138           break;
2139         }
2140         case 7: {
2141           // Const 7 comes from function 'ham', which gets inlined to 'baz',
2142           // which is then inlined to 'foo'. The callstack for the corresponding
2143           // node should contain these two functions.
2144           ASSERT_TRUE(n->callstack());
2145           auto callstack_vector = (*n->callstack())->vec();
2146           ASSERT_EQ(callstack_vector.size(), 2);
2147           ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("baz"));
2148           ASSERT_EQ(std::get<0>(callstack_vector[1]), &cu->get_function("ham"));
2149           break;
2150         }
2151         case 11: {
2152           // Const 11 comes from function 'foo', which is not inlined anywhere
2153           // and thus it should not have a callstack.
2154           ASSERT_FALSE(n->callstack());
2155           break;
2156         }
2157       }
2158     }
2159   }
2160 
2161   // Check that inlining doesn't corrupt callstack of the callee's nodes.
2162   const auto& baz = toGraphFunction(cu->get_function("baz"));
2163   for (Node* n : baz.optimized_graph()->nodes()) {
2164     if (n->kind() == prim::Constant) {
2165       if (!n->hasAttribute(attr::value) ||
2166           n->kindOf(attr::value) != AttributeKind::i) {
2167         continue;
2168       }
2169       int v = n->i(attr::value);
2170       ASSERT_TRUE(v == 7);
2171       // Const 7 comes from function 'ham', which gets inlined to 'baz'. 'baz'
2172       // was also inlined into 'foo', but when looking at the graph of 'baz' we
2173       // should only see a callstack of depth 1 (containing only 'ham').
2174       ASSERT_TRUE(n->callstack());
2175       auto callstack_vector = (*n->callstack())->vec();
2176       ASSERT_EQ(callstack_vector.size(), 1);
2177       ASSERT_EQ(std::get<0>(callstack_vector[0]), &cu->get_function("ham"));
2178     }
2179   }
2180 }
2181 
TEST(CallStackTest,Caching)2182 TEST(CallStackTest, Caching) {
2183   const auto text = R"(
2184 
2185 def a(x):
2186     print("a1")
2187     print("a2")
2188     return x
2189 
2190 def b(x):
2191     print("b1")
2192     print("b2")
2193     a(x)
2194     return x
2195 
2196 def c(x):
2197     print("c1")
2198     print("c2")
2199     b(x)
2200     return x
2201   )";
2202   auto cu = compile(text);
2203   const auto& baz = toGraphFunction(cu->get_function("c"));
2204   std::unordered_map<std::string, InlinedCallStack*> callstack_objects;
2205   for (Node* n : baz.optimized_graph()->nodes()) {
2206     if (n->kind() == prim::Constant) {
2207       if (!n->hasAttribute(attr::value) ||
2208           n->kindOf(attr::value) != AttributeKind::s) {
2209         continue;
2210       }
2211       // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
2212       std::string v = n->s(attr::value);
2213       if (n->callstack()) {
2214         callstack_objects[v] = n->callstack()->get();
2215       }
2216     }
2217   }
2218   // We expect to see nodes prim::Constant[value="a1"] and
2219   // prim::Constant[value="a2"] inlined to function 'c'. Their callstacks are
2220   // the same (a->b->c), so we want to make sure we're not creating different
2221   // callstack entries for them.
2222   ASSERT_TRUE(callstack_objects.count("a1") && callstack_objects.count("a2"));
2223   ASSERT_TRUE(callstack_objects.at("a1") == callstack_objects.at("a2"));
2224 }
2225 
TEST(InlinedCallStackTest,BlockAnnotation)2226 TEST(InlinedCallStackTest, BlockAnnotation) {
2227   Module a("A");
2228   a.define(R"(
2229     def forward(self, x, y, z: int):
2230       if (z == 1):
2231         return x + y
2232       else:
2233         return x * y
2234   )");
2235   Module b("B");
2236   b.define(R"(
2237     def forward(self, x):
2238       return x + 2
2239   )");
2240   Module c("C");
2241   c.register_module("A0", a);
2242   c.register_module("B0", b);
2243   c.define(R"(
2244     def forward(self, x, y, z: int):
2245       return self.A0.forward(x, y, z) + self.B0.forward(x)
2246   )");
2247 
2248   auto graph =
2249       toGraphFunction(c.get_method("forward").function()).optimized_graph();
2250   std::stringstream add_ss, mul_ss;
2251   for (Node* n : graph->nodes()) {
2252     if (n->kind() == prim::If) {
2253       for (Block* block : n->blocks()) {
2254         for (Node* if_node : block->nodes()) {
2255           if (if_node->kind() == aten::add) {
2256             for (const auto& e : if_node->callstack().value()->vec()) {
2257               add_ss << std::get<1>(e);
2258             }
2259             add_ss << if_node->sourceRange();
2260           }
2261           if (if_node->kind() == aten::mul) {
2262             for (const auto& e : if_node->callstack().value()->vec()) {
2263               mul_ss << std::get<1>(e);
2264             }
2265             mul_ss << if_node->sourceRange();
2266           }
2267         }
2268       }
2269     }
2270   }
2271   ASSERT_NE(add_ss.str().find("line 3"), std::string::npos);
2272   ASSERT_NE(add_ss.str().find("line 4"), std::string::npos);
2273   ASSERT_NE(
2274       add_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2275   ASSERT_NE(add_ss.str().find("return x + y"), std::string::npos);
2276   ASSERT_NE(mul_ss.str().find("line 3"), std::string::npos);
2277   ASSERT_NE(mul_ss.str().find("line 6"), std::string::npos);
2278   ASSERT_NE(
2279       mul_ss.str().find("return self.A0.forward(x, y, z)"), std::string::npos);
2280   ASSERT_NE(mul_ss.str().find("return x * y"), std::string::npos);
2281 }
2282 
2283 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(InlinedCallStackTest,SelfCallMethods)2284 TEST(InlinedCallStackTest, SelfCallMethods) {
2285   Module a("A");
2286   a.define(R"(
2287     def my_new_method(self, x):
2288       return x * 3
2289     def forward_impl_(self, x, y):
2290       return self.my_new_method(x) + y
2291     def forward(self, x, y):
2292       y = y + 2
2293       return self.forward_impl_(x, y)
2294   )");
2295   Module b("B");
2296   b.define(R"(
2297     def forward(self, x):
2298       return x + 2
2299   )");
2300   Module c("C");
2301   c.register_module("A0", a);
2302   c.register_module("B0", b);
2303   c.define(R"(
2304     def call_b(self, x):
2305       return self.B0.forward(x)
2306     def forward(self, x, y):
2307       return self.A0.forward(x, y) + self.call_b(x)
2308   )");
2309 
2310   auto graph =
2311       toGraphFunction(c.get_method("forward").function()).optimized_graph();
2312   std::unordered_map<std::string, size_t> module_hierarchies;
2313   for (Node* n : graph->nodes()) {
2314     auto hierarchy = torch::jit::utils::getNodesModuleHierarchy(*n);
2315     if (module_hierarchies.count(hierarchy) == 0) {
2316       module_hierarchies[hierarchy] = 0;
2317     }
2318     module_hierarchies[hierarchy] += 1;
2319   }
2320   ASSERT_EQ(module_hierarchies["A0(A)"], 2);
2321   ASSERT_EQ(module_hierarchies["A0(A).SELF(A).SELF(A)"], 2);
2322   ASSERT_EQ(module_hierarchies["A0(A).SELF(A)"], 1);
2323   ASSERT_EQ(module_hierarchies["SELF(C)"], 1);
2324   ASSERT_EQ(module_hierarchies["SELF(C).B0(B)"], 1);
2325 }
2326 
TEST(AutogradSymbolsTest,Basic)2327 TEST(AutogradSymbolsTest, Basic) {
2328   Symbol sym = Symbol::fromQualString("aten::test_symbol");
2329   Graph graph;
2330   auto node = graph.create(sym);
2331   TORCH_CHECK(canRunWithAutograd(node));
2332 
2333   sym = Symbol::fromQualString("prim::test_symbol");
2334   node = graph.create(sym);
2335   TORCH_CHECK(canRunWithAutograd(node));
2336 
2337   sym = Symbol::fromQualString("prim::FusionGroup");
2338   node = graph.create(sym);
2339   TORCH_CHECK(!canRunWithAutograd(node));
2340 
2341   sym = Symbol::fromQualString("custom::test_symbol");
2342   node = graph.create(sym);
2343   TORCH_CHECK(!canRunWithAutograd(node));
2344 }
2345 
TEST(DefaultArgTypeHintingTest,Basic)2346 TEST(DefaultArgTypeHintingTest, Basic) {
2347   const auto text_non_hinted = R"(
2348 
2349 def a(x, y=1):
2350     print("a1")
2351     print("a2")
2352     return x
2353   )";
2354 
2355   const auto text_hinted = R"(
2356 
2357 def a(x, y:int=1):
2358     print("a1")
2359     print("a2")
2360     return x
2361   )";
2362 
2363   try {
2364     compile(text_non_hinted);
2365     ASSERT_TRUE(0);
2366   } catch (const std::exception& c) {
2367   }
2368 
2369   auto cu = compile(text_hinted);
2370 }
2371 
2372 // Basic set case.
TEST(FuturesTest,Basic)2373 TEST(FuturesTest, Basic) {
2374   auto f1 = c10::make_intrusive<Future>(IntType::get());
2375   ASSERT_FALSE(f1->completed());
2376   ASSERT_FALSE(f1->hasValue());
2377   int32_t sat1 = 0;
2378   int32_t sat2 = 0;
2379   f1->addCallback([&](Future& /* unused */) { ++sat1; });
2380   f1->markCompleted(43);
2381   ASSERT_TRUE(f1->completed());
2382   ASSERT_TRUE(f1->hasValue());
2383   ASSERT_FALSE(f1->hasError());
2384   ASSERT_EQ(sat1, 1);
2385   ASSERT_EQ(f1->constValue().toInt(), 43);
2386   ASSERT_EQ(f1->value().toInt(), 43);
2387   f1->addCallback([&](Future& /* unused */) { ++sat2; });
2388   ASSERT_EQ(sat1, 1);
2389   ASSERT_EQ(sat2, 1);
2390 }
2391 
2392 // Sparse CUDA tensor test
TEST(FutureTest,SparseTensor)2393 TEST(FutureTest, SparseTensor) {
2394   // Skip test if CUDA is not available.
2395   bool has_cuda = at::globalContext().hasCUDA();
2396   if (!has_cuda) {
2397     LOG(INFO) << "CUDA not available, skipping test";
2398   }
2399   for (int i = 0; i < 2; ++i) {
2400     auto f = c10::make_intrusive<Future>(TensorType::get());
2401     at::TensorOptions opts = at::TensorOptions().device(at::DeviceType::CUDA);
2402     auto sparse_tensor = i == 0 ? at::ones(10).to_sparse()
2403                                 : at::sparse_coo_tensor(
2404                                       at::arange(10).unsqueeze(0).to(at::kLong),
2405                                       at::ones({10, 10}),
2406                                       opts);
2407     // Runs storage extraction for sparse CUDA tensors
2408     f->markCompleted(sparse_tensor);
2409     ASSERT_TRUE(f->completed());
2410     ASSERT_FALSE(f->hasError());
2411   }
2412 }
2413 
2414 // Basic error cases.
TEST(FuturesTest,Error)2415 TEST(FuturesTest, Error) {
2416   auto f1 = c10::make_intrusive<Future>(IntType::get());
2417   int sat1 = 0;
2418   int sat2 = 0;
2419   f1->addCallback([&](Future& /* unused */) { ++sat1; });
2420   f1->setError(
2421       std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2422   ASSERT_EQ(sat1, 1);
2423   ASSERT_TRUE(f1->completed());
2424   ASSERT_TRUE(f1->hasError());
2425   ASSERT_FALSE(f1->hasValue());
2426   try {
2427     (void)f1->value();
2428     ASSERT_TRUE(false); // Supposed to throw.
2429   } catch (const std::exception& e) {
2430     ASSERT_TRUE(strcmp(e.what(), "Failed") == 0);
2431   }
2432   f1->addCallback([&](Future& /* unused */) { ++sat2; });
2433   ASSERT_EQ(sat1, 1);
2434   ASSERT_EQ(sat2, 1);
2435   f1->setErrorIfNeeded(
2436       std::make_exception_ptr(c10::ivalue::Future::FutureError("Dup")));
2437   ASSERT_TRUE(strcmp(f1->tryRetrieveErrorMessage().c_str(), "Failed") == 0);
2438   ASSERT_EQ(sat1, 1);
2439   ASSERT_EQ(sat2, 1);
2440   try {
2441     (void)f1->constValue();
2442     ASSERT_TRUE(false); // Supposed to throw.
2443   } catch (const std::exception& e) {
2444     // Original error should be logged.
2445     ASSERT_TRUE(std::string(e.what()).find("Failed") != std::string::npos);
2446   }
2447 }
2448 
2449 // then
TEST(FuturesTest,Then)2450 TEST(FuturesTest, Then) {
2451   auto f1 = c10::make_intrusive<Future>(IntType::get());
2452   auto f2 = f1->then(
2453       [](Future& f1) -> IValue { return f1.constValue().toInt() + 1; },
2454       IntType::get());
2455   auto f3 = f2->then(
2456       [](Future& f2) -> IValue { return f2.constValue().toInt() * 3; },
2457       IntType::get());
2458   bool done = false;
2459   f3->addCallback([&done](Future& f3) {
2460     ASSERT_EQ(f3.constValue().toInt(), (42 + 1) * 3);
2461     done = true;
2462   });
2463   ASSERT_FALSE(done);
2464   f1->markCompleted(42);
2465   ASSERT_TRUE(done);
2466 }
2467 
2468 // collectAll()
TEST(FuturesTest,CollectAll)2469 TEST(FuturesTest, CollectAll) {
2470   auto s1 = c10::make_intrusive<Future>(IntType::get());
2471   auto s2 = c10::make_intrusive<Future>(IntType::get());
2472   auto s3 = c10::make_intrusive<Future>(IntType::get());
2473 
2474   // Empty case
2475   c10::List<intrusive_ptr<ivalue::Future>> futures(
2476       FutureType::create(IntType::get()));
2477   auto c1 = collectAll(futures);
2478   ASSERT_TRUE(c1->completed());
2479   ASSERT_EQ(c1->value().toList().size(), 0);
2480   ASSERT_TRUE(
2481       *(c1->value().toList().elementType()) ==
2482       *FutureType::create(IntType::get()));
2483 
2484   // 1-element, initially not completed.
2485   futures.push_back(s1);
2486   auto c2 = collectAll(futures);
2487   ASSERT_FALSE(c2->completed());
2488   s1->markCompleted(5);
2489   ASSERT_TRUE(c2->completed());
2490   ASSERT_EQ(c2->value().toList().size(), 1);
2491   ASSERT_TRUE(
2492       *(c2->value().toList().elementType()) ==
2493       *FutureType::create(IntType::get()));
2494   ASSERT_EQ(c2->value().toList().get(0).toFuture()->value().toInt(), 5);
2495 
2496   // 1-element, already completed
2497   auto c3 = collectAll(futures);
2498   ASSERT_TRUE(c3->completed());
2499   ASSERT_EQ(c3->value().toList().size(), 1);
2500   ASSERT_EQ(c3->value().toList().get(0).toFuture()->value().toInt(), 5);
2501 
2502   // 3 elements.
2503   futures.push_back(s2);
2504   futures.push_back(s3);
2505   auto c4 = collectAll(futures);
2506   ASSERT_FALSE(c4->completed());
2507   s3->markCompleted(7);
2508   ASSERT_FALSE(c4->completed());
2509   s2->markCompleted(6);
2510   ASSERT_TRUE(c4->completed());
2511   ASSERT_EQ(c4->value().toList().size(), 3);
2512   ASSERT_EQ(c4->value().toList().get(0).toFuture()->value().toInt(), 5);
2513   ASSERT_EQ(c4->value().toList().get(1).toFuture()->value().toInt(), 6);
2514   ASSERT_EQ(c4->value().toList().get(2).toFuture()->value().toInt(), 7);
2515   ASSERT_TRUE(
2516       *(c4->value().toList().elementType()) ==
2517       *FutureType::create(IntType::get()));
2518 
2519   // Handle exception in the list.
2520   auto s4 = c10::make_intrusive<Future>(IntType::get());
2521   futures.push_back(s4);
2522   auto c5 = collectAll(futures);
2523   ASSERT_FALSE(c5->completed());
2524   s4->setError(
2525       std::make_exception_ptr(c10::ivalue::Future::FutureError("Failed")));
2526   ASSERT_TRUE(c5->completed());
2527   try {
2528     c5->value();
2529     ASSERT_TRUE(false); // supposed to throw
2530   } catch (const std::exception& e) {
2531     ASSERT_EQ(std::string(e.what()), "Failed");
2532   }
2533 }
2534 
2535 // collectAny()
TEST(FuturesTest,CollectAny)2536 TEST(FuturesTest, CollectAny) {
2537   auto s1 = c10::make_intrusive<Future>(IntType::get());
2538 
2539   // Empty case
2540   c10::List<intrusive_ptr<ivalue::Future>> futures(
2541       FutureType::create(IntType::get()));
2542   auto c1 = collectAny(futures);
2543   ASSERT_TRUE(c1->completed());
2544 
2545   // 1 element, not yet satisfied
2546   futures.push_back(s1);
2547   auto c2 = collectAny(futures);
2548   ASSERT_FALSE(c2->completed());
2549   s1->markCompleted(5);
2550   ASSERT_TRUE(c2->completed());
2551   ASSERT_TRUE(c2->value().isInt());
2552   ASSERT_EQ(c2->value().toInt(), 5);
2553 
2554   // 1 element already satisfied.
2555   auto c3 = collectAny(futures);
2556   ASSERT_TRUE(c3->completed());
2557   ASSERT_TRUE(c3->value().isInt());
2558   ASSERT_EQ(c3->value().toInt(), 5);
2559 
2560   // 2 elements
2561   futures.clear();
2562   auto s2 = c10::make_intrusive<Future>(IntType::get());
2563   auto s3 = c10::make_intrusive<Future>(IntType::get());
2564   futures.push_back(s2);
2565   futures.push_back(s3);
2566   auto c4 = collectAny(futures);
2567   ASSERT_FALSE(c4->completed());
2568   s3->markCompleted(7);
2569   ASSERT_TRUE(c4->completed());
2570   ASSERT_EQ(c4->value().toInt(), 7);
2571   s2->markCompleted(1);
2572   ASSERT_EQ(c4->value().toInt(), 7);
2573 }
2574 
TEST(TLSFutureCallbacksTest,Basic)2575 TEST(TLSFutureCallbacksTest, Basic) {
2576   // cb that verifies the profiler is enabled
2577   auto profilerEnabledCb = [](Future& /* unused */) {
2578     ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2579   };
2580   // test running callbacks with propagation of TLS state.
2581   {
2582     // Enable the profiler in this thread
2583     torch::autograd::profiler::enableProfilerLegacy(
2584         torch::autograd::profiler::ProfilerConfig(
2585             torch::autograd::profiler::ProfilerState::CPU, false, false));
2586     auto s1 = c10::make_intrusive<Future>(IntType::get());
2587     s1->addCallback(wrapPropagateTLSState(profilerEnabledCb));
2588     std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2589     // Since we join here, we can ensure that all callbacks corresponding to
2590     // markCompleted() have finished.
2591     t.join();
2592     torch::autograd::profiler::disableProfilerLegacy();
2593   }
2594   // then() with TLS State
2595   {
2596     // Enable the profiler in this thread
2597     torch::autograd::profiler::enableProfilerLegacy(
2598         torch::autograd::profiler::ProfilerConfig(
2599             torch::autograd::profiler::ProfilerState::CPU, false, false));
2600     auto s1 = c10::make_intrusive<Future>(IntType::get());
2601     auto s2 = s1->then(
2602         wrapPropagateTLSState([&profilerEnabledCb](Future& s1) {
2603           profilerEnabledCb(s1);
2604           return at::IValue(1);
2605         }),
2606         IntType::get());
2607     std::thread t([s1 = std::move(s1)]() { s1->markCompleted(); });
2608     t.join();
2609     s2->wait();
2610     torch::autograd::profiler::disableProfilerLegacy();
2611   }
2612 }
2613 
TEST(ProfilerDisableInCallbackTest,Basic)2614 TEST(ProfilerDisableInCallbackTest, Basic) {
2615   // cb that verifies the profiler is enabled
2616   auto profilerEnabledCb = []() {
2617     ASSERT_TRUE(torch::autograd::profiler::profilerEnabled());
2618   };
2619   torch::autograd::profiler::enableProfilerLegacy(
2620       torch::autograd::profiler::ProfilerConfig(
2621           torch::autograd::profiler::ProfilerState::CPU, false, false));
2622   auto s1 = c10::make_intrusive<Future>(IntType::get());
2623   auto verifyProfilerCb =
2624       wrapPropagateTLSState([&profilerEnabledCb](Future& /* unused */) {
2625         // Ensure the profiler is still enabled in this thread.
2626         profilerEnabledCb();
2627         auto t1 = torch::ones({2, 2});
2628         auto t2 = torch::ones({2, 2});
2629         torch::add(t1, t2);
2630         // Don't cleanup TLSState, and just consolidate.
2631         auto opts =
2632             torch::autograd::profiler::ProfilerDisableOptions(false, true);
2633         auto thread_event_lists =
2634             // NOLINTNEXTLINE(performance-move-const-arg)
2635             torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2636         // Ensure that the events from this thread are still profiled and we
2637         // obtain the expected in events in our consolidated list when calling
2638         // disableProfilerLegacy().
2639         bool found_ones = false;
2640         bool found_add = false;
2641         for (const auto& li : thread_event_lists) {
2642           for (const auto& evt : li) {
2643             if (strcmp(evt.name(), "aten::add") == 0) {
2644               found_add = true;
2645             } else if (strcmp(evt.name(), "aten::ones") == 0) {
2646               found_ones = true;
2647             }
2648           }
2649           if (found_add && found_ones) {
2650             break;
2651           }
2652         }
2653         ASSERT_TRUE(found_ones);
2654         ASSERT_TRUE(found_add);
2655       });
2656 
2657   s1->addCallback(verifyProfilerCb);
2658   // Disable the profiler, but do not consolidate results in the main thread.
2659   auto opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2660   // NOLINTNEXTLINE(performance-move-const-arg)
2661   torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2662   std::thread t([s1 = std::move(s1)]() { s1->markCompleted(at::IValue(1)); });
2663   t.join();
2664 
2665   // Similar to above test, but verifies correctness in the case where
2666   // continuation runs on the main thread.
2667   torch::autograd::profiler::enableProfilerLegacy(
2668       torch::autograd::profiler::ProfilerConfig(
2669           torch::autograd::profiler::ProfilerState::CPU, false, false));
2670   s1 = c10::make_intrusive<Future>(IntType::get());
2671   s1->addCallback(verifyProfilerCb);
2672   // Runs callback inline
2673   s1->markCompleted(at::IValue(1));
2674   opts = torch::autograd::profiler::ProfilerDisableOptions(true, false);
2675   // NOLINTNEXTLINE(performance-move-const-arg)
2676   torch::autograd::profiler::disableProfilerLegacy(std::move(opts));
2677 }
2678 
TEST(RecordDebugHandles,Basic)2679 TEST(RecordDebugHandles, Basic) {
2680   // Enable the profiler in this thread
2681   const std::set<torch::autograd::profiler::ActivityType> activities(
2682       {torch::autograd::profiler::ActivityType::CPU});
2683   torch::autograd::profiler::prepareProfiler(
2684       torch::autograd::profiler::ProfilerConfig(
2685           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2686       activities);
2687   torch::autograd::profiler::enableProfiler(
2688       torch::autograd::profiler::ProfilerConfig(
2689           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2690       activities);
2691   {
2692     RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2693     float x{5.9999}, y{2.1212};
2694     float z = x / y;
2695     (void)z;
2696   }
2697   {
2698     RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2699     float x{5.9999}, y{2.1212};
2700     float z = x / y;
2701     (void)z;
2702   }
2703   auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2704   const auto& kineto_events = profiler_results_ptr->events();
2705   size_t my_events{0};
2706   for (const auto& e : kineto_events) {
2707     if (e.name() == "my_function") {
2708       ASSERT_EQ(e.debugHandle(), 42);
2709       my_events++;
2710     } else if (e.name() == "not_my_function") {
2711       ASSERT_EQ(e.debugHandle(), -1);
2712       my_events++;
2713     }
2714   }
2715   ASSERT_EQ(my_events, 2);
2716 }
2717 
TEST(RecordDebugHandles,ScopedCallbacks)2718 TEST(RecordDebugHandles, ScopedCallbacks) {
2719   // Enable the profiler in this thread
2720   torch::autograd::profiler::prepareProfiler(
2721       torch::autograd::profiler::ProfilerConfig(
2722           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2723       {torch::autograd::profiler::ActivityType::CPU});
2724   torch::autograd::profiler::enableProfiler(
2725       torch::autograd::profiler::ProfilerConfig(
2726           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2727       {torch::autograd::profiler::ActivityType::CPU});
2728 
2729   {
2730     auto a = torch::rand({128, 128});
2731     auto b = torch::rand({128, 128});
2732     auto c = a + b;
2733   }
2734   auto profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2735   ASSERT_TRUE(profiler_results_ptr->events().size() > 0);
2736 
2737   // Enable the profiler in this thread
2738   torch::autograd::profiler::prepareProfiler(
2739       torch::autograd::profiler::ProfilerConfig(
2740           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2741       {torch::autograd::profiler::ActivityType::CPU});
2742   torch::autograd::profiler::enableProfiler(
2743       torch::autograd::profiler::ProfilerConfig(
2744           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2745       {torch::autograd::profiler::ActivityType::CPU},
2746       {at::RecordScope::LITE_INTERPRETER});
2747   {
2748     auto a = torch::rand({128, 128});
2749     auto b = torch::rand({128, 128});
2750     auto c = a + b;
2751   }
2752   profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2753   ASSERT_TRUE(profiler_results_ptr->events().size() == 0);
2754 
2755   torch::autograd::profiler::prepareProfiler(
2756       torch::autograd::profiler::ProfilerConfig(
2757           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2758       {torch::autograd::profiler::ActivityType::CPU});
2759   torch::autograd::profiler::enableProfiler(
2760       torch::autograd::profiler::ProfilerConfig(
2761           torch::autograd::profiler::ProfilerState::KINETO, false, false),
2762       {torch::autograd::profiler::ActivityType::CPU},
2763       {at::RecordScope::LITE_INTERPRETER});
2764   {
2765     RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS("my_function", 42, {});
2766     auto a = torch::rand({128, 128});
2767     auto b = torch::rand({128, 128});
2768     auto c = a + b;
2769   }
2770   {
2771     RECORD_USER_SCOPE_WITH_INPUTS("not_my_function", {});
2772     auto a = torch::rand({128, 128});
2773     auto b = torch::rand({128, 128});
2774     auto c = a + b;
2775   }
2776   profiler_results_ptr = torch::autograd::profiler::disableProfiler();
2777   const auto& kineto_events = profiler_results_ptr->events();
2778   for (const auto& e : kineto_events) {
2779     if (e.name() == "my_function") {
2780       ASSERT_EQ(e.debugHandle(), 42);
2781     }
2782   }
2783   ASSERT_TRUE(profiler_results_ptr->events().size() == 1);
2784 }
2785 
TEST(IValueKWargsTest,Basic)2786 TEST(IValueKWargsTest, Basic) {
2787   const auto text = R"(
2788     def foo(a : int, b : int, c : int = 4):
2789       return a + 2*b + 3*c
2790   )";
2791   auto cu = compile(text);
2792   auto result = cu->get_function("foo")({1}, {{"b", 3}});
2793   ASSERT_EQ(result.toInt(), 19);
2794 }
2795 
TEST(ComputeFlopsTest,Basic)2796 TEST(ComputeFlopsTest, Basic) {
2797   uint64_t flops = 0;
2798 
2799   // Test unknown operator
2800   std::unordered_map<std::string, c10::IValue> extra_args;
2801   flops = torch::profiler::impl::computeFlops(
2802       std::string("aten::unknown"), extra_args);
2803   ASSERT_EQ(flops, 0);
2804 
2805   // Test aten::conv2d
2806   extra_args.clear();
2807   std::vector<int64_t> input_size = {4, 5, 6, 7};
2808   std::vector<int64_t> weight_size = {3, 5, 2, 1};
2809   std::vector<int64_t> padding = {1, 0};
2810   std::vector<int64_t> stride = {1, 1};
2811   std::vector<int64_t> dilation = {0, 0};
2812   extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2813   extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2814   extra_args["groups"] = 1;
2815   extra_args["padding"] = at::IValue(at::IntArrayRef(padding));
2816   extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2817   extra_args["dilation"] = at::IValue(at::IntArrayRef(dilation));
2818   flops = torch::profiler::impl::computeFlops(
2819       std::string("aten::conv2d"), extra_args);
2820   ASSERT_EQ(flops, 13440);
2821 
2822   // Test aten::conv2d fail
2823   input_size = {4, 5, 6, 7};
2824   weight_size = {4, 5, 6};
2825   extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2826   extra_args["weight_size"] = at::IValue(at::IntArrayRef(weight_size));
2827   flops = torch::profiler::impl::computeFlops(
2828       std::string("aten::conv2d"), extra_args);
2829   ASSERT_EQ(flops, 0);
2830 
2831   // Test aten::conv2d fail 2
2832   weight_size = {3, 5, 2, 1};
2833   stride = {0, 0};
2834   extra_args["weight_size"] = at::IValue(at::IntArrayRef(input_size));
2835   extra_args["stride"] = at::IValue(at::IntArrayRef(stride));
2836   flops = torch::profiler::impl::computeFlops(
2837       std::string("aten::conv2d"), extra_args);
2838   ASSERT_EQ(flops, 0);
2839 
2840   // Test aten::conv2d fail 3
2841   extra_args.clear();
2842   input_size = {4, 5, 6, 7};
2843   extra_args["input_size"] = at::IValue(at::IntArrayRef(input_size));
2844   flops = torch::profiler::impl::computeFlops(
2845       std::string("aten::conv2d"), extra_args);
2846   ASSERT_EQ(flops, 0);
2847 
2848   // Test aten::mm
2849   extra_args.clear();
2850   std::vector<int64_t> mat1_sizes = {3, 4, 5, 6};
2851   std::vector<int64_t> mat2_sizes = {6, 5, 4, 3};
2852   extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2853   extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2854   flops =
2855       torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2856   ASSERT_EQ(flops, 43200);
2857 
2858   // Test aten::addmm
2859   flops = torch::profiler::impl::computeFlops(
2860       std::string("aten::addmm"), extra_args);
2861   ASSERT_EQ(flops, 43200);
2862 
2863   // Test aten::bmm
2864   extra_args.clear();
2865   mat1_sizes = {7, 5, 6};
2866   mat2_sizes = {7, 6, 3};
2867   extra_args["mat1_size"] = at::IValue(at::IntArrayRef(mat1_sizes));
2868   extra_args["mat2_size"] = at::IValue(at::IntArrayRef(mat2_sizes));
2869   flops =
2870       torch::profiler::impl::computeFlops(std::string("aten::bmm"), extra_args);
2871   ASSERT_EQ(flops, 1260);
2872 
2873   // Test aten::baddbmm
2874   flops = torch::profiler::impl::computeFlops(
2875       std::string("aten::baddbmm"), extra_args);
2876   ASSERT_EQ(flops, 1260);
2877 
2878   // Test mm out of range
2879   extra_args.clear();
2880   flops =
2881       torch::profiler::impl::computeFlops(std::string("aten::mm"), extra_args);
2882   ASSERT_EQ(flops, 0);
2883 
2884   // Test aten::add.Tensor
2885   extra_args.clear();
2886   std::vector<int64_t> mat_sizes = {3, 4, 5, 6};
2887   extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2888   flops =
2889       torch::profiler::impl::computeFlops(std::string("aten::add"), extra_args);
2890   ASSERT_EQ(flops, 360);
2891 
2892   // Test aten::mul.Tensor
2893   extra_args.clear();
2894   mat_sizes = {3, 4, 5, 6};
2895   extra_args["mat_size"] = at::IValue(at::IntArrayRef(mat_sizes));
2896   flops =
2897       torch::profiler::impl::computeFlops(std::string("aten::mul"), extra_args);
2898   ASSERT_EQ(flops, 360);
2899 }
2900 
TEST(TestConstant,TensorGrad)2901 TEST(TestConstant, TensorGrad) {
2902   auto graph = std::make_shared<Graph>();
2903   IValue ten = torch::randn({3, 5}).requires_grad_(true);
2904   auto con = tryInsertConstant(*graph, ten);
2905   ASSERT_TRUE(con == std::nullopt);
2906 }
2907 
TEST(TestMutation,Basic)2908 TEST(TestMutation, Basic) {
2909   auto graph = std::make_shared<Graph>();
2910   std::unordered_map<std::string, Value*> vmap;
2911   parseIR(
2912       R"IR(
2913 graph(%x.1 : Tensor):
2914   %2 : int = prim::Constant[value=1]()
2915   %9 : int = prim::Constant[value=4]()
2916   %x.3 : Tensor = aten::add(%x.1, %2, %2)
2917   %7 : Tensor = aten::add_(%x.3, %2, %2)
2918   %y.1 : Tensor = aten::add(%x.3, %9, %2)
2919   return (%y.1))IR",
2920       &*graph,
2921       vmap);
2922   RemoveTensorMutation(graph, [](Node*) { return false; });
2923   testing::FileCheck().check("aten::add_")->run(*graph);
2924   RemoveTensorMutation(graph, [](Node*) { return true; });
2925   testing::FileCheck().check_not("aten::add_")->run(*graph);
2926 }
2927 
TEST(TestInplaceToFunctionalActivation,Basic)2928 TEST(TestInplaceToFunctionalActivation, Basic) {
2929   auto graph = std::make_shared<Graph>();
2930   std::unordered_map<std::string, Value*> vmap;
2931   parseIR(
2932       R"IR(
2933 graph(%x.1 : Tensor):
2934   %2 : int = prim::Constant[value=1]()
2935   %x.3 : Tensor = aten::add(%x.1, %2, %2)
2936   %y : Tensor = aten::relu_(%x.3)
2937   return (%y))IR",
2938       &*graph,
2939       vmap);
2940   InplaceToFunctionalActivation(graph);
2941   testing::FileCheck().check("aten::relu")->run(*graph);
2942   testing::FileCheck().check_not("aten::relu_")->run(*graph);
2943 }
2944 
TEST(TestRegisterShapeOp,Basic)2945 TEST(TestRegisterShapeOp, Basic) {
2946   auto graph = std::make_shared<Graph>();
2947   std::unordered_map<std::string, Value*> vmap;
2948   parseIR(
2949       R"IR(
2950 graph():
2951   %2 : int = prim::Constant[value=5]()
2952   %3: int[] = prim::ListConstruct(%2, %2)
2953   return (%3))IR",
2954       &*graph,
2955       vmap);
2956 
2957   auto g2 = std::make_shared<Graph>();
2958   parseIR(
2959       R"IR(
2960 graph():
2961   %2 : Tensor = prim::MakeTestTensor()
2962   return (%2))IR",
2963       &*g2,
2964       vmap);
2965 
2966   const FunctionSchema& schema = g2->nodes().begin()->schema();
2967   torch::jit::RegisterShapeComputeGraphForSchema(schema, graph);
2968   PropagateShapesOnGraph(g2);
2969   testing::FileCheck().check("5, 5")->run(*g2);
2970 }
2971 
TEST(TestFunctionalToInplaceActivation,Basic)2972 TEST(TestFunctionalToInplaceActivation, Basic) {
2973   auto graph = std::make_shared<Graph>();
2974   std::unordered_map<std::string, Value*> vmap;
2975   parseIR(
2976       R"IR(
2977 graph(%x.1 : Tensor):
2978   %2 : int = prim::Constant[value=1]()
2979   %x.3 : Tensor = aten::add(%x.1, %2, %2)
2980   %y : Tensor = aten::relu(%x.3)
2981   return (%y))IR",
2982       &*graph,
2983       vmap);
2984   FunctionalToInplaceActivation(graph);
2985   testing::FileCheck().check("aten::relu_")->run(*graph);
2986   testing::FileCheck().check_not("aten::relu(")->run(*graph);
2987 }
2988 
TEST(TestFunctionExecutor,SimpleExecutorTest)2989 TEST(TestFunctionExecutor, SimpleExecutorTest) {
2990   auto graph = std::make_shared<Graph>();
2991   parseIR(
2992       R"IR(
2993 graph(%x.1 : Tensor):
2994   %2 : int = prim::Constant[value=1]()
2995   %x.3 : Tensor = aten::add(%x.1, %2, %2)
2996   %y : Tensor = aten::relu(%x.3)
2997   return (%y))IR",
2998       &*graph);
2999   {
3000     auto func = std::make_unique<GraphFunction>(
3001         "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::PROFILING);
3002     auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3003     Stack stack = {a};
3004     func->run(stack);
3005     auto g = lastExecutedOptimizedGraph();
3006     testing::FileCheck()
3007         .check("prim::profile")
3008         ->check("aten::add")
3009         ->check("aten::relu")
3010         ->run(*g);
3011   }
3012   {
3013     auto func = std::make_unique<GraphFunction>(
3014         "name", graph, [](GraphFunction&) {}, ExecutorExecutionMode::SIMPLE);
3015     auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3016     Stack stack = {a};
3017     func->run(stack);
3018     auto g = func->getDebugState().graph;
3019     testing::FileCheck()
3020         .check_not("prim::profile")
3021         ->check("aten::add")
3022         ->check("aten::relu")
3023         ->run(*g);
3024   }
3025 }
3026 
TEST(TestFunctionExecutor,RunDecompositionTest)3027 TEST(TestFunctionExecutor, RunDecompositionTest) {
3028   static auto* func = torch::jit::GetDecompositionExecutor(
3029       "aten::var(Tensor self, bool unbiased=True) -> Tensor");
3030   for (bool unbiased : {true, false}) {
3031     auto input = at::rand({4, 4});
3032     Stack stack = {input, unbiased};
3033     func->run(stack);
3034     at::Tensor out = pop(stack).toTensor();
3035     ASSERT_TRUE(at::allclose(out, input.var(unbiased)));
3036   }
3037 }
3038 
TEST(TestShapeGraphLinting,Basic)3039 TEST(TestShapeGraphLinting, Basic) {
3040   auto schemas = RegisteredShapeComputeSchemas();
3041   for (const auto& schema : schemas) {
3042     // arange does not actually support complex, leave as
3043     // union[int, float] for now
3044     if (schema->name() == "aten::arange") {
3045       continue;
3046     }
3047     auto g = shapeComputeGraphForSchema(*schema);
3048     TORCH_INTERNAL_ASSERT(g);
3049     LintShapeComputeGraph(schema, *g);
3050   }
3051 }
3052 
3053 // TODO: move to test_kernel when global settings are explicit
3054 // fusion parameters
3055 class Composed : public ::testing::Test {
3056  public:
SetUp()3057   void SetUp() override {
3058     torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3059   }
3060 };
3061 
TEST_F(Composed,ComposedOp)3062 TEST_F(Composed, ComposedOp) {
3063   struct WithCPUFuser {
3064     WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
3065       overrideCanFuseOnCPU(val);
3066     }
3067 
3068     ~WithCPUFuser() {
3069       overrideCanFuseOnCPU(cpuFuserEnabled);
3070     }
3071 
3072     bool cpuFuserEnabled;
3073   };
3074 
3075 #ifdef TORCH_ENABLE_LLVM
3076   const auto graph_string = R"IR(
3077       graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
3078             %1 : Float(5, 3, strides=[1, 5], device=cpu)):
3079         %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
3080         %3 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %2)
3081         %4 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %3)
3082         return (%3, %4))IR";
3083   auto graph = std::make_shared<Graph>();
3084   parseIR(graph_string, &*graph);
3085 
3086   // wrong input sizes so we hit the fallback path
3087   auto a = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
3088   auto b = at::rand({2, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat))
3089                .transpose(0, 1);
3090   auto ref1 = a * (a * b);
3091   auto ref2 = a * ref1;
3092   WithCPUFuser g(true);
3093   bool fusable_on_device = torch::jit::tensorexpr::getTEMustUseLLVMOnCPU();
3094   torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = false;
3095   FuseTensorExprs(
3096       graph,
3097       /*min_group_size*/ 2,
3098       /*add_composed_op*/ true,
3099       /*fuse_to_dynamic_shapes*/ true);
3100   Code code(graph, "");
3101   InterpreterState interpreter{code};
3102   std::vector<IValue> stack = {a, b};
3103   interpreter.run(stack);
3104   at::Tensor out2 = pop(stack).toTensor();
3105   at::Tensor out1 = pop(stack).toTensor();
3106   ASSERT_TRUE(at::allclose(ref1, out1));
3107   ASSERT_TRUE(at::allclose(ref2, out2));
3108 
3109   auto inp_1 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3110   auto inp_2 = at::ones({4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
3111   stack = {inp_1, inp_2, a, b};
3112   InterpreterState interpreter2{code};
3113   interpreter2.run(stack);
3114   out2 = pop(stack).toTensor();
3115   out1 = pop(stack).toTensor();
3116   ASSERT_TRUE(at::allclose(ref1, out1));
3117   ASSERT_TRUE(at::allclose(ref2, out2));
3118   // inp_1 is on the bottom of the stack, and corresponds
3119   // to the second output. inp_2 is on the top corresponds to first output
3120   ASSERT_TRUE(at::allclose(inp_1, ref2));
3121   ASSERT_TRUE(at::allclose(inp_2, ref1));
3122   torch::jit::tensorexpr::getTEMustUseLLVMOnCPU() = fusable_on_device;
3123 #endif
3124 }
3125 
TEST(ConstantPropagation,CustomClassesCanBePropagated)3126 TEST(ConstantPropagation, CustomClassesCanBePropagated) {
3127 #ifdef USE_PYTORCH_QNNPACK
3128   const auto src = R"IR(
3129     graph():
3130         %none: NoneType = prim::Constant()
3131         %dim: int = prim::Constant[value=3]()
3132         %shape: int[] = prim::ListConstruct(%dim, %dim)
3133         %weight: Tensor = aten::ones(%shape, %none, %none, %none, %none)
3134         %scale: float = prim::Constant[value=1.]()
3135         %zero_point: int = prim::Constant[value=0]()
3136         %dtype: int = prim::Constant[value=12]()
3137         %weight_q: Tensor = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
3138         %params: __torch__.torch.classes.quantized.LinearPackedParamsBase = quantized::linear_prepack(%weight_q, %none)
3139         return (%params)
3140   )IR";
3141   auto graph = std::make_shared<Graph>();
3142   std::unordered_map<std::string, Value*> vmap;
3143   parseIR(src, graph.get(), vmap);
3144 
3145   ConstantPropagation(graph);
3146 
3147   testing::FileCheck().check_not("quantized::linear_prepack")->run(*graph);
3148 #endif
3149 }
3150 
3151 } // namespace jit
3152 } // namespace torch
3153