1 #include <gtest/gtest.h>
2
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/jit/frontend/ir_emitter.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/irparser.h>
7 #include <torch/csrc/jit/ir/type_hashing.h>
8 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
9 #include <torch/csrc/jit/runtime/custom_operator.h>
10 #include <torch/csrc/jit/runtime/graph_iterator.h>
11
12 #include <ATen/TensorOperators.h>
13
14 namespace torch {
15 namespace jit {
16
aliasAnalysisFromSchema()17 inline c10::AliasAnalysisKind aliasAnalysisFromSchema() {
18 return c10::AliasAnalysisKind::FROM_SCHEMA;
19 }
20
21 // Fixture to set up a graph and make assertions clearer
22 class TopologicalMoveTest : public ::testing::Test {
23 protected:
TopologicalMoveTest()24 TopologicalMoveTest() {
25 createGraph();
26 aliasDb = std::make_unique<AliasDb>(graph);
27 }
28
29 // Nodes are named after their output.
30 // e.g. "a" is an alias for "the node that outputs the value `a`"
createGraph()31 void createGraph() {
32 graph = std::make_shared<Graph>();
33 createNode("a", {});
34 createNode("b", {"a"});
35 createNode("c", {});
36 createNode("d", {"a", "b"});
37 createNode("e", {"c", "b"});
38 createNode("f", {"e"});
39 createNode("g", {"e"});
40 createNode("h", {"g"});
41 createNode("i", {"g"});
42 createNode("j", {"i"});
43 createNode("k", {"i"});
44 createNode("l", {"a"});
45 createNode("m", {}, {"l"}); // block depends on l
46 createNode("n", {"m"});
47 createNode("o", {"n"});
48 createNode("p", {});
49 createNode("q", {});
50 createNode("r", {"q"});
51 createNode("s", {"q"});
52
53 graph->lint();
54 }
55
createNode(const std::string & name,const std::vector<std::string> & inputNames,const std::vector<std::string> & blockInputNames={})56 void createNode(
57 const std::string& name,
58 const std::vector<std::string>& inputNames,
59 const std::vector<std::string>& blockInputNames = {}) {
60 std::vector<Value*> inputs;
61 for (const auto& name_ : inputNames) {
62 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
63 inputs.push_back(nodes.at(name_)->output());
64 }
65 auto node = graph->appendNode(graph->create(prim::AutogradZero, inputs));
66 node->output()->setDebugName(name);
67 nodes[name] = node;
68
69 if (blockInputNames.size() != 0) {
70 node->addBlock();
71 std::vector<Value*> blockDeps;
72 for (const auto& name_ : blockInputNames) {
73 // NOLINTNEXTLINE(performance-inefficient-vector-operation)
74 blockDeps.push_back(nodes.at(name_)->output());
75 }
76
77 auto block = node->blocks().at(0);
78 block->appendNode(graph->create(prim::AutogradZero, blockDeps));
79 }
80 }
81
moveBeforeTopologicallyValid(const std::string & toInsert,const std::string & insertPoint)82 bool moveBeforeTopologicallyValid(
83 const std::string& toInsert,
84 const std::string& insertPoint) {
85 std::function<bool(Node*, Node*)> func =
86 [this](Node* toInsert, Node* insertPoint) {
87 return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
88 };
89 return moveWithChecks(toInsert, insertPoint, func);
90 }
91
moveAfterTopologicallyValid(const std::string & toInsert,const std::string & insertPoint)92 bool moveAfterTopologicallyValid(
93 const std::string& toInsert,
94 const std::string& insertPoint) {
95 std::function<bool(Node*, Node*)> func =
96 [this](Node* toInsert, Node* insertPoint) {
97 return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
98 };
99 return moveWithChecks(toInsert, insertPoint, func);
100 }
101
moveWithChecks(const std::string & toInsert,const std::string & insertPoint,std::function<bool (Node *,Node *)> func)102 bool moveWithChecks(
103 const std::string& toInsert,
104 const std::string& insertPoint,
105 std::function<bool(Node*, Node*)> func) {
106 auto n = nodes.at(toInsert);
107 auto insert = nodes.at(insertPoint);
108 bool isAfter = n->isAfter(insert);
109
110 std::vector<Node*> originalOrdering;
111 Node* original = isAfter ? n->next() : n->prev();
112
113 auto curNode = original;
114 while (curNode != n->owningBlock()->return_node()) {
115 originalOrdering.push_back(curNode);
116 if (isAfter) {
117 curNode = curNode->next();
118 } else {
119 curNode = curNode->prev();
120 }
121 }
122
123 const auto couldMove = func(n, insert);
124 // Check the graph is okay
125 graph->lint();
126
127 // If this is the picture of nodes
128 // <some nodes> ... toInsert ... <some more nodes> ... insertPoint
129 // ^----------^ check that these nodes haven't moved
130 curNode = original;
131 size_t idx = 0;
132 while (curNode != n->owningBlock()->return_node()) {
133 EXPECT_TRUE(originalOrdering[idx] == curNode);
134 if (isAfter) {
135 curNode = curNode->next();
136 } else {
137 curNode = curNode->prev();
138 }
139 idx++;
140 }
141
142 return couldMove;
143 }
144
checkPostCondition(const std::string & toInsert,const std::string & insertPoint,bool after)145 void checkPostCondition(
146 const std::string& toInsert,
147 const std::string& insertPoint,
148 bool after) {
149 if (after) {
150 EXPECT_EQ(nodes.at(toInsert)->prev(), nodes.at(insertPoint));
151 } else {
152 EXPECT_EQ(nodes.at(toInsert)->next(), nodes.at(insertPoint));
153 }
154 }
155
156 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
157 std::shared_ptr<Graph> graph;
158 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
159 std::unique_ptr<AliasDb> aliasDb;
160 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
161 std::unordered_map<std::string, Node*> nodes;
162 };
163
TEST_F(TopologicalMoveTest,SplitsDeps)164 TEST_F(TopologicalMoveTest, SplitsDeps) {
165 // Check that we are removing `this`'s deps properly when we need to split
166 // `this` and deps (see code for what the hell that means)
167 EXPECT_TRUE(moveBeforeTopologicallyValid("q", "s"));
168 checkPostCondition("q", "s", false);
169 }
170
171 // Move after
TEST_F(TopologicalMoveTest,MoveAfterBackwardSimple)172 TEST_F(TopologicalMoveTest, MoveAfterBackwardSimple) {
173 // Simple move backward
174 EXPECT_TRUE(moveAfterTopologicallyValid("c", "a"));
175 checkPostCondition("c", "a", true);
176 }
TEST_F(TopologicalMoveTest,MoveAfterBackwardInvalid)177 TEST_F(TopologicalMoveTest, MoveAfterBackwardInvalid) {
178 // simple invalid move backward
179 EXPECT_FALSE(moveAfterTopologicallyValid("d", "a"));
180 }
181
TEST_F(TopologicalMoveTest,MoveAfterNoOp)182 TEST_F(TopologicalMoveTest, MoveAfterNoOp) {
183 // doesn't actually move anything
184 EXPECT_TRUE(moveAfterTopologicallyValid("f", "e"));
185 checkPostCondition("f", "e", true);
186 }
187
TEST_F(TopologicalMoveTest,MoveAfterBackwardMultipleDeps)188 TEST_F(TopologicalMoveTest, MoveAfterBackwardMultipleDeps) {
189 // move backward with multiple dependencies
190 EXPECT_TRUE(moveAfterTopologicallyValid("e", "c"));
191 checkPostCondition("e", "c", true);
192 }
193
TEST_F(TopologicalMoveTest,MoveAfterBackwardNonZeroWorkingSet)194 TEST_F(TopologicalMoveTest, MoveAfterBackwardNonZeroWorkingSet) {
195 // Move backward with non-zero working set
196 EXPECT_TRUE(moveAfterTopologicallyValid("k", "f"));
197 checkPostCondition("k", "f", true);
198 }
199
TEST_F(TopologicalMoveTest,MoveAfterForwardSimple)200 TEST_F(TopologicalMoveTest, MoveAfterForwardSimple) {
201 // Simple move forward
202 EXPECT_TRUE(moveAfterTopologicallyValid("c", "d"));
203 checkPostCondition("c", "d", true);
204 }
205
TEST_F(TopologicalMoveTest,MoveAfterForwardNonZeroWorkingSet)206 TEST_F(TopologicalMoveTest, MoveAfterForwardNonZeroWorkingSet) {
207 // Move forward with non-zero working set
208 EXPECT_TRUE(moveAfterTopologicallyValid("f", "l"));
209 checkPostCondition("f", "l", true);
210 }
211
212 // Move before
TEST_F(TopologicalMoveTest,MoveBeforeForwardSimple)213 TEST_F(TopologicalMoveTest, MoveBeforeForwardSimple) {
214 // Simple move forward
215 EXPECT_TRUE(moveBeforeTopologicallyValid("b", "d"));
216 checkPostCondition("b", "d", false);
217 }
218
TEST_F(TopologicalMoveTest,MoveBeforeBackwardSimple)219 TEST_F(TopologicalMoveTest, MoveBeforeBackwardSimple) {
220 // Simple move backward
221 EXPECT_TRUE(moveBeforeTopologicallyValid("c", "a"));
222 checkPostCondition("c", "a", false);
223 }
224
TEST_F(TopologicalMoveTest,MoveBeforeNoOp)225 TEST_F(TopologicalMoveTest, MoveBeforeNoOp) {
226 // doesn't actually move anything
227 EXPECT_TRUE(moveBeforeTopologicallyValid("a", "b"));
228 checkPostCondition("a", "b", false);
229 }
230
TEST_F(TopologicalMoveTest,MoveBeforeForwardWithDeps)231 TEST_F(TopologicalMoveTest, MoveBeforeForwardWithDeps) {
232 // move forward with deps
233 EXPECT_TRUE(moveBeforeTopologicallyValid("f", "m"));
234 checkPostCondition("f", "m", false);
235 }
236
TEST_F(TopologicalMoveTest,MoveBeforeBackwardWithDeps)237 TEST_F(TopologicalMoveTest, MoveBeforeBackwardWithDeps) {
238 // move backward with deps
239 EXPECT_TRUE(moveBeforeTopologicallyValid("l", "f"));
240 checkPostCondition("l", "f", false);
241 }
242
243 // check that dependencies in blocks are recognized
TEST_F(TopologicalMoveTest,DepsDisallowMove)244 TEST_F(TopologicalMoveTest, DepsDisallowMove) {
245 EXPECT_FALSE(moveAfterTopologicallyValid("l", "m"));
246 EXPECT_FALSE(moveBeforeTopologicallyValid("m", "l"));
247 EXPECT_FALSE(moveAfterTopologicallyValid("n", "l"));
248 EXPECT_FALSE(moveBeforeTopologicallyValid("l", "n"));
249 }
250
251 // Test that moveAfter(n) and moveBefore(n->next()) are not necessarily
252 // equivalent. Here, the dependency ordering is n -> o -> p. So we can't
253 // move `n` after `o`, but we can move `n` before `p` (which pushes `o` after
254 // `p`)
TEST_F(TopologicalMoveTest,MoveAfterBeforeWithDeps)255 TEST_F(TopologicalMoveTest, MoveAfterBeforeWithDeps) {
256 EXPECT_FALSE(moveAfterTopologicallyValid("n", "o"));
257 EXPECT_TRUE(moveBeforeTopologicallyValid("o", "p"));
258 checkPostCondition("o", "p", false);
259 }
260
261 namespace {
insertIf(Graph & g,Value * condValue,std::function<std::vector<Value * > ()> trueInst,std::function<std::vector<Value * > ()> falseInst)262 Node* insertIf(
263 Graph& g,
264 Value* condValue,
265 std::function<std::vector<Value*>()> trueInst,
266 std::function<std::vector<Value*>()> falseInst) {
267 auto if_ = g.insertNode(g.create(prim::If, 0));
268 if_->addInput(condValue); // condition value
269 auto trueBlock = if_->addBlock();
270 auto falseBlock = if_->addBlock();
271 {
272 // Mutate in true block
273 WithInsertPoint g(trueBlock);
274 auto outputs = trueInst();
275 for (auto output : outputs) {
276 trueBlock->registerOutput(output);
277 }
278 }
279 {
280 WithInsertPoint g(falseBlock);
281 auto outputs = falseInst();
282 for (auto output : outputs) {
283 falseBlock->registerOutput(output);
284 }
285 }
286
287 EXPECT_TRUE(trueBlock->outputs().size() == falseBlock->outputs().size());
288 for (auto output : trueBlock->outputs()) {
289 if_->addOutput()->setType(output->type());
290 }
291 return if_;
292 }
293
294 template <class Exception, class Functor>
expectThrows(Functor && functor,const char * expectMessageContains)295 inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
296 try {
297 std::forward<Functor>(functor)();
298 } catch (const Exception& e) {
299 if (std::string(e.what()).find(expectMessageContains) ==
300 std::string::npos) {
301 AT_ERROR(
302 "Expected error message to contain \"",
303 expectMessageContains,
304 "\" but error message was: ",
305 e.what());
306 }
307 return;
308 }
309 AT_ERROR(
310 "Expected to throw exception containing \"",
311 expectMessageContains,
312 "\" but didn't throw");
313 }
314
315 } // namespace
316
TEST(AliasAnalysisTest,AliasingMutationBlocksMoves)317 TEST(AliasAnalysisTest, AliasingMutationBlocksMoves) {
318 auto graph = std::make_shared<Graph>();
319 auto a = graph->addInput();
320 auto b = graph->addInput();
321
322 // addsB = b + b
323 // c = a + b
324 // a += b
325 // d = c + c
326 auto addsB = graph->insert(aten::add, {b, b});
327 auto c = graph->insert(aten::add, {a, b});
328 auto aMut = graph->insert(aten::add_, {a, b});
329 auto d = graph->insert(aten::add, {c, c});
330
331 graph->lint();
332
333 AliasDb aliasDb(graph);
334 // Can't move past a mutation of a used value
335 EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
336 EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
337
338 // b should alias to a (since they are both inputs)
339 EXPECT_FALSE(
340 aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
341 EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
342
343 graph->lint();
344 }
345
TEST(AliasAnalysisTest,AliasingMutationBlocksMoves2)346 TEST(AliasAnalysisTest, AliasingMutationBlocksMoves2) {
347 auto graph = std::make_shared<Graph>();
348 auto a = graph->addInput();
349 auto b = graph->addInput();
350
351 auto constant = graph->insertConstant(1);
352 auto fresh = graph->insert(aten::rand, {constant});
353 auto usesB = graph->insert(aten::add, {b, fresh});
354 auto aliasesB = graph->insert(aten::select, {a, constant, constant});
355 auto mutatesAliasOfB = graph->insert(aten::add_, {aliasesB, fresh});
356 graph->insert(aten::add, {fresh, aliasesB});
357 graph->lint();
358
359 AliasDb aliasDb(graph);
360 EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
361 aliasesB->node(), mutatesAliasOfB->node()));
362 EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(
363 usesB->node(), mutatesAliasOfB->node()));
364 }
365
TEST(AliasAnalysisTest,SideEffectsBlockMoves)366 TEST(AliasAnalysisTest, SideEffectsBlockMoves) {
367 // Test moves across side effectful nodes
368 auto graph = std::make_shared<Graph>();
369 auto a = graph->addInput();
370 auto print1 = graph->insertNode(graph->create(prim::Print, {a}, 0));
371 WithInsertPoint guard(print1);
372 auto print2 = graph->insertNode(graph->create(prim::Print, {a, a}, 0));
373 AliasDb aliasDb(graph);
374
375 // def foo(a):
376 // print2(a, a)
377 // print1(a)
378
379 // test moving across each other
380 EXPECT_FALSE(aliasDb.moveAfterTopologicallyValid(print2, print1));
381 EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(print1, print2));
382
383 // test moving where they already are
384 EXPECT_TRUE(aliasDb.moveBeforeTopologicallyValid(print2, print1));
385 EXPECT_TRUE(aliasDb.moveAfterTopologicallyValid(print1, print2));
386
387 graph->insertNode(graph->create(prim::MakeTestTensor, {}, 1));
388 AliasDb aliasDb2(graph);
389
390 // def foo(a):
391 // print2(a, a)
392 // non_side_effectful = makeTestTensor()
393 // print1(a)
394
395 // test moving with a side effectful node between
396 EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print2, print1));
397 EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print2, print1));
398 EXPECT_FALSE(aliasDb2.moveAfterTopologicallyValid(print1, print2));
399 EXPECT_FALSE(aliasDb2.moveBeforeTopologicallyValid(print1, print2));
400 }
401
TEST(AliasAnalysisTest,MovingAcrossInnerBlocks)402 TEST(AliasAnalysisTest, MovingAcrossInnerBlocks) {
403 // Test moves across inner blocks
404
405 // a = rand(1)
406 // b = rand(1)
407 // if True:
408 // a.add_(b)
409 // c = a + b
410 auto graph = std::make_shared<Graph>();
411 auto constant = graph->insertConstant(1);
412 auto a = graph->insert(aten::rand, {constant});
413 auto b = graph->insert(aten::rand, {constant});
414
415 auto if_ = insertIf(
416 *graph,
417 constant,
418 [&]() -> std::vector<Value*> {
419 auto aMut = graph->insert(aten::add_, {a, b});
420 return {aMut};
421 },
422 [&]() -> std::vector<Value*> { return {a}; });
423
424 auto c = graph->insert(aten::add, {a, b});
425
426 graph->lint();
427
428 // we should not be able to move `c` before the if statement, since it
429 // may write to `a`.
430 AliasDb aliasDb(graph);
431 EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(c->node(), if_));
432 }
433
TEST(AliasAnalysisTest,NoneHasNoWriters)434 TEST(AliasAnalysisTest, NoneHasNoWriters) {
435 auto graph = std::make_shared<Graph>();
436 std::unordered_map<std::string, Value*> vmap;
437 parseIR(
438 R"IR(
439 graph():
440 %opt : Tensor? = prim::Constant()
441 %out : Tensor = prim::unchecked_unwrap_optional(%opt)
442 %ret.2 : Tensor = aten::div(%out, %out, %out)
443 return (%opt, %out, %ret.2)
444 )IR",
445 &*graph,
446 vmap);
447
448 AliasDb aliasDb(graph);
449 EXPECT_FALSE(aliasDb.hasWriters(vmap["opt"]->node()));
450 }
451
TEST(AliasAnalysisTest,SafeToChangeAliasingRelationship)452 TEST(AliasAnalysisTest, SafeToChangeAliasingRelationship) {
453 auto graph = std::make_shared<Graph>();
454 std::unordered_map<std::string, Value*> vmap;
455 parseIR(
456 R"IR(
457 graph(%x : Tensor):
458 %3 : int = prim::Constant[value=1]()
459 %2 : int = prim::Constant[value=0]()
460 %b : Tensor = aten::add(%x, %2, %3)
461 %c : Tensor = aten::add(%x, %2, %3)
462 %d : Tensor = aten::add(%x, %2, %3)
463 %e : Tensor = aten::add(%x, %2, %3)
464 %f : Tensor[] = prim::ListConstruct(%e)
465 %14 : (Tensor, Tensor) = prim::TupleConstruct(%b, %c)
466 return (%14)
467 )IR",
468 &*graph,
469 vmap);
470
471 AliasDb aliasDb(graph);
472 // x, b, c escape scope, so we can't introduce an aliasing relationship
473 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["b"]));
474 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["x"]));
475 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["b"], vmap["c"]));
476 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["b"]));
477
478 // e aliases the wildcard set because it's contained in a list
479 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["e"], vmap["x"]));
480 EXPECT_FALSE(aliasDb.safeToChangeAliasingRelationship(vmap["x"], vmap["e"]));
481
482 // d is a temporary with no writers, safe to change aliasing relationship
483 // here
484 EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["c"], vmap["d"]));
485 EXPECT_TRUE(aliasDb.safeToChangeAliasingRelationship(vmap["d"], vmap["c"]));
486 }
487
488 class BatchAndInstanceNormFixture
489 : public ::testing::TestWithParam<std::tuple<std::string, NodeKind, bool>> {
490 };
491
TEST_P(BatchAndInstanceNormFixture,BatchAndInstanceNorm)492 TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNorm) {
493 auto param = GetParam();
494 auto fnName = std::get<0>(param);
495 auto nodeKind = std::get<1>(param);
496 auto isTraining = std::get<2>(param);
497 std::string isTrainingStr = std::to_string((int)isTraining);
498
499 auto graph = std::make_shared<Graph>();
500
501 parseIR(
502 R"IR(
503 graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor):
504 %none : NoneType = prim::Constant()
505 %training : bool = prim::Constant[value=)IR" +
506 isTrainingStr + R"IR(]()
507 %momentum : float = prim::Constant[value=1.0]()
508 %eps : float = prim::Constant[value=1.0e-9]()
509 %cudnn_enabled : bool = prim::Constant[value=0]()
510 %res : Tensor = )IR" +
511 fnName +
512 R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
513 return (%res)
514 )IR",
515 &*graph);
516
517 graph->lint();
518 DepthFirstGraphNodeIterator it(graph);
519
520 Node* n = nullptr;
521 while ((n = it.next()) != nullptr) {
522 if (n->kind() == nodeKind) {
523 break;
524 }
525 }
526 EXPECT_TRUE(n != nullptr);
527
528 AliasDb aliasDb(graph);
529 EXPECT_TRUE(aliasDb.hasWriters(n) == isTraining);
530 }
531
TEST_P(BatchAndInstanceNormFixture,BatchAndInstanceNormTrainingUnknown)532 TEST_P(BatchAndInstanceNormFixture, BatchAndInstanceNormTrainingUnknown) {
533 auto param = GetParam();
534 auto fnName = std::get<0>(param);
535 auto nodeKind = std::get<1>(param);
536
537 auto graph = std::make_shared<Graph>();
538
539 parseIR(
540 R"IR(
541 graph(%input : Tensor, %running_mean : Tensor, %running_var : Tensor, %training : bool):
542 %none : NoneType = prim::Constant()
543 %momentum : float = prim::Constant[value=1.0]()
544 %eps : float = prim::Constant[value=1.0e-9]()
545 %cudnn_enabled : bool = prim::Constant[value=0]()
546 %res : Tensor = )IR" +
547 fnName +
548 R"IR((%input, %none, %none, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
549 return (%res)
550 )IR",
551 &*graph);
552
553 graph->lint();
554 DepthFirstGraphNodeIterator it(graph);
555
556 Node* n = nullptr;
557 while ((n = it.next()) != nullptr) {
558 if (n->kind() == nodeKind) {
559 break;
560 }
561 }
562 EXPECT_TRUE(n != nullptr);
563
564 AliasDb aliasDb(graph);
565 EXPECT_TRUE(aliasDb.hasWriters(n));
566 }
567
TEST_P(BatchAndInstanceNormFixture,BatchNormTrainingWithNoMeanOrVar)568 TEST_P(BatchAndInstanceNormFixture, BatchNormTrainingWithNoMeanOrVar) {
569 auto param = GetParam();
570 auto fnName = std::get<0>(param);
571 auto nodeKind = std::get<1>(param);
572 auto isTraining = std::get<2>(param);
573 std::string isTrainingStr = std::to_string((int)isTraining);
574
575 auto graph = std::make_shared<Graph>();
576
577 parseIR(
578 R"IR(
579 graph(%input : Tensor):
580 %none : NoneType = prim::Constant()
581 %training : bool = prim::Constant[value=)IR" +
582 isTrainingStr + R"IR(]()
583 %momentum : float = prim::Constant[value=1.0]()
584 %eps : float = prim::Constant[value=1.0e-9]()
585 %cudnn_enabled : bool = prim::Constant[value=0]()
586 %res : Tensor = )IR" +
587 fnName +
588 R"IR((%input, %none, %none, %none, %none, %training, %momentum, %eps, %cudnn_enabled)
589 return (%res)
590 )IR",
591 &*graph);
592
593 graph->lint();
594 DepthFirstGraphNodeIterator it(graph);
595
596 Node* n = nullptr;
597 while ((n = it.next()) != nullptr) {
598 if (n->kind() == nodeKind) {
599 break;
600 }
601 }
602 EXPECT_TRUE(n != nullptr);
603
604 AliasDb aliasDb(graph);
605 EXPECT_FALSE(aliasDb.hasWriters(n));
606 }
607
608 INSTANTIATE_TEST_SUITE_P(
609 AliasAnalysisTest,
610 BatchAndInstanceNormFixture,
611 ::testing::Values(
612 std::make_tuple("aten::batch_norm", aten::batch_norm, false),
613 std::make_tuple("aten::instance_norm", aten::instance_norm, false),
614 std::make_tuple("aten::batch_norm", aten::batch_norm, true),
615 std::make_tuple("aten::instance_norm", aten::instance_norm, true)));
616
TEST(WriteTrackingTest,Basic)617 TEST(WriteTrackingTest, Basic) {
618 RegisterOperators reg({Operator(
619 "prim::creates_alias(Tensor(a) x) -> Tensor(a)",
620 [](Stack&) {},
621 aliasAnalysisFromSchema())});
622 const auto creates_alias = Symbol::fromQualString("prim::creates_alias");
623 auto graph = std::make_shared<Graph>();
624 auto a = graph->addInput();
625 auto b = graph->addInput();
626
627 // aten::add(%b, %b)
628 // aten::add_(%a, %b)
629 // foo::creates_alias(%a)
630 auto pureNode = graph->insert(aten::add, {b, b})->node();
631 auto writingNode = graph->insert(aten::add_, {a, b})->node();
632 auto node3 = graph->insert(creates_alias, {a})->node();
633 auto aAlias = node3->output();
634
635 graph->lint();
636
637 AliasDb aliasDb(graph);
638 EXPECT_TRUE(aliasDb.mayAlias(aAlias, a));
639 EXPECT_TRUE(aliasDb.mayAlias(a, b));
640 EXPECT_FALSE(
641 aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
642 EXPECT_FALSE(
643 aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
644 EXPECT_TRUE(
645 aliasDb.writesToAlias(writingNode, std::unordered_set<const Value*>{a}));
646 EXPECT_TRUE(aliasDb.writesToAlias(
647 writingNode, std::unordered_set<const Value*>{a, b}));
648 EXPECT_TRUE(aliasDb.writesToAlias(
649 writingNode, std::unordered_set<const Value*>{aAlias}));
650 }
651
TEST(WriteTrackingTest,IsMutable)652 TEST(WriteTrackingTest, IsMutable) {
653 auto graph = std::make_shared<Graph>();
654 parseIR(
655 R"IR(
656 graph(%x: Tensor):
657 %b : Tensor = aten::relu_(%x)
658 return (%b)
659 )IR",
660 &*graph);
661 auto node_iter = graph->block()->nodes().begin();
662 auto relu = *node_iter;
663 AliasDb aliasDb(graph);
664 EXPECT_TRUE(aliasDb.isMutable(relu));
665 }
666
TEST(WriteTrackingTest,IsImmutable)667 TEST(WriteTrackingTest, IsImmutable) {
668 auto graph = std::make_shared<Graph>();
669 parseIR(
670 R"IR(
671 graph(%x: Tensor, %y : Tensor):
672 %b : Tensor = aten::mul(%x, %y)
673 return (%b)
674 )IR",
675 &*graph);
676 auto node_iter = graph->block()->nodes().begin();
677 auto mul = *node_iter;
678 AliasDb aliasDb(graph);
679 EXPECT_FALSE(aliasDb.isMutable(mul));
680 }
681
TEST(WriteTrackingTest,HasWriters)682 TEST(WriteTrackingTest, HasWriters) {
683 auto graph = std::make_shared<Graph>();
684 std::unordered_map<std::string, Value*> vmap;
685 parseIR(
686 R"IR(
687 graph(%x: Tensor, %y : Tensor):
688 %c1 : int = prim::Constant[value=1]()
689 %b : Tensor = aten::add_(%x, %y, %c1)
690 return (%b)
691 )IR",
692 &*graph,
693 vmap);
694 auto add = vmap["b"]->node();
695 AliasDb aliasDb(graph);
696 EXPECT_TRUE(aliasDb.hasWriters(add));
697 EXPECT_TRUE(aliasDb.isMutable(add));
698 }
699
TEST(ContainerAliasingTest,MayContainAlias)700 TEST(ContainerAliasingTest, MayContainAlias) {
701 auto graph = std::make_shared<Graph>();
702 std::unordered_map<std::string, Value*> vmap;
703 parseIR(
704 R"IR(
705 graph(%inp: Tensor[]):
706 %x : str = prim::Constant[value="a"]()
707 %y : Tensor = prim::Constant()
708 %z : Tensor = prim::Constant()
709 %a : (Tensor) = prim::TupleConstruct(%y)
710 %b : Dict(str, Tensor) = prim::DictConstruct(%x, %y)
711 %c : Tensor[] = prim::ListConstruct(%y)
712 return (%a, %b, %c)
713 )IR",
714 &*graph,
715 vmap);
716
717 auto str_output = vmap["x"];
718 auto ten_output = vmap["y"];
719 auto local_var = vmap["z"];
720 AliasDb aliasDb(graph);
721
722 EXPECT_TRUE(graph->outputs().size() == 3);
723 for (auto out : graph->outputs()) {
724 EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, out));
725 EXPECT_FALSE(aliasDb.mayContainAlias(local_var, out));
726 }
727
728 EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->inputs()));
729 EXPECT_FALSE(aliasDb.mayContainAlias(local_var, graph->inputs()));
730
731 EXPECT_TRUE(aliasDb.mayContainAlias(ten_output, graph->outputs()));
732 EXPECT_TRUE(aliasDb.mayContainAlias(
733 at::ArrayRef<Value*>{ten_output}, graph->outputs()));
734 EXPECT_FALSE(aliasDb.mayContainAlias(str_output, graph->outputs()));
735 }
736
TEST(ContainerAliasingTest,MayContainAlias_cast)737 TEST(ContainerAliasingTest, MayContainAlias_cast) {
738 auto graph = std::make_shared<Graph>();
739 std::unordered_map<std::string, Value*> vmap;
740 parseIR(
741 R"IR(
742 graph(%input.1 : Tensor):
743 %2 : NoneType = prim::Constant()
744 %3 : bool = prim::Constant[value=0]()
745 %4 : int = prim::Constant[value=6]()
746 %5 : int = prim::Constant[value=1]()
747 %a.1 : Tensor = aten::add(%input.1, %input.1, %5)
748 %b.1 : Tensor = aten::to(%a.1, %4, %3, %3, %2)
749 %c.1 : Tensor = aten::mul(%b.1, %b.1)
750 return (%c.1)
751 )IR",
752 &*graph,
753 vmap);
754
755 auto a = vmap["a.1"];
756 auto b = vmap["b.1"];
757 auto c = vmap["c.1"];
758 AliasDb aliasDb(graph);
759
760 EXPECT_TRUE(graph->outputs().size() == 1);
761 for (auto out : graph->outputs()) {
762 EXPECT_TRUE(aliasDb.mayContainAlias(c, out));
763 }
764
765 EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
766 EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->inputs()));
767
768 EXPECT_TRUE(aliasDb.mayContainAlias(c, graph->outputs()));
769 EXPECT_TRUE(
770 aliasDb.mayContainAlias(at::ArrayRef<Value*>{c}, graph->outputs()));
771 EXPECT_FALSE(aliasDb.mayContainAlias(b, graph->outputs()));
772 }
773
TEST(ContainerAliasingTest,PrimitveValuesDontAliasContainers)774 TEST(ContainerAliasingTest, PrimitveValuesDontAliasContainers) {
775 auto graph = std::make_shared<Graph>();
776 parseIR(
777 R"IR(
778 graph():
779 %x : str = prim::Constant[value="a"]()
780 %y : int = prim::Constant[value=1]()
781 %a : (int) = prim::TupleConstruct(%y)
782 %b : Dict(str, int) = prim::DictConstruct(%x, %y)
783 %c : int[] = prim::ListConstruct(%y)
784 return (%a, %b, %c)
785 )IR",
786 &*graph);
787
788 auto node_iter = graph->block()->nodes().begin();
789 node_iter++; // string
790 Node* int_node = *node_iter++;
791 AliasDb aliasDb(graph);
792
793 EXPECT_TRUE(graph->outputs().size() == 3);
794 // primitive values don't need to alias container
795 for (auto out : graph->outputs()) {
796 EXPECT_FALSE(aliasDb.mayContainAlias(int_node->output(), out));
797 }
798 }
799
TEST(ContainerAliasingTest,UnionAliasing)800 TEST(ContainerAliasingTest, UnionAliasing) {
801 auto graph = std::make_shared<Graph>();
802 parseIR(
803 R"IR(
804 graph(%a : Dict(str, Tensor),
805 %b : Tensor[],
806 %c : Union(Dict(str, Tensor), Tensor[])):
807 return (%a, %b, %c)
808 )IR",
809 &*graph);
810
811 AliasDb aliasDb(graph);
812 auto a = graph->outputs().at(0);
813 auto b = graph->outputs().at(1);
814 auto c = graph->outputs().at(2);
815
816 EXPECT_TRUE(aliasDb.mayAlias(a, c));
817 EXPECT_TRUE(aliasDb.mayAlias(b, c));
818 EXPECT_TRUE(aliasDb.mayAlias(c, c));
819 EXPECT_FALSE(aliasDb.mayAlias(a, b));
820 EXPECT_TRUE(aliasDb.mayContainAlias(a, b));
821 EXPECT_TRUE(aliasDb.mayContainAlias(a, c));
822 EXPECT_TRUE(aliasDb.mayContainAlias(b, c));
823 }
824
TEST(ContainerAliasingTest,InputsCanAliasOutputs)825 TEST(ContainerAliasingTest, InputsCanAliasOutputs) {
826 // Test input aliasing
827 auto graph = std::make_shared<Graph>();
828 parseIR(
829 R"IR(
830 graph(%x: Tensor, %y: Tensor):
831 %a : (Tensor) = prim::TupleConstruct(%x)
832 return (%a)
833 )IR",
834 &*graph);
835
836 auto node_iter = graph->block()->nodes().begin();
837 auto tuple_node = *node_iter;
838 AliasDb aliasDb(graph);
839
840 for (auto input : graph->inputs()) {
841 EXPECT_TRUE(aliasDb.mayContainAlias(input, tuple_node->output()));
842 }
843 EXPECT_TRUE(aliasDb.mayContainAlias(graph->inputs(), graph->outputs()));
844 }
845
846 // Test tuple that doesn't come from construct
TEST(ContainerAliasingTest,NestedTupleConstruct)847 TEST(ContainerAliasingTest, NestedTupleConstruct) {
848 auto graph = std::make_shared<Graph>();
849 parseIR(
850 R"IR(
851 graph(%x : int,
852 %y : Tensor,
853 %z : Tensor):
854 %3 : int = prim::Constant[value=1]()
855 %4 : bool = aten::eq(%x, %3)
856 %a : (Tensor) = prim::If(%4)
857 block0():
858 %a.1 : (Tensor) = prim::TupleConstruct(%y)
859 -> (%a.1)
860 block1():
861 %a.2 : (Tensor) = prim::TupleConstruct(%z)
862 -> (%a.2)
863 return (%a)
864 )IR",
865 &*graph);
866
867 AliasDb aliasDb(graph);
868
869 for (auto input : graph->inputs()) {
870 if (input->type() == IntType::get()) {
871 continue;
872 }
873
874 EXPECT_TRUE(aliasDb.mayContainAlias(input, graph->outputs().at(0)));
875 }
876 }
877
878 // test nested types
TEST(ContainerAliasingTest,NestedTypes)879 TEST(ContainerAliasingTest, NestedTypes) {
880 auto graph = std::make_shared<Graph>();
881 parseIR(
882 R"IR(
883 graph():
884 %a : Tensor = prim::MakeTestTensor()
885 %a_list : Tensor[] = prim::ListConstruct(%a)
886 %b : Tensor = prim::MakeTestTensor()
887 %b_list : Tensor[] = prim::ListConstruct(%b)
888 %13 : (Tensor[], Tensor[]) = prim::TupleConstruct(%a_list, %b_list)
889 return (%13)
890 )IR",
891 &*graph);
892 AliasDb aliasDb(graph);
893 auto g_output = graph->outputs().at(0);
894 auto list_2 = g_output->node()->inputs().at(0);
895 auto list_1 = g_output->node()->inputs().at(1);
896
897 // TODO FIX assume conservatively for now
898 EXPECT_TRUE(aliasDb.mayContainAlias(list_1, list_2));
899 EXPECT_TRUE(aliasDb.mayContainAlias(list_2, list_1));
900
901 EXPECT_TRUE(aliasDb.mayContainAlias(list_1, g_output));
902 EXPECT_TRUE(aliasDb.mayContainAlias(list_2, g_output));
903 }
904
905 // simple example
TEST(ContainerAliasingTest,Simple)906 TEST(ContainerAliasingTest, Simple) {
907 auto graph = std::make_shared<Graph>();
908 parseIR(
909 R"IR(
910 graph():
911 %0 : Tensor = prim::Constant()
912 %1 : Tensor = prim::Constant()
913 %13 : (Tensor) = prim::TupleConstruct(%0)
914 return (%13)
915 )IR",
916 &*graph);
917 AliasDb aliasDb(graph);
918
919 auto node_iter = graph->block()->nodes().begin();
920 auto first_ten = *node_iter++;
921 auto second_ten = *node_iter++;
922 auto tup_node = *node_iter;
923
924 EXPECT_TRUE(aliasDb.mayContainAlias(first_ten->output(), tup_node->output()));
925 EXPECT_TRUE(
926 !aliasDb.mayContainAlias(second_ten->output(), tup_node->output()));
927
928 std::vector<Value*> first_st = {first_ten->output()};
929 std::vector<Value*> second_st = {second_ten->output()};
930 std::vector<Value*> tup_st = {tup_node->output()};
931 EXPECT_TRUE(aliasDb.mayContainAlias(first_st, tup_st));
932 EXPECT_FALSE(aliasDb.mayContainAlias(first_st, second_st));
933 EXPECT_FALSE(aliasDb.mayContainAlias(second_st, tup_st));
934 }
935
TEST(ContainerAliasingTest,Lists)936 TEST(ContainerAliasingTest, Lists) {
937 auto graph = std::make_shared<Graph>();
938 std::unordered_map<std::string, Value*> vmap;
939 parseIR(
940 R"IR(
941 graph():
942 %x : str = prim::Constant[value="a"]()
943 %y : Tensor = prim::Constant()
944 %c : Tensor[] = prim::ListConstruct(%y)
945 %d : Tensor[] = prim::ListConstruct(%y)
946 return (%c, %d)
947 )IR",
948 &*graph,
949 vmap);
950
951 AliasDb aliasDb(graph);
952 auto x = vmap["x"];
953 auto c = vmap["c"];
954 EXPECT_FALSE(aliasDb.mayContainAlias(x, c));
955 EXPECT_FALSE(aliasDb.mayContainAlias(c, x));
956
957 auto d = vmap["d"];
958
959 EXPECT_TRUE(aliasDb.mayContainAlias(d, c));
960 EXPECT_TRUE(aliasDb.mayContainAlias(c, d));
961 }
962
TEST(ContainerAliasingTest,Lists2)963 TEST(ContainerAliasingTest, Lists2) {
964 // Test list container aliasing
965 auto graph = std::make_shared<Graph>();
966 std::unordered_map<std::string, Value*> vmap;
967 parseIR(
968 R"IR(
969 graph():
970 %0 : int = prim::Constant[value=2]()
971 %1 : int = prim::Constant[value=3]()
972 %2 : int[] = prim::ListConstruct(%0, %1)
973 %x : Tensor = prim::MakeTestTensor()
974 %12 : int[] = prim::ListConstruct(%0, %1)
975 %y : Tensor = prim::MakeTestTensor()
976 %22 : int[] = prim::ListConstruct(%0, %1)
977 %z : Tensor = prim::MakeTestTensor()
978 %32 : int[] = prim::ListConstruct(%0, %1)
979 %fresh : Tensor = prim::MakeTestTensor()
980 %foo : Tensor[] = prim::ListConstruct(%x, %y)
981 %43 : Tensor[] = aten::append(%foo, %z)
982 return ()
983 )IR",
984 graph.get(),
985 vmap);
986 AliasDb aliasDb(graph);
987 auto x = vmap["x"];
988 auto y = vmap["y"];
989 auto z = vmap["z"];
990 // Tensors x, y, and z went into a list, so they all may alias each other.
991 EXPECT_TRUE(aliasDb.mayAlias(x, y));
992 EXPECT_TRUE(aliasDb.mayAlias(y, z));
993 EXPECT_TRUE(aliasDb.mayAlias(x, z));
994
995 // But we know `fresh` didn't go into a list, so x, y, and z should not
996 // alias it.
997 auto fresh = vmap["fresh"];
998 EXPECT_FALSE(aliasDb.mayAlias(x, fresh));
999 EXPECT_FALSE(aliasDb.mayAlias(y, fresh));
1000 EXPECT_FALSE(aliasDb.mayAlias(z, fresh));
1001 }
1002
TEST(ContainerAliasingTest,Conservative)1003 TEST(ContainerAliasingTest, Conservative) {
1004 // test "conservative" analysis writes to the inside of a container.
1005 auto ops = torch::RegisterOperators(
1006 "custom::conservative", [](torch::List<at::Tensor> in) { return in; });
1007
1008 auto graph = std::make_shared<Graph>();
1009 std::unordered_map<std::string, Value*> vmap;
1010 parseIR(
1011 R"IR(
1012 graph():
1013 %0 : int = prim::Constant[value=2]()
1014 %1 : int = prim::Constant[value=3]()
1015 %2 : int[] = prim::ListConstruct(%0, %1)
1016 %11 : Tensor = prim::MakeTestTensor()
1017 %12 : Tensor[] = prim::ListConstruct(%11)
1018 %out : Tensor[] = custom::conservative(%12)
1019 %ret.2 : Tensor = aten::div(%11, %11)
1020 return ()
1021 )IR",
1022 graph.get(),
1023 vmap);
1024 AliasDb aliasDb(graph);
1025 auto conservativeOp = vmap["out"]->node();
1026 auto tensor = vmap["11"];
1027 EXPECT_TRUE(aliasDb.writesToAlias(conservativeOp, ValueSet{tensor}));
1028 }
1029
TEST(ContainerAliasingTest,MovesAcrossContainedWrites)1030 TEST(ContainerAliasingTest, MovesAcrossContainedWrites) {
1031 auto ops = torch::RegisterOperators().op(
1032 "uses::list",
1033 torch::RegisterOperators::options()
1034 .catchAllKernel([](torch::List<at::Tensor> in) {
1035 return torch::rand({2, 3});
1036 })
1037 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1038 // Write to the inside of a list. Check that we can't reorder a
1039 // print across it.
1040 auto graph = std::make_shared<Graph>();
1041 std::unordered_map<std::string, Value*> vmap;
1042 parseIR(
1043 R"IR(
1044 graph():
1045 %35 : int = prim::Constant[value=1]()
1046 %0 : int = prim::Constant[value=2]()
1047 %1 : int = prim::Constant[value=3]()
1048 %23 : int = prim::Constant[value=0]()
1049 %2 : int[] = prim::ListConstruct(%0, %1)
1050 %11 : Tensor = prim::MakeTestTensor()
1051 %12 : int[] = prim::ListConstruct(%0, %1)
1052 %21 : Tensor = prim::MakeTestTensor()
1053 %l : Tensor[] = prim::ListConstruct(%11, %21)
1054 %24 : Tensor = aten::select(%l, %23)
1055 %25 : int[] = prim::ListConstruct(%0, %1)
1056 %34 : Tensor = prim::MakeTestTensor()
1057 %36 : Tensor = aten::add_(%24, %34, %35)
1058 %37 : Tensor = uses::list(%l)
1059 return (%37)
1060 )IR",
1061 graph.get(),
1062 vmap);
1063 AliasDb aliasDb(graph);
1064 auto listUse = vmap["37"]->node();
1065 auto internalWrite = vmap["36"]->node();
1066 EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
1067 }
1068
TEST(ContainerAliasingTest,MovesAcrossContainedWritesNested)1069 TEST(ContainerAliasingTest, MovesAcrossContainedWritesNested) {
1070 // The same as above, but with a nested list
1071 auto ops = torch::RegisterOperators().op(
1072 "uses::list",
1073 torch::RegisterOperators::options()
1074 .catchAllKernel([](torch::List<at::Tensor> in) {
1075 return torch::rand({2, 3});
1076 })
1077 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1078 // Write to the inside of a list. Check that we can't reorder a
1079 // print across it.
1080 auto graph = std::make_shared<Graph>();
1081 std::unordered_map<std::string, Value*> vmap;
1082 parseIR(
1083 R"IR(
1084 graph():
1085 %38 : int = prim::Constant[value=1]()
1086 %0 : int = prim::Constant[value=2]()
1087 %1 : int = prim::Constant[value=3]()
1088 %24 : int = prim::Constant[value=0]()
1089 %2 : int[] = prim::ListConstruct(%0, %1)
1090 %11 : Tensor = prim::MakeTestTensor()
1091 %12 : int[] = prim::ListConstruct(%0, %1)
1092 %21 : Tensor = prim::MakeTestTensor()
1093 %l : Tensor[] = prim::ListConstruct(%11, %21)
1094 %25 : Tensor = aten::select(%l, %24)
1095 %27 : Tensor = aten::select(%25, %24, %24)
1096 %28 : int[] = prim::ListConstruct(%0, %1)
1097 %37 : Tensor = prim::MakeTestTensor()
1098 %39 : Tensor = aten::add_(%27, %37, %38)
1099 %40 : Tensor = uses::list(%l)
1100 return (%40)
1101 )IR",
1102 graph.get(),
1103 vmap);
1104 AliasDb aliasDb(graph);
1105 auto listUse = vmap["40"]->node();
1106 auto internalWrite = vmap["39"]->node();
1107 EXPECT_FALSE(aliasDb.moveBeforeTopologicallyValid(listUse, internalWrite));
1108 }
1109
TEST(WildcardsTest,Basic)1110 TEST(WildcardsTest, Basic) {
1111 RegisterOperators reg(
1112 {Operator(
1113 "prim::returns_wildcard(Tensor a) -> Tensor(*)",
1114 [](Stack&) {},
1115 aliasAnalysisFromSchema()),
1116 Operator(
1117 "prim::writes(Tensor(z!) a) -> Tensor(a)",
1118 [](Stack&) {},
1119 aliasAnalysisFromSchema())});
1120 const auto returns_wildcard =
1121 Symbol::fromQualString("prim::returns_wildcard");
1122 const auto writes = Symbol::fromQualString("prim::writes");
1123
1124 auto graph = std::make_shared<Graph>();
1125 const auto a = graph->addInput();
1126
1127 const auto constant = graph->insertConstant(1);
1128 const auto fresh = graph->insert(aten::rand, {constant});
1129 const auto fresh2 = graph->insert(aten::rand, {constant});
1130 const auto wildcard = graph->insert(returns_wildcard, {fresh});
1131
1132 {
1133 graph->lint();
1134 AliasDb aliasDb(graph);
1135
1136 EXPECT_FALSE(aliasDb.mayAlias(a, fresh));
1137 EXPECT_FALSE(aliasDb.mayAlias(wildcard, fresh));
1138 EXPECT_TRUE(aliasDb.mayAlias(wildcard, a));
1139 EXPECT_FALSE(aliasDb.mayAlias(ValueSet{wildcard}, ValueSet{}));
1140 EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
1141 }
1142
1143 graph->insert(writes, {fresh2})->node();
1144 {
1145 graph->lint();
1146 AliasDb aliasDb(graph);
1147 EXPECT_FALSE(aliasDb.hasWriters(wildcard->node()));
1148 }
1149
1150 const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
1151 {
1152 graph->lint();
1153 AliasDb aliasDb(graph);
1154 // Test writes to wildcards
1155 EXPECT_FALSE(aliasDb.writesToAlias(
1156 wildcardWrite, std::unordered_set<const Value*>{fresh}));
1157 EXPECT_FALSE(aliasDb.writesToAlias(
1158 wildcardWrite, std::unordered_set<const Value*>{fresh2}));
1159 EXPECT_TRUE(aliasDb.writesToAlias(
1160 wildcardWrite, std::unordered_set<const Value*>{a}));
1161 EXPECT_TRUE(aliasDb.hasWriters(wildcard->node()));
1162 }
1163 }
1164
1165 // test that wildcards are correctly divided by type
TEST(WildcardsTest,TypeIsolation)1166 TEST(WildcardsTest, TypeIsolation) {
1167 auto graph = std::make_shared<Graph>();
1168 std::unordered_map<std::string, Value*> vmap;
1169 parseIR(
1170 R"IR(
1171 graph(%ten_list : Tensor[], %int_list : int[], %opt_ten_list : Tensor[]?):
1172 %ten : Tensor = prim::Constant()
1173 %4 : Tensor[] = aten::append(%ten_list, %ten)
1174 %ten_ten_list : Tensor[][] = prim::Constant()
1175 %int_int_list : int[][] = prim::Constant()
1176 return ()
1177 )IR",
1178 &*graph,
1179 vmap);
1180 AliasDb aliasDb(graph);
1181 auto opt_ten_list = vmap["opt_ten_list"];
1182 auto ten_list = vmap["ten_list"];
1183 auto int_list = vmap["int_list"];
1184 EXPECT_FALSE(aliasDb.hasWriters(int_list));
1185 EXPECT_TRUE(aliasDb.hasWriters(opt_ten_list));
1186 EXPECT_TRUE(aliasDb.hasWriters(ten_list));
1187 EXPECT_FALSE(aliasDb.mayContainAlias(int_list, opt_ten_list));
1188 EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, opt_ten_list));
1189 EXPECT_TRUE(aliasDb.mayAlias(ten_list, opt_ten_list));
1190
1191 auto list_of_tensor_lists = vmap["ten_ten_list"];
1192 EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, list_of_tensor_lists));
1193 EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, vmap["ten"]));
1194
1195 EXPECT_TRUE(
1196 !aliasDb.mayContainAlias(vmap["int_int_list"], list_of_tensor_lists));
1197 }
1198
1199 // test invariant container aliasing
1200 // the containers of different type cannot alias each other,
1201 // however they may contain elements which alias each other
TEST(WildcardsTest,InvariantContainerAliasing)1202 TEST(WildcardsTest, InvariantContainerAliasing) {
1203 {
1204 auto graph = std::make_shared<Graph>();
1205 std::unordered_map<std::string, Value*> vmap;
1206 parseIR(
1207 R"IR(
1208 graph(%ten_list : Tensor[], %ten_opt_list : Tensor?[]):
1209 %ten : Tensor = prim::Constant()
1210 %4 : Tensor[] = aten::append(%ten_list, %ten)
1211 return ()
1212 )IR",
1213 &*graph,
1214 vmap);
1215 AliasDb aliasDb(graph);
1216 auto ten_opt_list = vmap["ten_opt_list"];
1217 auto ten_list = vmap["ten_list"];
1218 EXPECT_FALSE(aliasDb.hasWriters(ten_opt_list));
1219 EXPECT_TRUE(aliasDb.hasWriters(ten_list));
1220 EXPECT_TRUE(aliasDb.mayContainAlias(ten_list, ten_opt_list));
1221 EXPECT_FALSE(aliasDb.mayAlias(ten_list, ten_opt_list));
1222 }
1223 {
1224 auto graph = std::make_shared<Graph>();
1225 std::unordered_map<std::string, Value*> vmap;
1226 parseIR(
1227 R"IR(
1228 graph(%float_3D : Float(*, *, *), %float_2D : Float(*, *)):
1229 return ()
1230 )IR",
1231 &*graph,
1232 vmap);
1233 AliasDb aliasDb(graph);
1234 EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D"], vmap["float_2D"]));
1235 }
1236
1237 {
1238 auto graph = std::make_shared<Graph>();
1239 std::unordered_map<std::string, Value*> vmap;
1240 parseIR(
1241 R"IR(
1242 graph(%float_3D_list : Float(*, *, *)[], %float_2D_list : Float(*, *)[], %ten: Tensor):
1243 return ()
1244 )IR",
1245 &*graph,
1246 vmap);
1247 AliasDb aliasDb(graph);
1248 EXPECT_TRUE(aliasDb.mayAlias(vmap["float_3D_list"], vmap["float_2D_list"]));
1249 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_3D_list"], vmap["ten"]));
1250 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["float_2D_list"], vmap["ten"]));
1251 }
1252 }
1253
TEST(AliasRegistrationTest,ConservativeWithInferredSchema)1254 TEST(AliasRegistrationTest, ConservativeWithInferredSchema) {
1255 auto registry = torch::RegisterOperators().op(
1256 "foo::rand1",
1257 torch::RegisterOperators::options()
1258 .catchAllKernel([](at::Tensor) -> at::Tensor {
1259 return at::rand({2, 2});
1260 })
1261 .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1262 const auto rand_op = Symbol::fromQualString("foo::rand1");
1263 auto graph = std::make_shared<Graph>();
1264 auto a = graph->addInput();
1265 auto b = graph->insert(rand_op, {a});
1266 AliasDb aliasDb(graph);
1267 // Conservatively we assume there is a reference
1268 EXPECT_TRUE(aliasDb.mayAlias(a, b));
1269 }
1270
TEST(AliasRegistrationTest,ConservativeWithSpecifiedSchema)1271 TEST(AliasRegistrationTest, ConservativeWithSpecifiedSchema) {
1272 auto registry = torch::RegisterOperators().op(
1273 "foo::rand2(Tensor arg1) -> Tensor",
1274 torch::RegisterOperators::options()
1275 .catchAllKernel([](at::Tensor) -> at::Tensor {
1276 return at::rand({2, 2});
1277 })
1278 .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1279 const auto rand_op = Symbol::fromQualString("foo::rand2");
1280 auto graph = std::make_shared<Graph>();
1281 auto a = graph->addInput();
1282 auto b = graph->insert(rand_op, {a});
1283 AliasDb aliasDb(graph);
1284 // Conservatively we assume there is a reference
1285 EXPECT_TRUE(aliasDb.mayAlias(a, b));
1286 }
1287
TEST(AliasRegistrationTest,ConservativeWithAliasingAnnotationsShouldError)1288 TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError) {
1289 auto registry = torch::RegisterOperators().op(
1290 "foo::rand3(Tensor(a) arg1) -> Tensor(b)",
1291 torch::RegisterOperators::options()
1292 .catchAllKernel([](at::Tensor) -> at::Tensor {
1293 return at::rand({2, 2});
1294 })
1295 .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1296
1297 const auto rand_op = Symbol::fromQualString("foo::rand3");
1298 auto graph = std::make_shared<Graph>();
1299 auto a = graph->addInput();
1300 graph->insert(rand_op, {a});
1301
1302 // Registration time is okay, but throw exception when fetch from
1303 // registration.
1304 expectThrows<c10::Error>(
1305 [&graph] { AliasDb aliasDb(graph); },
1306 "Tried to register operator foo::rand3(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1307 }
1308
TEST(AliasRegistrationTest,ConservativeWithAliasingAnnotationsShouldError2)1309 TEST(AliasRegistrationTest, ConservativeWithAliasingAnnotationsShouldError2) {
1310 auto registry = torch::RegisterOperators().op(
1311 "foo::rand4(Tensor(a) arg1) -> Tensor(a)",
1312 torch::RegisterOperators::options()
1313 .catchAllKernel([](at::Tensor) -> at::Tensor {
1314 return at::rand({2, 2});
1315 })
1316 .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE));
1317 const auto rand_op = Symbol::fromQualString("foo::rand4");
1318 auto graph = std::make_shared<Graph>();
1319 auto a = graph->addInput();
1320 graph->insert(rand_op, {a});
1321
1322 // Registration time is okay, but throw exception when fetch from
1323 // registration.
1324 expectThrows<c10::Error>(
1325 [&graph] { AliasDb aliasDb(graph); },
1326 "Tried to register operator foo::rand4(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1327 }
1328
TEST(AliasRegistrationTest,FromSchemaWithInferredSchemaShouldError)1329 TEST(AliasRegistrationTest, FromSchemaWithInferredSchemaShouldError) {
1330 expectThrows<c10::Error>(
1331 [] {
1332 torch::RegisterOperators().op(
1333 "foo::rand5",
1334 torch::RegisterOperators::options()
1335 .catchAllKernel([](at::Tensor) -> at::Tensor {
1336 return at::rand({2, 2});
1337 })
1338 .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1339 },
1340 "Tried to register operator foo::rand5(Tensor _0) -> Tensor _0 with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred");
1341 }
1342
TEST(AliasRegistrationTest,FromSchemaInferredPure)1343 TEST(AliasRegistrationTest, FromSchemaInferredPure) {
1344 auto registry = torch::RegisterOperators().op(
1345 "foo::rand6(Tensor arg1) -> Tensor",
1346 torch::RegisterOperators::options()
1347 .catchAllKernel([](at::Tensor) -> at::Tensor {
1348 return at::rand({2, 2});
1349 })
1350 .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1351 const auto rand_op = Symbol::fromQualString("foo::rand6");
1352 auto graph = std::make_shared<Graph>();
1353 auto a = graph->addInput();
1354 auto b = graph->insert(rand_op, {a});
1355 AliasDb aliasDb(graph);
1356 // The schema doesn't contain alias information, which means it's pure
1357 // (meh!)
1358 EXPECT_FALSE(aliasDb.mayAlias(a, b));
1359 }
1360
TEST(AliasRegistrationTest,FromSchemaAliased)1361 TEST(AliasRegistrationTest, FromSchemaAliased) {
1362 auto registry = torch::RegisterOperators().op(
1363 "foo::rand7(Tensor(a) arg1) -> Tensor(a)",
1364 torch::RegisterOperators::options()
1365 .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1366 .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1367 const auto rand_op = Symbol::fromQualString("foo::rand7");
1368
1369 auto graph = std::make_shared<Graph>();
1370 auto a = graph->addInput();
1371 auto b = graph->insert(rand_op, {a});
1372 AliasDb aliasDb(graph);
1373 // The schema has an alias reference
1374 EXPECT_TRUE(aliasDb.mayAlias(a, b));
1375 }
1376
TEST(AliasRegistrationTest,FromSchemaPure)1377 TEST(AliasRegistrationTest, FromSchemaPure) {
1378 auto registry = torch::RegisterOperators().op(
1379 "foo::rand8(Tensor(a) arg1) -> Tensor(b)",
1380 torch::RegisterOperators::options()
1381 .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1382 .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA));
1383 const auto rand_op = Symbol::fromQualString("foo::rand8");
1384 auto graph = std::make_shared<Graph>();
1385 auto a = graph->addInput();
1386 auto b = graph->insert(rand_op, {a});
1387 AliasDb aliasDb(graph);
1388 // The schema does not have an alias reference
1389 EXPECT_FALSE(aliasDb.mayAlias(a, b));
1390 }
1391
TEST(AliasRegistrationTest,PureNoSchema)1392 TEST(AliasRegistrationTest, PureNoSchema) {
1393 auto registry = torch::RegisterOperators().op(
1394 "foo::rand9",
1395 torch::RegisterOperators::options()
1396 .catchAllKernel([](at::Tensor) -> at::Tensor {
1397 return at::rand({2, 2});
1398 })
1399 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1400 const auto rand_op = Symbol::fromQualString("foo::rand9");
1401 auto graph = std::make_shared<Graph>();
1402 auto a = graph->addInput();
1403 auto b = graph->insert(rand_op, {a});
1404 AliasDb aliasDb(graph);
1405 // The schema is pure, there cannot be any alias
1406 EXPECT_FALSE(aliasDb.mayAlias(a, b));
1407 }
1408
TEST(AliasRegistrationTest,PureWithSchema)1409 TEST(AliasRegistrationTest, PureWithSchema) {
1410 auto registry = torch::RegisterOperators().op(
1411 "foo::rand10(Tensor arg1) -> Tensor",
1412 torch::RegisterOperators::options()
1413 .catchAllKernel([](at::Tensor) -> at::Tensor {
1414 return at::rand({2, 2});
1415 })
1416 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1417 const auto rand_op = Symbol::fromQualString("foo::rand10");
1418 auto graph = std::make_shared<Graph>();
1419 auto a = graph->addInput();
1420 auto b = graph->insert(rand_op, {a});
1421 AliasDb aliasDb(graph);
1422 // The schema is pure, there cannot be any alias
1423 EXPECT_FALSE(aliasDb.mayAlias(a, b));
1424 }
1425
TEST(AliasRegistrationTest,PureWithAnnotationsShouldError)1426 TEST(AliasRegistrationTest, PureWithAnnotationsShouldError) {
1427 auto registry = torch::RegisterOperators().op(
1428 "foo::rand11(Tensor(a) arg1) -> Tensor(a)",
1429 torch::RegisterOperators::options()
1430 .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1431 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1432 const auto rand_op = Symbol::fromQualString("foo::rand11");
1433 auto graph = std::make_shared<Graph>();
1434 auto a = graph->addInput();
1435 graph->insert(rand_op, {a});
1436
1437 // Registration time is okay, but throw exception when fetch from
1438 // registration.
1439 expectThrows<c10::Error>(
1440 [&graph] { AliasDb aliasDb(graph); },
1441 "Tried to register operator foo::rand11(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1442 }
1443
TEST(AliasRegistrationTest,AliasMoveAtenListOp)1444 TEST(AliasRegistrationTest, AliasMoveAtenListOp) {
1445 auto graph = std::make_shared<Graph>();
1446 std::unordered_map<std::string, Value*> vmap;
1447 auto graph_string = R"IR(
1448 graph():
1449 %x : Tensor = prim::MakeTestTensor()
1450 %8 : int = prim::Constant[value=0]()
1451 %5 : int = prim::Constant[value=1]()
1452 %4 : int = prim::Constant[value=2]()
1453 %y : Tensor[] = prim::ListConstruct(%x)
1454 %6 : Tensor = aten::add_(%x, %4, %5)
1455 %9 : Tensor = aten::cat(%y, %8)
1456 return (%9))IR";
1457
1458 torch::jit::parseIR(graph_string, graph.get(), vmap);
1459 AliasDb aliasDb(graph);
1460
1461 // bc y.1 has a single used in a single non-aliasing aten op,
1462 // x is added to y.1 contained elements instead of wildcard set
1463 EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["9"]));
1464
1465 // write to contained element should prevent move
1466 EXPECT_TRUE(!aliasDb.moveBeforeTopologicallyValid(
1467 vmap["y"]->node(), vmap["9"]->node()));
1468 }
1469
TEST(AliasRegistrationTest,AliasMoveForTupleConstructWithSingleUseAsGraphOutput)1470 TEST(
1471 AliasRegistrationTest,
1472 AliasMoveForTupleConstructWithSingleUseAsGraphOutput) {
1473 auto graph = std::make_shared<Graph>();
1474 std::unordered_map<std::string, Value*> vmap;
1475 auto graph_string = R"IR(
1476 graph():
1477 %x : Tensor = prim::MakeTestTensor()
1478 %y : Tensor = prim::MakeTestTensor()
1479 %z : (Tensor) = prim::TupleConstruct(%x, %y)
1480 return (%z))IR";
1481
1482 torch::jit::parseIR(graph_string, graph.get(), vmap);
1483 AliasDb aliasDb(graph, /*isFrozen=*/false);
1484
1485 EXPECT_TRUE(!aliasDb.mayAlias(vmap["x"], vmap["y"]));
1486 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["x"]));
1487 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["z"], vmap["y"]));
1488 }
1489
TEST(AliasRegistrationTest,RecursiveSubgraphTupleContainment)1490 TEST(AliasRegistrationTest, RecursiveSubgraphTupleContainment) {
1491 auto graph = std::make_shared<Graph>();
1492 std::unordered_map<std::string, Value*> vmap;
1493 auto graph_string = R"IR(
1494 graph():
1495 %x : Tensor = prim::MakeTestTensor()
1496 %y : Tensor = prim::MakeTestTensor()
1497 %z : (Tensor, Tensor) = prim::TupleConstruct(%x, %y)
1498 return (%z))IR";
1499
1500 torch::jit::parseIR(graph_string, graph.get(), vmap);
1501 auto node = vmap["z"]->node();
1502 auto subgraph =
1503 SubgraphUtils::createSingletonSubgraph(node, prim::FunctionalGraph);
1504 AliasDb aliasDb(graph);
1505
1506 EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["x"]));
1507 EXPECT_TRUE(aliasDb.mayContainAlias(subgraph->output(), vmap["y"]));
1508 EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
1509 }
1510
TEST(AliasRegistrationTest,WildcardAliasForTupleConstructWithUses)1511 TEST(AliasRegistrationTest, WildcardAliasForTupleConstructWithUses) {
1512 auto graph = std::make_shared<Graph>();
1513 std::unordered_map<std::string, Value*> vmap;
1514 auto graph_string = R"IR(
1515 graph():
1516 %x : Tensor = prim::MakeTestTensor()
1517 %y : Tensor = prim::MakeTestTensor()
1518 %z : Tensor = prim::MakeTestTensor()
1519 %0 : int = prim::Constant[value=0]()
1520 %a : (Tensor) = prim::TupleConstruct(%x, %y)
1521 %b : (Tensor) = prim::TupleConstruct(%z)
1522 %c : Tensor = prim::TupleIndex(%a, %0)
1523 %d : Tensor = prim::TupleIndex(%b, %0)
1524 return (%c, %d))IR";
1525
1526 torch::jit::parseIR(graph_string, graph.get(), vmap);
1527 AliasDb aliasDb(graph, /*isFrozen=*/false);
1528
1529 EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["y"]));
1530 EXPECT_TRUE(aliasDb.mayAlias(vmap["x"], vmap["z"]));
1531 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["z"]));
1532 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["x"]));
1533 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["y"]));
1534 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["a"], vmap["z"]));
1535 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["x"]));
1536 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["y"]));
1537 EXPECT_TRUE(aliasDb.mayContainAlias(vmap["b"], vmap["z"]));
1538 }
1539
TEST(AliasRegistrationTest,ATenSplitIntListAliasCheck)1540 TEST(AliasRegistrationTest, ATenSplitIntListAliasCheck) {
1541 auto graph = std::make_shared<Graph>();
1542 std::unordered_map<std::string, Value*> vmap;
1543 auto graph_string = R"IR(
1544 graph():
1545 %x : Tensor = prim::MakeTestTensor()
1546 %0 : int = prim::Constant[value=0]()
1547 %1 : int = prim::Constant[value=1]()
1548 %2 : int = prim::Constant[value=2]()
1549 %y : Tensor = aten::add(%x, %x, %0)
1550 %lengths_list : int[] = prim::tolist(%1, %2)
1551 %a : Tensor[] = aten::split(%y, %lengths_list, %0)
1552 %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1553 %b1 : Tensor = aten::flatten(%b, %0, %1)
1554 %c1 : Tensor = aten::flatten(%c, %0, %1)
1555 %d : Tensor = aten::add(%b1, %c1, %0)
1556 return (%d))IR";
1557
1558 torch::jit::parseIR(graph_string, graph.get(), vmap);
1559 AliasDb aliasDb(graph, /*isFrozen=*/false);
1560
1561 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1562 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1563 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1564 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1565 }
1566
TEST(AliasRegistrationTest,ATenSplitIntAliasCheck)1567 TEST(AliasRegistrationTest, ATenSplitIntAliasCheck) {
1568 auto graph = std::make_shared<Graph>();
1569 std::unordered_map<std::string, Value*> vmap;
1570 auto graph_string = R"IR(
1571 graph():
1572 %x : Tensor = prim::MakeTestTensor()
1573 %0 : int = prim::Constant[value=0]()
1574 %1 : int = prim::Constant[value=1]()
1575 %2 : int = prim::Constant[value=2]()
1576 %y : Tensor = aten::add(%x, %x, %0)
1577 %a : Tensor[] = aten::split(%y, %2, %0)
1578 %b : Tensor, %c : Tensor = prim::ListUnpack(%a)
1579 %b1 : Tensor = aten::flatten(%b, %0, %1)
1580 %c1 : Tensor = aten::flatten(%c, %0, %1)
1581 %d : Tensor = aten::add(%b1, %c1, %0)
1582 return (%d))IR";
1583
1584 torch::jit::parseIR(graph_string, graph.get(), vmap);
1585 AliasDb aliasDb(graph, /*isFrozen=*/false);
1586
1587 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b"]));
1588 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c"]));
1589 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["b1"]));
1590 EXPECT_TRUE(aliasDb.mayAlias(vmap["y"], vmap["c1"]));
1591 }
1592
TEST(AliasRegistrationTest,PureWithAnnotationsShouldError2)1593 TEST(AliasRegistrationTest, PureWithAnnotationsShouldError2) {
1594 auto registry = torch::RegisterOperators().op(
1595 "foo::rand12(Tensor(a) arg1) -> Tensor(b)",
1596 torch::RegisterOperators::options()
1597 .catchAllKernel([](at::Tensor t) -> at::Tensor { return t * 2; })
1598 .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION));
1599 const auto rand_op = Symbol::fromQualString("foo::rand12");
1600 auto graph = std::make_shared<Graph>();
1601 auto a = graph->addInput();
1602 graph->insert(rand_op, {a});
1603
1604 // Registration time is okay, but throw exception when fetch from
1605 // registration.
1606 expectThrows<c10::Error>(
1607 [&graph] { AliasDb aliasDb(graph); },
1608 "Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA");
1609 }
1610
TEST(IRNonDeterminismTest,Basic)1611 TEST(IRNonDeterminismTest, Basic) {
1612 auto graph = std::make_shared<Graph>();
1613 auto graph_string = R"IR(
1614 graph():
1615 %x : Tensor = prim::MakeTestTensor()
1616 %0 : int = prim::Constant[value=0]()
1617 %1 : NoneType = prim::Constant()
1618 %2 : Tensor = aten::bernoulli(%x, %1)
1619 %3 : Tensor = aten::add(%x, %2, %0)
1620 return (%3))IR";
1621 parseIR(graph_string, graph.get());
1622
1623 for (Node* n : graph->nodes()) {
1624 if (n->kind() == aten::bernoulli) {
1625 ASSERT_TRUE(n->isNondeterministic());
1626 } else {
1627 ASSERT_FALSE(n->isNondeterministic());
1628 }
1629 }
1630 }
1631
TEST(IRNonDeterminismTest,DropoutSpecialCase)1632 TEST(IRNonDeterminismTest, DropoutSpecialCase) {
1633 auto graph = std::make_shared<Graph>();
1634 auto graph_string = R"IR(
1635 graph():
1636 %x : Tensor = prim::MakeTestTensor()
1637 %0 : bool = prim::Constant[value=0]()
1638 %1 : bool = prim::Constant[value=1]()
1639 %3 : int = prim::Constant[value=1]()
1640 %3 : float = prim::Constant[value=1.0]()
1641 %4 : Tensor = aten::dropout(%x, %3, %0)
1642 %5 : Tensor = aten::dropout(%x, %3, %1)
1643 %6 : Tensor = aten::add(%4, %5, %3)
1644 return (%6))IR";
1645 parseIR(graph_string, graph.get());
1646
1647 bool train = false;
1648 for (Node* n : graph->nodes()) {
1649 if (n->kind() == aten::dropout) {
1650 if (!train) {
1651 ASSERT_FALSE(n->isNondeterministic());
1652 train = true;
1653 } else {
1654 ASSERT_TRUE(n->isNondeterministic());
1655 }
1656 } else {
1657 ASSERT_FALSE(n->isNondeterministic());
1658 }
1659 }
1660 }
1661
TEST(NonDeterminismBackwardsCompatibility,BackwardsCompatibility)1662 TEST(NonDeterminismBackwardsCompatibility, BackwardsCompatibility) {
1663 static const std::vector<std::string> nondeterministic_ops = {
1664 "aten::dropout(Tensor input, float p, bool train) -> Tensor",
1665 "aten::_fused_dropout(Tensor self, float p, Generator? generator) -> (Tensor, Tensor)",
1666 "aten::_standard_gamma(Tensor self, Generator? generator) -> Tensor",
1667 "aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
1668 "aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
1669 "aten::multinomial(Tensor self, int num_samples, bool replacement, *, Generator? generator) -> Tensor",
1670 "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)",
1671 "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator) -> Tensor",
1672 "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator) -> Tensor",
1673 "aten::normal.Tensor_float(Tensor mean, float std, *, Generator? generator) -> Tensor",
1674 "aten::poisson(Tensor self, Generator? generator) -> Tensor",
1675 "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor",
1676 "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
1677 "aten::rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
1678 "aten::rand(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1679 "aten::rand_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1680 "aten::randint(int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1681 "aten::randint(int low, int high, int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1682 "aten::randint_like(Tensor self, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1683 "aten::randint_like(Tensor self, int low, int high, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1684 "aten::randn(int[] size, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor",
1685 "aten::randn_like(Tensor self, *, int? dtype=None, int? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor",
1686 "aten::randperm(int n, *, int? dtype, int? layout, Device? device, bool? pin_memory) -> Tensor"};
1687 for (const std::string& op : nondeterministic_ops) {
1688 const c10::FunctionSchema& schema = torch::jit::parseSchema(op);
1689 const auto& op_handle = c10::Dispatcher::singleton().findOp(
1690 c10::OperatorName(schema.name(), schema.overload_name()));
1691 ASSERT_TRUE(op_handle->hasTag(at::Tag::nondeterministic_seeded));
1692 }
1693 }
1694
TEST(TypeHashing,HashTypes)1695 TEST(TypeHashing, HashTypes) {
1696 HashType hasher;
1697
1698 const TypePtr int_type = IntType::get();
1699 const TypePtr float_type = FloatType::get();
1700 ASSERT_NE(hasher(int_type), hasher(float_type));
1701
1702 const TypePtr int2_type = TupleType::create({int_type, int_type});
1703 const TypePtr int3_type = TupleType::create({int_type, int_type, int_type});
1704 ASSERT_NE(hasher(int2_type), hasher(int3_type));
1705 }
1706
1707 } // namespace jit
1708 } // namespace torch
1709