xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_alias_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/jit/frontend/ir_emitter.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/ir/type_hashing.h>
8 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
9 #include <torch/csrc/jit/runtime/custom_operator.h>
10 #include <torch/csrc/jit/runtime/graph_iterator.h>
11 
12 #include <ATen/TensorOperators.h>
13 
14 namespace torch {
15 namespace jit {
16 
aliasAnalysisFromSchema()17 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
18   return c10::AliasAnalysisKind::FROM_SCHEMA;
19 }
20 
21 // Fixture to set up a graph and make assertions clearer
22 class TopologicalMoveTest : public ::testing::Test {
23  protected:
TopologicalMoveTest()24   TopologicalMoveTest() {
25     createGraph();
26     aliasDb = std::make_unique<AliasDb>(graph);
27   }
28 
29   // Nodes are named after their output.
30   // e.g. "a" is an alias for "the node that outputs the value `a`"
createGraph()31   void createGraph() {
32     graph = std::make_shared<Graph>();
33     createNode("a", {});
34     createNode("b", {"a"});
35     createNode("c", {});
36     createNode("d", {"a", "b"});
37     createNode("e", {"c", "b"});
38     createNode("f", {"e"});
39     createNode("g", {"e"});
40     createNode("h", {"g"});
41     createNode("i", {"g"});
42     createNode("j", {"i"});
43     createNode("k", {"i"});
44     createNode("l", {"a"});
45     createNode("m", {}, {"l"}); // block depends on l
46     createNode("n", {"m"});
47     createNode("o", {"n"});
48     createNode("p", {});
49     createNode("q", {});
50     createNode("r", {"q"});
51     createNode("s", {"q"});
52 
53     graph->lint();
54   }
55 
createNode(const std::string & name,const std::vector<std::string> & inputNames,const std::vector<std::string> & blockInputNames={})56   void createNode(
57       const std::string& name,
58       const std::vector<std::string>& inputNames,
59       const std::vector<std::string>& blockInputNames = {}) {
60     std::vector<Value*> inputs;
61     for (const auto& name_ : inputNames) {
62       // NOLINTNEXTLINE(performance-inefficient-vector-operation)
63       inputs.push_back(nodes.at(name_)->output());
64     }
65     auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
66     node->output()->setDebugName(name);
67     nodes[name] = node;
68 
69     if (blockInputNames.size() != 0) {
70       node->addBlock();
71       std::vector<Value*> blockDeps;
72       for (const auto& name_ : blockInputNames) {
73         // NOLINTNEXTLINE(performance-inefficient-vector-operation)
74         blockDeps.push_back(nodes.at(name_)->output());
75       }
76 
77       auto block = node->blocks().at(0);
78       block->appendNode(graph->create(prim::AutogradZero, blockDeps));
79     }
80   }
81 
moveBeforeTopologicallyValid(const std::string & toInsert,const std::string & insertPoint)82   bool moveBeforeTopologicallyValid(
83       const std::string& toInsert,
84       const std::string& insertPoint) {
85     std::function<bool(Node*, Node*)> func =
86         [this](Node* toInsert, Node* insertPoint) {
87           return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
88         };
89     return moveWithChecks(toInsert, insertPoint, func);
90   }
91 
moveAfterTopologicallyValid(const std::string & toInsert,const std::string & insertPoint)92   bool moveAfterTopologicallyValid(
93       const std::string& toInsert,
94       const std::string& insertPoint) {
95     std::function<bool(Node*, Node*)> func =
96         [this](Node* toInsert, Node* insertPoint) {
97           return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
98         };
99     return moveWithChecks(toInsert, insertPoint, func);
100   }
101 
moveWithChecks(const std::string & toInsert,const std::string & insertPoint,std::function<bool (Node *,Node *)> func)102   bool moveWithChecks(
103       const std::string& toInsert,
104       const std::string& insertPoint,
105       std::function<bool(Node*, Node*)> func) {
106     auto n = nodes.at(toInsert);
107     auto insert = nodes.at(insertPoint);
108     bool isAfter = n->isAfter(insert);
109 
110     std::vector<Node*> originalOrdering;
111     Node* original = isAfter ? n->next() : n->prev();
112 
113     auto curNode = original;
114     while (curNode != n->owningBlock()->return_node()) {
115       originalOrdering.push_back(curNode);
116       if (isAfter) {
117         curNode = curNode->next();
118       } else {
119         curNode = curNode->prev();
120       }
121     }
122 
123     const auto couldMove = func(n, insert);
124     // Check the graph is okay
125     graph->lint();
126 
127     // If this is the picture of nodes
128     // <some nodes> ... toInsert ... <some more nodes> ... insertPoint
129     // ^----------^ check that these nodes haven't moved
130     curNode = original;
131     size_t idx = 0;
132     while (curNode != n->owningBlock()->return_node()) {
133       EXPECT_TRUE(originalOrdering[idx] == curNode);
134       if (isAfter) {
135         curNode = curNode->next();
136       } else {
137         curNode = curNode->prev();
138       }
139       idx++;
140     }
141 
142     return couldMove;
143   }
144 
checkPostCondition(const std::string & toInsert,const std::string & insertPoint,bool after)145   void checkPostCondition(
146       const std::string& toInsert,
147       const std::string& insertPoint,
148       bool after) {
149     if (after) {
150       EXPECT_EQ(nodes.at(toInsert)->prev(), nodes.at(insertPoint));
151     } else {
152       EXPECT_EQ(nodes.at(toInsert)->next(), nodes.at(insertPoint));
153     }
154   }
155 
156   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
157   std::shared_ptr<Graph> graph;
158   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
159   std::unique_ptr<AliasDb> aliasDb;
160   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
161   std::unordered_map<std::string, Node*> nodes;
162 };
163 
TEST_F(TopologicalMoveTest,SplitsDeps)164 TEST_F(TopologicalMoveTest, SplitsDeps) {
165   // Check that we are removing `this`'s deps properly when we need to split
166   // `this` and deps (see code for what the hell that means)
167   EXPECT_TRUE(moveBeforeTopologicallyValid("q", "s"));
168   checkPostCondition("q", "s", false);
169 }
170 
171 // Move after
TEST_F(TopologicalMoveTest,MoveAfterBackwardSimple)172 TEST_F(TopologicalMoveTest, MoveAfterBackwardSimple) {
173   // Simple move backward
174   EXPECT_TRUE(moveAfterTopologicallyValid("c", "a"));
175   checkPostCondition("c", "a", true);
176 }
TEST_F(TopologicalMoveTest,MoveAfterBackwardInvalid)177 TEST_F(TopologicalMoveTest, MoveAfterBackwardInvalid) {
178   // simple invalid move backward
179   EXPECT_FALSE(moveAfterTopologicallyValid("d", "a"));
180 }
181 
TEST_F(TopologicalMoveTest,MoveAfterNoOp)182 TEST_F(TopologicalMoveTest, MoveAfterNoOp) {
183   // doesn't actually move anything
184   EXPECT_TRUE(moveAfterTopologicallyValid("f", "e"));
185   checkPostCondition("f", "e", true);
186 }
187 
TEST_F(TopologicalMoveTest,MoveAfterBackwardMultipleDeps)188 TEST_F(TopologicalMoveTest, MoveAfterBackwardMultipleDeps) {
189   // move backward with multiple dependencies
190   EXPECT_TRUE(moveAfterTopologicallyValid("e", "c"));
191   checkPostCondition("e", "c", true);
192 }
193 
TEST_F(TopologicalMoveTest,MoveAfterBackwardNonZeroWorkingSet)194 TEST_F(TopologicalMoveTest, MoveAfterBackwardNonZeroWorkingSet) {
195   // Move backward with non-zero working set
196   EXPECT_TRUE(moveAfterTopologicallyValid("k", "f"));
197   checkPostCondition("k", "f", true);
198 }
199 
TEST_F(TopologicalMoveTest,MoveAfterForwardSimple)200 TEST_F(TopologicalMoveTest, MoveAfterForwardSimple) {
201   // Simple move forward
202   EXPECT_TRUE(moveAfterTopologicallyValid("c", "d"));
203   checkPostCondition("c", "d", true);
204 }
205 
TEST_F(TopologicalMoveTest,MoveAfterForwardNonZeroWorkingSet)206 TEST_F(TopologicalMoveTest, MoveAfterForwardNonZeroWorkingSet) {
207   // Move forward with non-zero working set
208   EXPECT_TRUE(moveAfterTopologicallyValid("f", "l"));
209   checkPostCondition("f", "l", true);
210 }
211 
212 // Move before
TEST_F(TopologicalMoveTest,MoveBeforeForwardSimple)213 TEST_F(TopologicalMoveTest, MoveBeforeForwardSimple) {
214   // Simple move forward
215   EXPECT_TRUE(moveBeforeTopologicallyValid("b", "d"));
216   checkPostCondition("b", "d", false);
217 }
218 
TEST_F(TopologicalMoveTest,MoveBeforeBackwardSimple)219 TEST_F(TopologicalMoveTest, MoveBeforeBackwardSimple) {
220   // Simple move backward
221   EXPECT_TRUE(moveBeforeTopologicallyValid("c", "a"));
222   checkPostCondition("c", "a", false);
223 }
224 
TEST_F(TopologicalMoveTest,MoveBeforeNoOp)225 TEST_F(TopologicalMoveTest, MoveBeforeNoOp) {
226   // doesn't actually move anything
227   EXPECT_TRUE(moveBeforeTopologicallyValid("a", "b"));
228   checkPostCondition("a", "b", false);
229 }
230 
TEST_F(TopologicalMoveTest,MoveBeforeForwardWithDeps)231 TEST_F(TopologicalMoveTest, MoveBeforeForwardWithDeps) {
232   // move forward with deps
233   EXPECT_TRUE(moveBeforeTopologicallyValid("f", "m"));
234   checkPostCondition("f", "m", false);
235 }
236 
TEST_F(TopologicalMoveTest,MoveBeforeBackwardWithDeps)237 TEST_F(TopologicalMoveTest, MoveBeforeBackwardWithDeps) {
238   // move backward with deps
239   EXPECT_TRUE(moveBeforeTopologicallyValid("l", "f"));
240   checkPostCondition("l", "f", false);
241 }
242 
243 // check that dependencies in blocks are recognized
TEST_F(TopologicalMoveTest,DepsDisallowMove)244 TEST_F(TopologicalMoveTest, DepsDisallowMove) {
245   EXPECT_FALSE(moveAfterTopologicallyValid("l", "m"));
246   EXPECT_FALSE(moveBeforeTopologicallyValid("m", "l"));
247   EXPECT_FALSE(moveAfterTopologicallyValid("n", "l"));
248   EXPECT_FALSE(moveBeforeTopologicallyValid("l", "n"));
249 }
250 
251 // Test that moveAfter(n) and moveBefore(n->next()) are not necessarily
252 // equivalent. Here, the dependency ordering is n -> o -> p.  So we can't
253 // move `n` after `o`, but we can move `n` before `p` (which pushes `o` after
254 // `p`)
TEST_F(TopologicalMoveTest,MoveAfterBeforeWithDeps)255 TEST_F(TopologicalMoveTest, MoveAfterBeforeWithDeps) {
256   EXPECT_FALSE(moveAfterTopologicallyValid("n", "o"));
257   EXPECT_TRUE(moveBeforeTopologicallyValid("o", "p"));
258   checkPostCondition("o", "p", false);
259 }
260 
261 namespace {
insertIf(Graph & g,Value * condValue,std::function<std::vector<Value * > ()> trueInst,std::function<std::vector<Value * > ()> falseInst)262 Node* insertIf(
263     Graph& g,
264     Value* condValue,
265     std::function<std::vector<Value*>()> trueInst,
266     std::function<std::vector<Value*>()> falseInst) {
267   auto if_ = g.insertNode(g.create(prim::If, 0));
268   if_->addInput(condValue); // condition value
269   auto trueBlock = if_->addBlock();
270   auto falseBlock = if_->addBlock();
271   {
272     // Mutate in true block
273     WithInsertPoint g(trueBlock);
274     auto outputs = trueInst();
275     for (auto output : outputs) {
276       trueBlock->registerOutput(output);
277     }
278   }
279   {
280     WithInsertPoint g(falseBlock);
281     auto outputs = falseInst();
282     for (auto output : outputs) {
283       falseBlock->registerOutput(output);
284     }
285   }
286 
287   EXPECT_TRUE(trueBlock->outputs().size() == falseBlock->outputs().size());
288   for (auto output : trueBlock->outputs()) {
289     if_->addOutput()->setType(output->type());
290   }
291   return if_;
292 }
293 
294 template <class Exception, class Functor>
expectThrows(Functor && functor,const char * expectMessageContains)295 inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
296   try {
297     std::forward<Functor>(functor)();
298   } catch (const Exception& e) {
299     if (std::string(e.what()).find(expectMessageContains) ==
300         std::string::npos) {
301       AT_ERROR(
302           "Expected error message to contain \"",
303           expectMessageContains,
304           "\" but error message was: ",
305           e.what());
306     }
307     return;
308   }
309   AT_ERROR(
310       "Expected to throw exception containing \"",
311       expectMessageContains,
312       "\" but didn't throw");
313 }
314 
315 } // namespace
316 
TEST(AliasAnalysisTest,AliasingMutationBlocksMoves)317 TEST(AliasAnalysisTest, AliasingMutationBlocksMoves) {
318   auto graph = std::make_shared<Graph>();
319   auto a = graph->addInput();
320   auto b = graph->addInput();
321 
322   // addsB = b + b
323   // c = a + b
324   // a += b
325   // d = c + c
326   auto addsB = graph->insert(aten::add, {b, b});
327   auto c = graph->insert(aten::add, {a, b});
328   auto aMut = graph->insert(aten::add_, {a, b});
329   auto d = graph->insert(aten::add, {c, c});
330 
331   graph->lint();
332 
333   AliasDb aliasDb(graph);
334   // Can't move past a mutation of a used value
335   EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
336   EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
337 
338   // b should alias to a (since they are both inputs)
339   EXPECT_FALSE(
340       aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
341   EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
342 
343   graph->lint();
344 }
345 
TEST(AliasAnalysisTest,AliasingMutationBlocksMoves2)346 TEST(AliasAnalysisTest, AliasingMutationBlocksMoves2) {
347   auto graph = std::make_shared<Graph>();
348   auto a = graph->addInput();
349   auto b = graph->addInput();
350 
351   auto constant = graph->insertConstant(1);
352   auto fresh = graph->insert(aten::rand, {constant});
353   auto usesB = graph->insert(aten::add, {b, fresh});
354   auto aliasesB = graph->insert(aten::select, {a, constant, constant});
355   auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
356   graph->insert(aten::add, {fresh, aliasesB});
357   graph->lint();
358 
359   AliasDb aliasDb(graph);
360   EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
361       aliasesB->node(), mutatesAliasOfB->node()));
362   EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
363       usesB->node(), mutatesAliasOfB->node()));
364 }
365 
TEST(AliasAnalysisTest,SideEffectsBlockMoves)366 TEST(AliasAnalysisTest, SideEffectsBlockMoves) {
367   // Test moves across side effectful nodes
368   auto graph = std::make_shared<Graph>();
369   auto a = graph->addInput();
370   auto print1 = graph->insertNode(graph->create(prim::Print, {a}, 0));
371   WithInsertPoint guard(print1);
372   auto print2 = graph->insertNode(graph->create(prim::Print, {a, a}, 0));
373   AliasDb aliasDb(graph);
374 
375   // def foo(a):
376   //  print2(a, a)
377   //  print1(a)
378 
379   // test moving across each other
380   EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(print2, print1));
381   EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(print1, print2));
382 
383   // test moving where they already are
384   EXPECT_TRUE(aliasDb.moveBeforeTopologicallyValid(print2, print1));
385   EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(print1, print2));
386 
387   graph->insertNode(graph->create(prim::MakeTestTensor, {}, 1));
388   AliasDb aliasDb2(graph);
389 
390   // def foo(a):
391   //  print2(a, a)
392   //  non_side_effectful = makeTestTensor()
393   //  print1(a)
394 
395   // test moving with a side effectful node between
396   EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print2, print1));
397   EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print2, print1));
398   EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print1, print2));
399   EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print1, print2));
400 }
401 
TEST(AliasAnalysisTest,MovingAcrossInnerBlocks)402 TEST(AliasAnalysisTest, MovingAcrossInnerBlocks) {
403   // Test moves across inner blocks
404 
405   // a = rand(1)
406   // b = rand(1)
407   // if True:
408   //   a.add_(b)
409   // c = a + b
410   auto graph = std::make_shared<Graph>();
411   auto constant = graph->insertConstant(1);
412   auto a = graph->insert(aten::rand, {constant});
413   auto b = graph->insert(aten::rand, {constant});
414 
415   auto if_ = insertIf(
416       *graph,
417       constant,
418       [&]() -> std::vector<Value*> {
419         auto aMut = graph->insert(aten::add_, {a, b});
420         return {aMut};
421       },
422       [&]() -> std::vector<Value*> { return {a}; });
423 
424   auto c = graph->insert(aten::add, {a, b});
425 
426   graph->lint();
427 
428   // we should not be able to move `c` before the if statement, since it
429   // may write to `a`.
430   AliasDb aliasDb(graph);
431   EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
432 }
433 
TEST(AliasAnalysisTest,NoneHasNoWriters)434 TEST(AliasAnalysisTest, NoneHasNoWriters) {
435   auto graph = std::make_shared<Graph>();
436   std::unordered_map<std::string, Value*> vmap;
437   parseIR(
438       R"IR(
439     graph():
440       %opt : Tensor? = prim::Constant()
441       %out : Tensor = prim::unchecked_unwrap_optional(%opt)
442       %ret.2 : Tensor = aten::div(%out, %out, %out)
443       return (%opt, %out, %ret.2)
444       )IR",
445       &*graph,
446       vmap);
447 
448   AliasDb aliasDb(graph);
449   EXPECT_FALSE(aliasDb.hasWriters(vmap["opt"]->node()));
450 }
451 
TEST(AliasAnalysisTest,SafeToChangeAliasingRelationship)452 TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
453   auto graph = std::make_shared<Graph>();
454   std::unordered_map<std::string, Value*> vmap;
455   parseIR(
456       R"IR(
457   graph(%x : Tensor):
458       %3 : int = prim::Constant[value=1]()
459       %2 : int = prim::Constant[value=0]()
460       %b : Tensor = aten::add(%x, %2, %3)
461       %c : Tensor = aten::add(%x, %2, %3)
462       %d : Tensor = aten::add(%x, %2, %3)
463       %e : Tensor = aten::add(%x, %2, %3)
464       %f : Tensor[] = prim::ListConstruct(%e)
465       %14 : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
466       return (%14)
467     )IR",
468       &*graph,
469       vmap);
470 
471   AliasDb aliasDb(graph);
472   // x, b, c escape scope, so we can't introduce an aliasing relationship
473   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["b"]));
474   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["x"]));
475   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["c"]));
476   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["b"]));
477 
478   // e aliases the wildcard set because it's contained in a list
479   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["e"], vmap["x"]));
480   EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["e"]));
481 
482   // d is a temporary with no writers, safe to change aliasing relationship
483   // here
484   EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["d"]));
485   EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
486 }
487 
488 class BatchAndInstanceNormFixture
489     : public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
490 };
491 
TEST_P(BatchAndInstanceNormFixture,BatchAndInstanceNorm)492 TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
493   auto param = GetParam();
494   auto fnName = std::get<0>(param);
495   auto nodeKind = std::get<1>(param);
496   auto isTraining = std::get<2>(param);
497   std::string isTrainingStr = std::to_string((int)isTraining);
498 
499   auto graph = std::make_shared<Graph>();
500 
501   parseIR(
502       R"IR(
503   graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
504       %none : NoneType = prim::Constant()
505       %training : bool = prim::Constant[value=)IR" +
506           isTrainingStr + R"IR(]()
507       %momentum : float = prim::Constant[value=1.0]()
508       %eps : float = prim::Constant[value=1.0e-9]()
509       %cudnn_enabled : bool = prim::Constant[value=0]()
510       %res : Tensor = )IR" +
511           fnName +
512           R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
513       return (%res)
514     )IR",
515       &*graph);
516 
517   graph->lint();
518   DepthFirstGraphNodeIterator it(graph);
519 
520   Node* n = nullptr;
521   while ((n = it.next()) != nullptr) {
522     if (n->kind() == nodeKind) {
523       break;
524     }
525   }
526   EXPECT_TRUE(n != nullptr);
527 
528   AliasDb aliasDb(graph);
529   EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
530 }
531 
TEST_P(BatchAndInstanceNormFixture,BatchAndInstanceNormTrainingUnknown)532 TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
533   auto param = GetParam();
534   auto fnName = std::get<0>(param);
535   auto nodeKind = std::get<1>(param);
536 
537   auto graph = std::make_shared<Graph>();
538 
539   parseIR(
540       R"IR(
541   graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
542       %none : NoneType = prim::Constant()
543       %momentum : float = prim::Constant[value=1.0]()
544       %eps : float = prim::Constant[value=1.0e-9]()
545       %cudnn_enabled : bool = prim::Constant[value=0]()
546       %res : Tensor = )IR" +
547           fnName +
548           R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
549       return (%res)
550     )IR",
551       &*graph);
552 
553   graph->lint();
554   DepthFirstGraphNodeIterator it(graph);
555 
556   Node* n = nullptr;
557   while ((n = it.next()) != nullptr) {
558     if (n->kind() == nodeKind) {
559       break;
560     }
561   }
562   EXPECT_TRUE(n != nullptr);
563 
564   AliasDb aliasDb(graph);
565   EXPECT_TRUE(aliasDb.hasWriters(n));
566 }
567 
TEST_P(BatchAndInstanceNormFixture,BatchNormTrainingWithNoMeanOrVar)568 TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
569   auto param = GetParam();
570   auto fnName = std::get<0>(param);
571   auto nodeKind = std::get<1>(param);
572   auto isTraining = std::get<2>(param);
573   std::string isTrainingStr = std::to_string((int)isTraining);
574 
575   auto graph = std::make_shared<Graph>();
576 
577   parseIR(
578       R"IR(
579   graph(%input : Tensor):
580       %none : NoneType = prim::Constant()
581       %training : bool = prim::Constant[value=)IR" +
582           isTrainingStr + R"IR(]()
583       %momentum : float = prim::Constant[value=1.0]()
584       %eps : float = prim::Constant[value=1.0e-9]()
585       %cudnn_enabled : bool = prim::Constant[value=0]()
586       %res : Tensor = )IR" +
587           fnName +
588           R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
589       return (%res)
590     )IR",
591       &*graph);
592 
593   graph->lint();
594   DepthFirstGraphNodeIterator it(graph);
595 
596   Node* n = nullptr;
597   while ((n = it.next()) != nullptr) {
598     if (n->kind() == nodeKind) {
599       break;
600     }
601   }
602   EXPECT_TRUE(n != nullptr);
603 
604   AliasDb aliasDb(graph);
605   EXPECT_FALSE(aliasDb.hasWriters(n));
606 }
607 
608 INSTANTIATE_TEST_SUITE_P(
609     AliasAnalysisTest,
610     BatchAndInstanceNormFixture,
611     ::testing::Values(
612         std::make_tuple("aten::batch_norm", aten::batch_norm, false),
613         std::make_tuple("aten::instance_norm", aten::instance_norm, false),
614         std::make_tuple("aten::batch_norm", aten::batch_norm, true),
615         std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
616 
TEST(WriteTrackingTest,Basic)617 TEST(WriteTrackingTest, Basic) {
618   RegisterOperators reg({Operator(
619       "prim::creates_alias(Tensor(a) x) -> Tensor(a)",
620       [](Stack&) {},
621       aliasAnalysisFromSchema())});
622   const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
623   auto graph = std::make_shared<Graph>();
624   auto a = graph->addInput();
625   auto b = graph->addInput();
626 
627   // aten::add(%b, %b)
628   // aten::add_(%a, %b)
629   // foo::creates_alias(%a)
630   auto pureNode = graph->insert(aten::add, {b, b})->node();
631   auto writingNode = graph->insert(aten::add_, {a, b})->node();
632   auto node3 = graph->insert(creates_alias, {a})->node();
633   auto aAlias = node3->output();
634 
635   graph->lint();
636 
637   AliasDb aliasDb(graph);
638   EXPECT_TRUE(aliasDb.mayAlias(aAlias, a));
639   EXPECT_TRUE(aliasDb.mayAlias(a, b));
640   EXPECT_FALSE(
641       aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
642   EXPECT_FALSE(
643       aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
644   EXPECT_TRUE(
645       aliasDb.writesToAlias(writingNode, std::unordered_set<const Value*>{a}));
646   EXPECT_TRUE(aliasDb.writesToAlias(
647       writingNode, std::unordered_set<const Value*>{a, b}));
648   EXPECT_TRUE(aliasDb.writesToAlias(
649       writingNode, std::unordered_set<const Value*>{aAlias}));
650 }
651 
TEST(WriteTrackingTest,IsMutable)652 TEST(WriteTrackingTest, IsMutable) {
653   auto graph = std::make_shared<Graph>();
654   parseIR(
655       R"IR(
656   graph(%x: Tensor):
657     %b : Tensor = aten::relu_(%x)
658     return (%b)
659     )IR",
660       &*graph);
661   auto node_iter = graph->block()->nodes().begin();
662   auto relu = *node_iter;
663   AliasDb aliasDb(graph);
664   EXPECT_TRUE(aliasDb.isMutable(relu));
665 }
666 
TEST(WriteTrackingTest,IsImmutable)667 TEST(WriteTrackingTest, IsImmutable) {
668   auto graph = std::make_shared<Graph>();
669   parseIR(
670       R"IR(
671   graph(%x: Tensor, %y : Tensor):
672     %b : Tensor = aten::mul(%x, %y)
673     return (%b)
674     )IR",
675       &*graph);
676   auto node_iter = graph->block()->nodes().begin();
677   auto mul = *node_iter;
678   AliasDb aliasDb(graph);
679   EXPECT_FALSE(aliasDb.isMutable(mul));
680 }
681 
TEST(WriteTrackingTest,HasWriters)682 TEST(WriteTrackingTest, HasWriters) {
683   auto graph = std::make_shared<Graph>();
684   std::unordered_map<std::string, Value*> vmap;
685   parseIR(
686       R"IR(
687   graph(%x: Tensor, %y : Tensor):
688     %c1 : int = prim::Constant[value=1]()
689     %b : Tensor = aten::add_(%x, %y, %c1)
690     return (%b)
691     )IR",
692       &*graph,
693       vmap);
694   auto add = vmap["b"]->node();
695   AliasDb aliasDb(graph);
696   EXPECT_TRUE(aliasDb.hasWriters(add));
697   EXPECT_TRUE(aliasDb.isMutable(add));
698 }
699 
TEST(ContainerAliasingTest,MayContainAlias)700 TEST(ContainerAliasingTest, MayContainAlias) {
701   auto graph = std::make_shared<Graph>();
702   std::unordered_map<std::string, Value*> vmap;
703   parseIR(
704       R"IR(
705   graph(%inp: Tensor[]):
706     %x : str = prim::Constant[value="a"]()
707     %y : Tensor = prim::Constant()
708     %z : Tensor = prim::Constant()
709     %a : (Tensor) = prim::TupleConstruct(%y)
710     %b : Dict(str, Tensor) = prim::DictConstruct(%x, %y)
711     %c : Tensor[] = prim::ListConstruct(%y)
712     return (%a, %b, %c)
713     )IR",
714       &*graph,
715       vmap);
716 
717   auto str_output = vmap["x"];
718   auto ten_output = vmap["y"];
719   auto local_var = vmap["z"];
720   AliasDb aliasDb(graph);
721 
722   EXPECT_TRUE(graph->outputs().size() == 3);
723   for (auto out : graph->outputs()) {
724     EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, out));
725     EXPECT_FALSE(aliasDb.mayContainAlias(local_var, out));
726   }
727 
728   EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->inputs()));
729   EXPECT_FALSE(aliasDb.mayContainAlias(local_var, graph->inputs()));
730 
731   EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->outputs()));
732   EXPECT_TRUE(aliasDb.mayContainAlias(
733       at::ArrayRef<Value*>{ten_output}, graph->outputs()));
734   EXPECT_FALSE(aliasDb.mayContainAlias(str_output, graph->outputs()));
735 }
736 
TEST(ContainerAliasingTest,MayContainAlias_cast)737 TEST(ContainerAliasingTest, MayContainAlias_cast) {
738   auto graph = std::make_shared<Graph>();
739   std::unordered_map<std::string, Value*> vmap;
740   parseIR(
741       R"IR(
742   graph(%input.1 : Tensor):
743     %2 : NoneType = prim::Constant()
744     %3 : bool = prim::Constant[value=0]()
745     %4 : int = prim::Constant[value=6]()
746     %5 : int = prim::Constant[value=1]()
747     %a.1 : Tensor = aten::add(%input.1, %input.1, %5)
748     %b.1 : Tensor = aten::to(%a.1, %4, %3, %3, %2)
749     %c.1 : Tensor = aten::mul(%b.1, %b.1)
750     return (%c.1)
751     )IR",
752       &*graph,
753       vmap);
754 
755   auto a = vmap["a.1"];
756   auto b = vmap["b.1"];
757   auto c = vmap["c.1"];
758   AliasDb aliasDb(graph);
759 
760   EXPECT_TRUE(graph->outputs().size() == 1);
761   for (auto out : graph->outputs()) {
762     EXPECT_TRUE(aliasDb.mayContainAlias(c, out));
763   }
764 
765   EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
766   EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->inputs()));
767 
768   EXPECT_TRUE(aliasDb.mayContainAlias(c, graph->outputs()));
769   EXPECT_TRUE(
770       aliasDb.mayContainAlias(at::ArrayRef<Value*>{c}, graph->outputs()));
771   EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->outputs()));
772 }
773 
TEST(ContainerAliasingTest,PrimitveValuesDontAliasContainers)774 TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) {
775   auto graph = std::make_shared<Graph>();
776   parseIR(
777       R"IR(
778   graph():
779     %x : str = prim::Constant[value="a"]()
780     %y : int = prim::Constant[value=1]()
781     %a : (int) = prim::TupleConstruct(%y)
782     %b : Dict(str, int) = prim::DictConstruct(%x, %y)
783     %c : int[] = prim::ListConstruct(%y)
784     return (%a, %b, %c)
785     )IR",
786       &*graph);
787 
788   auto node_iter = graph->block()->nodes().begin();
789   node_iter++; // string
790   Node* int_node = *node_iter++;
791   AliasDb aliasDb(graph);
792 
793   EXPECT_TRUE(graph->outputs().size() == 3);
794   // primitive values don't need to alias container
795   for (auto out : graph->outputs()) {
796     EXPECT_FALSE(aliasDb.mayContainAlias(int_node->output(), out));
797   }
798 }
799 
TEST(ContainerAliasingTest,UnionAliasing)800 TEST(ContainerAliasingTest, UnionAliasing) {
801   auto graph = std::make_shared<Graph>();
802   parseIR(
803       R"IR(
804   graph(%a : Dict(str, Tensor),
805         %b : Tensor[],
806         %c : Union(Dict(str, Tensor), Tensor[])):
807     return (%a, %b, %c)
808     )IR",
809       &*graph);
810 
811   AliasDb aliasDb(graph);
812   auto a = graph->outputs().at(0);
813   auto b = graph->outputs().at(1);
814   auto c = graph->outputs().at(2);
815 
816   EXPECT_TRUE(aliasDb.mayAlias(a, c));
817   EXPECT_TRUE(aliasDb.mayAlias(b, c));
818   EXPECT_TRUE(aliasDb.mayAlias(c, c));
819   EXPECT_FALSE(aliasDb.mayAlias(a, b));
820   EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
821   EXPECT_TRUE(aliasDb.mayContainAlias(a, c));
822   EXPECT_TRUE(aliasDb.mayContainAlias(b, c));
823 }
824 
TEST(ContainerAliasingTest,InputsCanAliasOutputs)825 TEST(ContainerAliasingTest, InputsCanAliasOutputs) {
826   // Test input aliasing
827   auto graph = std::make_shared<Graph>();
828   parseIR(
829       R"IR(
830   graph(%x: Tensor, %y: Tensor):
831     %a : (Tensor) = prim::TupleConstruct(%x)
832     return (%a)
833     )IR",
834       &*graph);
835 
836   auto node_iter = graph->block()->nodes().begin();
837   auto tuple_node = *node_iter;
838   AliasDb aliasDb(graph);
839 
840   for (auto input : graph->inputs()) {
841     EXPECT_TRUE(aliasDb.mayContainAlias(input, tuple_node->output()));
842   }
843   EXPECT_TRUE(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
844 }
845 
846 // Test tuple that doesn't come from construct
TEST(ContainerAliasingTest,NestedTupleConstruct)847 TEST(ContainerAliasingTest, NestedTupleConstruct) {
848   auto graph = std::make_shared<Graph>();
849   parseIR(
850       R"IR(
851 graph(%x : int,
852       %y : Tensor,
853       %z : Tensor):
854   %3 : int = prim::Constant[value=1]()
855   %4 : bool = aten::eq(%x, %3)
856   %a : (Tensor) = prim::If(%4)
857     block0():
858       %a.1 : (Tensor) = prim::TupleConstruct(%y)
859       -> (%a.1)
860     block1():
861       %a.2 : (Tensor) = prim::TupleConstruct(%z)
862       -> (%a.2)
863   return (%a)
864  )IR",
865       &*graph);
866 
867   AliasDb aliasDb(graph);
868 
869   for (auto input : graph->inputs()) {
870     if (input->type() == IntType::get()) {
871       continue;
872     }
873 
874     EXPECT_TRUE(aliasDb.mayContainAlias(input, graph->outputs().at(0)));
875   }
876 }
877 
878 // test nested types
TEST(ContainerAliasingTest,NestedTypes)879 TEST(ContainerAliasingTest, NestedTypes) {
880   auto graph = std::make_shared<Graph>();
881   parseIR(
882       R"IR(
883 graph():
884   %a : Tensor = prim::MakeTestTensor()
885   %a_list : Tensor[] = prim::ListConstruct(%a)
886   %b : Tensor = prim::MakeTestTensor()
887   %b_list : Tensor[] = prim::ListConstruct(%b)
888   %13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
889   return (%13)
890 )IR",
891       &*graph);
892   AliasDb aliasDb(graph);
893   auto g_output = graph->outputs().at(0);
894   auto list_2 = g_output->node()->inputs().at(0);
895   auto list_1 = g_output->node()->inputs().at(1);
896 
897   // TODO FIX assume conservatively for now
898   EXPECT_TRUE(aliasDb.mayContainAlias(list_1, list_2));
899   EXPECT_TRUE(aliasDb.mayContainAlias(list_2, list_1));
900 
901   EXPECT_TRUE(aliasDb.mayContainAlias(list_1, g_output));
902   EXPECT_TRUE(aliasDb.mayContainAlias(list_2, g_output));
903 }
904 
905 // simple example
TEST(ContainerAliasingTest,Simple)906 TEST(ContainerAliasingTest, Simple) {
907   auto graph = std::make_shared<Graph>();
908   parseIR(
909       R"IR(
910 graph():
911   %0 : Tensor = prim::Constant()
912   %1 : Tensor = prim::Constant()
913   %13 : (Tensor) = prim::TupleConstruct(%0)
914   return (%13)
915 )IR",
916       &*graph);
917   AliasDb aliasDb(graph);
918 
919   auto node_iter = graph->block()->nodes().begin();
920   auto first_ten = *node_iter++;
921   auto second_ten = *node_iter++;
922   auto tup_node = *node_iter;
923 
924   EXPECT_TRUE(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
925   EXPECT_TRUE(
926       !aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
927 
928   std::vector<Value*> first_st = {first_ten->output()};
929   std::vector<Value*> second_st = {second_ten->output()};
930   std::vector<Value*> tup_st = {tup_node->output()};
931   EXPECT_TRUE(aliasDb.mayContainAlias(first_st, tup_st));
932   EXPECT_FALSE(aliasDb.mayContainAlias(first_st, second_st));
933   EXPECT_FALSE(aliasDb.mayContainAlias(second_st, tup_st));
934 }
935 
TEST(ContainerAliasingTest,Lists)936 TEST(ContainerAliasingTest, Lists) {
937   auto graph = std::make_shared<Graph>();
938   std::unordered_map<std::string, Value*> vmap;
939   parseIR(
940       R"IR(
941   graph():
942     %x : str = prim::Constant[value="a"]()
943     %y : Tensor = prim::Constant()
944     %c : Tensor[] = prim::ListConstruct(%y)
945     %d : Tensor[] = prim::ListConstruct(%y)
946     return (%c, %d)
947     )IR",
948       &*graph,
949       vmap);
950 
951   AliasDb aliasDb(graph);
952   auto x = vmap["x"];
953   auto c = vmap["c"];
954   EXPECT_FALSE(aliasDb.mayContainAlias(x, c));
955   EXPECT_FALSE(aliasDb.mayContainAlias(c, x));
956 
957   auto d = vmap["d"];
958 
959   EXPECT_TRUE(aliasDb.mayContainAlias(d, c));
960   EXPECT_TRUE(aliasDb.mayContainAlias(c, d));
961 }
962 
TEST(ContainerAliasingTest,Lists2)963 TEST(ContainerAliasingTest, Lists2) {
964   // Test list container aliasing
965   auto graph = std::make_shared<Graph>();
966   std::unordered_map<std::string, Value*> vmap;
967   parseIR(
968       R"IR(
969 graph():
970   %0 : int = prim::Constant[value=2]()
971   %1 : int = prim::Constant[value=3]()
972   %2 : int[] = prim::ListConstruct(%0, %1)
973   %x : Tensor = prim::MakeTestTensor()
974   %12 : int[] = prim::ListConstruct(%0, %1)
975   %y : Tensor = prim::MakeTestTensor()
976   %22 : int[] = prim::ListConstruct(%0, %1)
977   %z : Tensor = prim::MakeTestTensor()
978   %32 : int[] = prim::ListConstruct(%0, %1)
979   %fresh : Tensor = prim::MakeTestTensor()
980   %foo : Tensor[] = prim::ListConstruct(%x, %y)
981   %43 : Tensor[] = aten::append(%foo, %z)
982   return ()
983 )IR",
984       graph.get(),
985       vmap);
986   AliasDb aliasDb(graph);
987   auto x = vmap["x"];
988   auto y = vmap["y"];
989   auto z = vmap["z"];
990   // Tensors x, y, and z went into a list, so they all may alias each other.
991   EXPECT_TRUE(aliasDb.mayAlias(x, y));
992   EXPECT_TRUE(aliasDb.mayAlias(y, z));
993   EXPECT_TRUE(aliasDb.mayAlias(x, z));
994 
995   // But we know `fresh` didn't go into a list, so x, y, and z should not
996   // alias it.
997   auto fresh = vmap["fresh"];
998   EXPECT_FALSE(aliasDb.mayAlias(x, fresh));
999   EXPECT_FALSE(aliasDb.mayAlias(y, fresh));
1000   EXPECT_FALSE(aliasDb.mayAlias(z, fresh));
1001 }
1002 
TEST(ContainerAliasingTest,Conservative)1003 TEST(ContainerAliasingTest, Conservative) {
1004   // test "conservative" analysis writes to the inside of a container.
1005   auto ops = torch::RegisterOperators(
1006       "custom::conservative", [](torch::List<at::Tensor> in) { return in; });
1007 
1008   auto graph = std::make_shared<Graph>();
1009   std::unordered_map<std::string, Value*> vmap;
1010   parseIR(
1011       R"IR(
1012 graph():
1013   %0 : int = prim::Constant[value=2]()
1014   %1 : int = prim::Constant[value=3]()
1015   %2 : int[] = prim::ListConstruct(%0, %1)
1016   %11 : Tensor = prim::MakeTestTensor()
1017   %12 : Tensor[] = prim::ListConstruct(%11)
1018   %out : Tensor[] = custom::conservative(%12)
1019   %ret.2 : Tensor = aten::div(%11, %11)
1020   return ()
1021 )IR",
1022       graph.get(),
1023       vmap);
1024   AliasDb aliasDb(graph);
1025   auto conservativeOp = vmap["out"]->node();
1026   auto tensor = vmap["11"];
1027   EXPECT_TRUE(aliasDb.writesToAlias(conservativeOp, ValueSet{tensor}));
1028 }
1029 
TEST(ContainerAliasingTest,MovesAcrossContainedWrites)1030 TEST(ContainerAliasingTest, MovesAcrossContainedWrites) {
1031   auto ops = torch::RegisterOperators().op(
1032       "uses::list",
1033       torch::RegisterOperators::options()
1034           .catchAllKernel([](torch::List<at::Tensor> in) {
1035             return torch::rand({2, 3});
1036           })
1037           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1038   // Write to the inside of a list. Check that we can't reorder a
1039   // print across it.
1040   auto graph = std::make_shared<Graph>();
1041   std::unordered_map<std::string, Value*> vmap;
1042   parseIR(
1043       R"IR(
1044 graph():
1045   %35 : int = prim::Constant[value=1]()
1046   %0 : int = prim::Constant[value=2]()
1047   %1 : int = prim::Constant[value=3]()
1048   %23 : int = prim::Constant[value=0]()
1049   %2 : int[] = prim::ListConstruct(%0, %1)
1050   %11 : Tensor = prim::MakeTestTensor()
1051   %12 : int[] = prim::ListConstruct(%0, %1)
1052   %21 : Tensor = prim::MakeTestTensor()
1053   %l : Tensor[] = prim::ListConstruct(%11, %21)
1054   %24 : Tensor = aten::select(%l, %23)
1055   %25 : int[] = prim::ListConstruct(%0, %1)
1056   %34 : Tensor = prim::MakeTestTensor()
1057   %36 : Tensor = aten::add_(%24, %34, %35)
1058   %37 : Tensor = uses::list(%l)
1059   return (%37)
1060 )IR",
1061       graph.get(),
1062       vmap);
1063   AliasDb aliasDb(graph);
1064   auto listUse = vmap["37"]->node();
1065   auto internalWrite = vmap["36"]->node();
1066   EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
1067 }
1068 
TEST(ContainerAliasingTest,MovesAcrossContainedWritesNested)1069 TEST(ContainerAliasingTest, MovesAcrossContainedWritesNested) {
1070   // The same as above, but with a nested list
1071   auto ops = torch::RegisterOperators().op(
1072       "uses::list",
1073       torch::RegisterOperators::options()
1074           .catchAllKernel([](torch::List<at::Tensor> in) {
1075             return torch::rand({2, 3});
1076           })
1077           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1078   // Write to the inside of a list. Check that we can't reorder a
1079   // print across it.
1080   auto graph = std::make_shared<Graph>();
1081   std::unordered_map<std::string, Value*> vmap;
1082   parseIR(
1083       R"IR(
1084 graph():
1085   %38 : int = prim::Constant[value=1]()
1086   %0 : int = prim::Constant[value=2]()
1087   %1 : int = prim::Constant[value=3]()
1088   %24 : int = prim::Constant[value=0]()
1089   %2 : int[] = prim::ListConstruct(%0, %1)
1090   %11 : Tensor = prim::MakeTestTensor()
1091   %12 : int[] = prim::ListConstruct(%0, %1)
1092   %21 : Tensor = prim::MakeTestTensor()
1093   %l : Tensor[] = prim::ListConstruct(%11, %21)
1094   %25 : Tensor = aten::select(%l, %24)
1095   %27 : Tensor = aten::select(%25, %24, %24)
1096   %28 : int[] = prim::ListConstruct(%0, %1)
1097   %37 : Tensor = prim::MakeTestTensor()
1098   %39 : Tensor = aten::add_(%27, %37, %38)
1099   %40 : Tensor = uses::list(%l)
1100   return (%40)
1101 )IR",
1102       graph.get(),
1103       vmap);
1104   AliasDb aliasDb(graph);
1105   auto listUse = vmap["40"]->node();
1106   auto internalWrite = vmap["39"]->node();
1107   EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
1108 }
1109 
TEST(WildcardsTest,Basic)1110 TEST(WildcardsTest, Basic) {
1111   RegisterOperators reg(
1112       {Operator(
1113            "prim::returns_wildcard(Tensor a) -> Tensor(*)",
1114            [](Stack&) {},
1115            aliasAnalysisFromSchema()),
1116        Operator(
1117            "prim::writes(Tensor(z!) a) -> Tensor(a)",
1118            [](Stack&) {},
1119            aliasAnalysisFromSchema())});
1120   const auto returns_wildcard =
1121       Symbol::fromQualString("prim::returns_wildcard");
1122   const auto writes = Symbol::fromQualString("prim::writes");
1123 
1124   auto graph = std::make_shared<Graph>();
1125   const auto a = graph->addInput();
1126 
1127   const auto constant = graph->insertConstant(1);
1128   const auto fresh = graph->insert(aten::rand, {constant});
1129   const auto fresh2 = graph->insert(aten::rand, {constant});
1130   const auto wildcard = graph->insert(returns_wildcard, {fresh});
1131 
1132   {
1133     graph->lint();
1134     AliasDb aliasDb(graph);
1135 
1136     EXPECT_FALSE(aliasDb.mayAlias(a, fresh));
1137     EXPECT_FALSE(aliasDb.mayAlias(wildcard, fresh));
1138     EXPECT_TRUE(aliasDb.mayAlias(wildcard, a));
1139     EXPECT_FALSE(aliasDb.mayAlias(ValueSet{wildcard}, ValueSet{}));
1140     EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
1141   }
1142 
1143   graph->insert(writes, {fresh2})->node();
1144   {
1145     graph->lint();
1146     AliasDb aliasDb(graph);
1147     EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
1148   }
1149 
1150   const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
1151   {
1152     graph->lint();
1153     AliasDb aliasDb(graph);
1154     // Test writes to wildcards
1155     EXPECT_FALSE(aliasDb.writesToAlias(
1156         wildcardWrite, std::unordered_set<const Value*>{fresh}));
1157     EXPECT_FALSE(aliasDb.writesToAlias(
1158         wildcardWrite, std::unordered_set<const Value*>{fresh2}));
1159     EXPECT_TRUE(aliasDb.writesToAlias(
1160         wildcardWrite, std::unordered_set<const Value*>{a}));
1161     EXPECT_TRUE(aliasDb.hasWriters(wildcard->node()));
1162   }
1163 }
1164 
1165 // test that wildcards are correctly divided by type
TEST(WildcardsTest,TypeIsolation)1166 TEST(WildcardsTest, TypeIsolation) {
1167   auto graph = std::make_shared<Graph>();
1168   std::unordered_map<std::string, Value*> vmap;
1169   parseIR(
1170       R"IR(
1171   graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
1172     %ten : Tensor = prim::Constant()
1173     %4 : Tensor[] = aten::append(%ten_list, %ten)
1174     %ten_ten_list : Tensor[][] = prim::Constant()
1175     %int_int_list : int[][] = prim::Constant()
1176     return ()
1177     )IR",
1178       &*graph,
1179       vmap);
1180   AliasDb aliasDb(graph);
1181   auto opt_ten_list = vmap["opt_ten_list"];
1182   auto ten_list = vmap["ten_list"];
1183   auto int_list = vmap["int_list"];
1184   EXPECT_FALSE(aliasDb.hasWriters(int_list));
1185   EXPECT_TRUE(aliasDb.hasWriters(opt_ten_list));
1186   EXPECT_TRUE(aliasDb.hasWriters(ten_list));
1187   EXPECT_FALSE(aliasDb.mayContainAlias(int_list, opt_ten_list));
1188   EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, opt_ten_list));
1189   EXPECT_TRUE(aliasDb.mayAlias(ten_list, opt_ten_list));
1190 
1191   auto list_of_tensor_lists = vmap["ten_ten_list"];
1192   EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, list_of_tensor_lists));
1193   EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, vmap["ten"]));
1194 
1195   EXPECT_TRUE(
1196       !aliasDb.mayContainAlias(vmap["int_int_list"], list_of_tensor_lists));
1197 }
1198 
1199 // test invariant container aliasing
1200 // the containers of different type cannot alias each other,
1201 // however they may contain elements which alias each other
TEST(WildcardsTest,InvariantContainerAliasing)1202 TEST(WildcardsTest, InvariantContainerAliasing) {
1203   {
1204     auto graph = std::make_shared<Graph>();
1205     std::unordered_map<std::string, Value*> vmap;
1206     parseIR(
1207         R"IR(
1208   graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
1209     %ten : Tensor = prim::Constant()
1210     %4 : Tensor[] = aten::append(%ten_list, %ten)
1211     return ()
1212     )IR",
1213         &*graph,
1214         vmap);
1215     AliasDb aliasDb(graph);
1216     auto ten_opt_list = vmap["ten_opt_list"];
1217     auto ten_list = vmap["ten_list"];
1218     EXPECT_FALSE(aliasDb.hasWriters(ten_opt_list));
1219     EXPECT_TRUE(aliasDb.hasWriters(ten_list));
1220     EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, ten_opt_list));
1221     EXPECT_FALSE(aliasDb.mayAlias(ten_list, ten_opt_list));
1222   }
1223   {
1224     auto graph = std::make_shared<Graph>();
1225     std::unordered_map<std::string, Value*> vmap;
1226     parseIR(
1227         R"IR(
1228   graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
1229     return ()
1230     )IR",
1231         &*graph,
1232         vmap);
1233     AliasDb aliasDb(graph);
1234     EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D"], vmap["float_2D"]));
1235   }
1236 
1237   {
1238     auto graph = std::make_shared<Graph>();
1239     std::unordered_map<std::string, Value*> vmap;
1240     parseIR(
1241         R"IR(
1242   graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
1243     return ()
1244     )IR",
1245         &*graph,
1246         vmap);
1247     AliasDb aliasDb(graph);
1248     EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D_list"], vmap["float_2D_list"]));
1249     EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_3D_list"], vmap["ten"]));
1250     EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_2D_list"], vmap["ten"]));
1251   }
1252 }
1253 
TEST(AliasRegistrationTest,ConservativeWithInferredSchema)1254 TEST(AliasRegistrationTest, ConservativeWithInferredSchema) {
1255   auto registry = torch::RegisterOperators().op(
1256       "foo::rand1",
1257       torch::RegisterOperators::options()
1258           .catchAllKernel([](at::Tensor) -> at::Tensor {
1259             return at::rand({2, 2});
1260           })
1261           .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1262   const auto rand_op = Symbol::fromQualString("foo::rand1");
1263   auto graph = std::make_shared<Graph>();
1264   auto a = graph->addInput();
1265   auto b = graph->insert(rand_op, {a});
1266   AliasDb aliasDb(graph);
1267   // Conservatively we assume there is a reference
1268   EXPECT_TRUE(aliasDb.mayAlias(a, b));
1269 }
1270 
TEST(AliasRegistrationTest,ConservativeWithSpecifiedSchema)1271 TEST(AliasRegistrationTest, ConservativeWithSpecifiedSchema) {
1272   auto registry = torch::RegisterOperators().op(
1273       "foo::rand2(Tensor arg1) -> Tensor",
1274       torch::RegisterOperators::options()
1275           .catchAllKernel([](at::Tensor) -> at::Tensor {
1276             return at::rand({2, 2});
1277           })
1278           .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1279   const auto rand_op = Symbol::fromQualString("foo::rand2");
1280   auto graph = std::make_shared<Graph>();
1281   auto a = graph->addInput();
1282   auto b = graph->insert(rand_op, {a});
1283   AliasDb aliasDb(graph);
1284   // Conservatively we assume there is a reference
1285   EXPECT_TRUE(aliasDb.mayAlias(a, b));
1286 }
1287 
TEST(AliasRegistrationTest,ConservativeWithAliasingAnnotationsShouldError)1288 TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError) {
1289   auto registry = torch::RegisterOperators().op(
1290       "foo::rand3(Tensor(a) arg1) -> Tensor(b)",
1291       torch::RegisterOperators::options()
1292           .catchAllKernel([](at::Tensor) -> at::Tensor {
1293             return at::rand({2, 2});
1294           })
1295           .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1296 
1297   const auto rand_op = Symbol::fromQualString("foo::rand3");
1298   auto graph = std::make_shared<Graph>();
1299   auto a = graph->addInput();
1300   graph->insert(rand_op, {a});
1301 
1302   // Registration time is okay, but throw exception when fetch from
1303   // registration.
1304   expectThrows<c10::Error>(
1305       [&graph] { AliasDb aliasDb(graph); },
1306       "Tried to register operator foo::rand3(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1307 }
1308 
TEST(AliasRegistrationTest,ConservativeWithAliasingAnnotationsShouldError2)1309 TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError2) {
1310   auto registry = torch::RegisterOperators().op(
1311       "foo::rand4(Tensor(a) arg1) -> Tensor(a)",
1312       torch::RegisterOperators::options()
1313           .catchAllKernel([](at::Tensor) -> at::Tensor {
1314             return at::rand({2, 2});
1315           })
1316           .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1317   const auto rand_op = Symbol::fromQualString("foo::rand4");
1318   auto graph = std::make_shared<Graph>();
1319   auto a = graph->addInput();
1320   graph->insert(rand_op, {a});
1321 
1322   // Registration time is okay, but throw exception when fetch from
1323   // registration.
1324   expectThrows<c10::Error>(
1325       [&graph] { AliasDb aliasDb(graph); },
1326       "Tried to register operator foo::rand4(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1327 }
1328 
TEST(AliasRegistrationTest,FromSchemaWithInferredSchemaShouldError)1329 TEST(AliasRegistrationTest, FromSchemaWithInferredSchemaShouldError) {
1330   expectThrows<c10::Error>(
1331       [] {
1332         torch::RegisterOperators().op(
1333             "foo::rand5",
1334             torch::RegisterOperators::options()
1335                 .catchAllKernel([](at::Tensor) -> at::Tensor {
1336                   return at::rand({2, 2});
1337                 })
1338                 .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1339       },
1340       "Tried to register operator foo::rand5(Tensor _0) -> Tensor _0 with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred");
1341 }
1342 
TEST(AliasRegistrationTest,FromSchemaInferredPure)1343 TEST(AliasRegistrationTest, FromSchemaInferredPure) {
1344   auto registry = torch::RegisterOperators().op(
1345       "foo::rand6(Tensor arg1) -> Tensor",
1346       torch::RegisterOperators::options()
1347           .catchAllKernel([](at::Tensor) -> at::Tensor {
1348             return at::rand({2, 2});
1349           })
1350           .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1351   const auto rand_op = Symbol::fromQualString("foo::rand6");
1352   auto graph = std::make_shared<Graph>();
1353   auto a = graph->addInput();
1354   auto b = graph->insert(rand_op, {a});
1355   AliasDb aliasDb(graph);
1356   // The schema doesn't contain alias information, which means it's pure
1357   // (meh!)
1358   EXPECT_FALSE(aliasDb.mayAlias(a, b));
1359 }
1360 
TEST(AliasRegistrationTest,FromSchemaAliased)1361 TEST(AliasRegistrationTest, FromSchemaAliased) {
1362   auto registry = torch::RegisterOperators().op(
1363       "foo::rand7(Tensor(a) arg1) -> Tensor(a)",
1364       torch::RegisterOperators::options()
1365           .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1366           .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1367   const auto rand_op = Symbol::fromQualString("foo::rand7");
1368 
1369   auto graph = std::make_shared<Graph>();
1370   auto a = graph->addInput();
1371   auto b = graph->insert(rand_op, {a});
1372   AliasDb aliasDb(graph);
1373   // The schema has an alias reference
1374   EXPECT_TRUE(aliasDb.mayAlias(a, b));
1375 }
1376 
TEST(AliasRegistrationTest,FromSchemaPure)1377 TEST(AliasRegistrationTest, FromSchemaPure) {
1378   auto registry = torch::RegisterOperators().op(
1379       "foo::rand8(Tensor(a) arg1) -> Tensor(b)",
1380       torch::RegisterOperators::options()
1381           .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1382           .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1383   const auto rand_op = Symbol::fromQualString("foo::rand8");
1384   auto graph = std::make_shared<Graph>();
1385   auto a = graph->addInput();
1386   auto b = graph->insert(rand_op, {a});
1387   AliasDb aliasDb(graph);
1388   // The schema does not have an alias reference
1389   EXPECT_FALSE(aliasDb.mayAlias(a, b));
1390 }
1391 
TEST(AliasRegistrationTest,PureNoSchema)1392 TEST(AliasRegistrationTest, PureNoSchema) {
1393   auto registry = torch::RegisterOperators().op(
1394       "foo::rand9",
1395       torch::RegisterOperators::options()
1396           .catchAllKernel([](at::Tensor) -> at::Tensor {
1397             return at::rand({2, 2});
1398           })
1399           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1400   const auto rand_op = Symbol::fromQualString("foo::rand9");
1401   auto graph = std::make_shared<Graph>();
1402   auto a = graph->addInput();
1403   auto b = graph->insert(rand_op, {a});
1404   AliasDb aliasDb(graph);
1405   // The schema is pure, there cannot be any alias
1406   EXPECT_FALSE(aliasDb.mayAlias(a, b));
1407 }
1408 
TEST(AliasRegistrationTest,PureWithSchema)1409 TEST(AliasRegistrationTest, PureWithSchema) {
1410   auto registry = torch::RegisterOperators().op(
1411       "foo::rand10(Tensor arg1) -> Tensor",
1412       torch::RegisterOperators::options()
1413           .catchAllKernel([](at::Tensor) -> at::Tensor {
1414             return at::rand({2, 2});
1415           })
1416           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1417   const auto rand_op = Symbol::fromQualString("foo::rand10");
1418   auto graph = std::make_shared<Graph>();
1419   auto a = graph->addInput();
1420   auto b = graph->insert(rand_op, {a});
1421   AliasDb aliasDb(graph);
1422   // The schema is pure, there cannot be any alias
1423   EXPECT_FALSE(aliasDb.mayAlias(a, b));
1424 }
1425 
TEST(AliasRegistrationTest,PureWithAnnotationsShouldError)1426 TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) {
1427   auto registry = torch::RegisterOperators().op(
1428       "foo::rand11(Tensor(a) arg1) -> Tensor(a)",
1429       torch::RegisterOperators::options()
1430           .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1431           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1432   const auto rand_op = Symbol::fromQualString("foo::rand11");
1433   auto graph = std::make_shared<Graph>();
1434   auto a = graph->addInput();
1435   graph->insert(rand_op, {a});
1436 
1437   // Registration time is okay, but throw exception when fetch from
1438   // registration.
1439   expectThrows<c10::Error>(
1440       [&graph] { AliasDb aliasDb(graph); },
1441       "Tried to register operator foo::rand11(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1442 }
1443 
TEST(AliasRegistrationTest,AliasMoveAtenListOp)1444 TEST(AliasRegistrationTest, AliasMoveAtenListOp) {
1445   auto graph = std::make_shared<Graph>();
1446   std::unordered_map<std::string, Value*> vmap;
1447   auto graph_string = R"IR(
1448   graph():
1449     %x : Tensor = prim::MakeTestTensor()
1450     %8 : int = prim::Constant[value=0]()
1451     %5 : int = prim::Constant[value=1]()
1452     %4 : int = prim::Constant[value=2]()
1453     %y : Tensor[] = prim::ListConstruct(%x)
1454     %6 : Tensor = aten::add_(%x, %4, %5)
1455     %9 : Tensor = aten::cat(%y, %8)
1456     return (%9))IR";
1457 
1458   torch::jit::parseIR(graph_string, graph.get(), vmap);
1459   AliasDb aliasDb(graph);
1460 
1461   // bc y.1 has a single used in a single non-aliasing aten op,
1462   // x is added to y.1 contained elements instead of wildcard set
1463   EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"]));
1464 
1465   // write to contained element should prevent move
1466   EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid(
1467       vmap["y"]->node(), vmap["9"]->node()));
1468 }
1469 
TEST(AliasRegistrationTest,AliasMoveForTupleConstructWithSingleUseAsGraphOutput)1470 TEST(
1471     AliasRegistrationTest,
1472     AliasMoveForTupleConstructWithSingleUseAsGraphOutput) {
1473   auto graph = std::make_shared<Graph>();
1474   std::unordered_map<std::string, Value*> vmap;
1475   auto graph_string = R"IR(
1476   graph():
1477     %x : Tensor = prim::MakeTestTensor()
1478     %y : Tensor = prim::MakeTestTensor()
1479     %z : (Tensor) = prim::TupleConstruct(%x, %y)
1480     return (%z))IR";
1481 
1482   torch::jit::parseIR(graph_string, graph.get(), vmap);
1483   AliasDb aliasDb(graph, /*isFrozen=*/false);
1484 
1485   EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"]));
1486   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
1487   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"]));
1488 }
1489 
TEST(AliasRegistrationTest,RecursiveSubgraphTupleContainment)1490 TEST(AliasRegistrationTest, RecursiveSubgraphTupleContainment) {
1491   auto graph = std::make_shared<Graph>();
1492   std::unordered_map<std::string, Value*> vmap;
1493   auto graph_string = R"IR(
1494   graph():
1495     %x : Tensor = prim::MakeTestTensor()
1496     %y : Tensor = prim::MakeTestTensor()
1497     %z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
1498     return (%z))IR";
1499 
1500   torch::jit::parseIR(graph_string, graph.get(), vmap);
1501   auto node = vmap["z"]->node();
1502   auto subgraph =
1503       SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph);
1504   AliasDb aliasDb(graph);
1505 
1506   EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"]));
1507   EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"]));
1508   EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
1509 }
1510 
TEST(AliasRegistrationTest,WildcardAliasForTupleConstructWithUses)1511 TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
1512   auto graph = std::make_shared<Graph>();
1513   std::unordered_map<std::string, Value*> vmap;
1514   auto graph_string = R"IR(
1515   graph():
1516     %x : Tensor = prim::MakeTestTensor()
1517     %y : Tensor = prim::MakeTestTensor()
1518     %z : Tensor = prim::MakeTestTensor()
1519     %0 : int = prim::Constant[value=0]()
1520     %a : (Tensor) = prim::TupleConstruct(%x, %y)
1521     %b : (Tensor) = prim::TupleConstruct(%z)
1522     %c : Tensor = prim::TupleIndex(%a, %0)
1523     %d : Tensor = prim::TupleIndex(%b, %0)
1524     return (%c, %d))IR";
1525 
1526   torch::jit::parseIR(graph_string, graph.get(), vmap);
1527   AliasDb aliasDb(graph, /*isFrozen=*/false);
1528 
1529   EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
1530   EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"]));
1531   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["z"]));
1532   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["x"]));
1533   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["y"]));
1534   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["z"]));
1535   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["x"]));
1536   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["y"]));
1537   EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
1538 }
1539 
TEST(AliasRegistrationTest,ATenSplitIntListAliasCheck)1540 TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
1541   auto graph = std::make_shared<Graph>();
1542   std::unordered_map<std::string, Value*> vmap;
1543   auto graph_string = R"IR(
1544   graph():
1545     %x : Tensor = prim::MakeTestTensor()
1546     %0 : int = prim::Constant[value=0]()
1547     %1 : int = prim::Constant[value=1]()
1548     %2 : int = prim::Constant[value=2]()
1549     %y : Tensor = aten::add(%x, %x, %0)
1550     %lengths_list : int[] = prim::tolist(%1, %2)
1551     %a : Tensor[] = aten::split(%y, %lengths_list, %0)
1552     %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1553     %b1 : Tensor = aten::flatten(%b, %0, %1)
1554     %c1 : Tensor = aten::flatten(%c, %0, %1)
1555     %d : Tensor = aten::add(%b1, %c1, %0)
1556     return (%d))IR";
1557 
1558   torch::jit::parseIR(graph_string, graph.get(), vmap);
1559   AliasDb aliasDb(graph, /*isFrozen=*/false);
1560 
1561   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1562   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1563   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1564   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1565 }
1566 
TEST(AliasRegistrationTest,ATenSplitIntAliasCheck)1567 TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
1568   auto graph = std::make_shared<Graph>();
1569   std::unordered_map<std::string, Value*> vmap;
1570   auto graph_string = R"IR(
1571   graph():
1572     %x : Tensor = prim::MakeTestTensor()
1573     %0 : int = prim::Constant[value=0]()
1574     %1 : int = prim::Constant[value=1]()
1575     %2 : int = prim::Constant[value=2]()
1576     %y : Tensor = aten::add(%x, %x, %0)
1577     %a : Tensor[] = aten::split(%y, %2, %0)
1578     %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1579     %b1 : Tensor = aten::flatten(%b, %0, %1)
1580     %c1 : Tensor = aten::flatten(%c, %0, %1)
1581     %d : Tensor = aten::add(%b1, %c1, %0)
1582     return (%d))IR";
1583 
1584   torch::jit::parseIR(graph_string, graph.get(), vmap);
1585   AliasDb aliasDb(graph, /*isFrozen=*/false);
1586 
1587   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1588   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1589   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1590   EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1591 }
1592 
TEST(AliasRegistrationTest,PureWithAnnotationsShouldError2)1593 TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
1594   auto registry = torch::RegisterOperators().op(
1595       "foo::rand12(Tensor(a) arg1) -> Tensor(b)",
1596       torch::RegisterOperators::options()
1597           .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1598           .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1599   const auto rand_op = Symbol::fromQualString("foo::rand12");
1600   auto graph = std::make_shared<Graph>();
1601   auto a = graph->addInput();
1602   graph->insert(rand_op, {a});
1603 
1604   // Registration time is okay, but throw exception when fetch from
1605   // registration.
1606   expectThrows<c10::Error>(
1607       [&graph] { AliasDb aliasDb(graph); },
1608       "Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1609 }
1610 
TEST(IRNonDeterminismTest,Basic)1611 TEST(IRNonDeterminismTest, Basic) {
1612   auto graph = std::make_shared<Graph>();
1613   auto graph_string = R"IR(
1614   graph():
1615     %x : Tensor = prim::MakeTestTensor()
1616     %0 : int = prim::Constant[value=0]()
1617     %1 : NoneType = prim::Constant()
1618     %2 : Tensor = aten::bernoulli(%x, %1)
1619     %3 : Tensor = aten::add(%x, %2, %0)
1620     return (%3))IR";
1621   parseIR(graph_string, graph.get());
1622 
1623   for (Node* n : graph->nodes()) {
1624     if (n->kind() == aten::bernoulli) {
1625       ASSERT_TRUE(n->isNondeterministic());
1626     } else {
1627       ASSERT_FALSE(n->isNondeterministic());
1628     }
1629   }
1630 }
1631 
TEST(IRNonDeterminismTest,DropoutSpecialCase)1632 TEST(IRNonDeterminismTest, DropoutSpecialCase) {
1633   auto graph = std::make_shared<Graph>();
1634   auto graph_string = R"IR(
1635   graph():
1636     %x : Tensor = prim::MakeTestTensor()
1637     %0 : bool = prim::Constant[value=0]()
1638     %1 : bool = prim::Constant[value=1]()
1639     %3 : int = prim::Constant[value=1]()
1640     %3 : float = prim::Constant[value=1.0]()
1641     %4 : Tensor = aten::dropout(%x, %3, %0)
1642     %5 : Tensor = aten::dropout(%x, %3, %1)
1643     %6 : Tensor = aten::add(%4, %5, %3)
1644     return (%6))IR";
1645   parseIR(graph_string, graph.get());
1646 
1647   bool train = false;
1648   for (Node* n : graph->nodes()) {
1649     if (n->kind() == aten::dropout) {
1650       if (!train) {
1651         ASSERT_FALSE(n->isNondeterministic());
1652         train = true;
1653       } else {
1654         ASSERT_TRUE(n->isNondeterministic());
1655       }
1656     } else {
1657       ASSERT_FALSE(n->isNondeterministic());
1658     }
1659   }
1660 }
1661 
TEST(NonDeterminismBackwardsCompatibility,BackwardsCompatibility)1662 TEST(NonDeterminismBackwardsCompatibility, BackwardsCompatibility) {
1663   static const std::vector<std::string> nondeterministic_ops = {
1664       "aten::dropout(Tensor input, float p, bool train) -> Tensor",
1665       "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
1666       "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
1667       "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
1668       "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
1669       "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
1670       "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
1671       "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
1672       "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator) -> Tensor",
1673       "aten::normal.Tensor_float(Tensor mean, float std, *, Generator? generator) -> Tensor",
1674       "aten::poisson(Tensor self, Generator? generator) -> Tensor",
1675       "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
1676       "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
1677       "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
1678       "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1679       "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1680       "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1681       "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1682       "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1683       "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1684       "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1685       "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1686       "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
1687   for (const std::string& op : nondeterministic_ops) {
1688     const c10::FunctionSchema& schema = torch::jit::parseSchema(op);
1689     const auto& op_handle = c10::Dispatcher::singleton().findOp(
1690         c10::OperatorName(schema.name(), schema.overload_name()));
1691     ASSERT_TRUE(op_handle->hasTag(at::Tag::nondeterministic_seeded));
1692   }
1693 }
1694 
TEST(TypeHashing,HashTypes)1695 TEST(TypeHashing, HashTypes) {
1696   HashType hasher;
1697 
1698   const TypePtr int_type = IntType::get();
1699   const TypePtr float_type = FloatType::get();
1700   ASSERT_NE(hasher(int_type), hasher(float_type));
1701 
1702   const TypePtr int2_type = TupleType::create({int_type, int_type});
1703   const TypePtr int3_type = TupleType::create({int_type, int_type, int_type});
1704   ASSERT_NE(hasher(int2_type), hasher(int3_type));
1705 }
1706 
1707 } // namespace jit
1708 } // namespace torch
1709