xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_add_if_then_else.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 #include <torch/csrc/jit/ir/irparser.h>
5 #include <torch/csrc/jit/passes/add_if_then_else.h>
6 
7 namespace torch {
8 namespace jit {
9 
TEST(AddIfThenElseOpTest,AddIfThenElseOpSimple)10 TEST(AddIfThenElseOpTest, AddIfThenElseOpSimple) {
11   const auto src = R"IR(
12         graph(%cond: bool, %a: Tensor, %b: Tensor):
13             %result: Tensor = prim::If(%cond)
14                 block0():
15                     -> (%a)
16                 block1():
17                     -> (%b)
18             return (%result)
19     )IR";
20 
21   auto graph = std::make_shared<Graph>();
22   parseIR(src, graph.get());
23   EXPECT_TRUE(AddIfThenElseOp(graph));
24 
25   testing::FileCheck()
26       .check_count("= prim::IfThenElse", 1, /*exactly*/ true)
27       ->check_count("= prim::If", 0, /*exactly*/ true)
28       ->run(*graph);
29 }
30 
TEST(AddIfThenElseOpTest,NoIfThenElseOpMultipleOutputs)31 TEST(AddIfThenElseOpTest, NoIfThenElseOpMultipleOutputs) {
32   const auto src = R"IR(
33         graph(%cond: bool, %a: Tensor, %b: Tensor):
34             %result1: Tensor, %result2: Tensor = prim::If(%cond)
35                 block0():
36                     -> (%a, %b)
37                 block1():
38                     -> (%b, %a)
39             return (%result1, %result2)
40     )IR";
41 
42   auto graph = std::make_shared<Graph>();
43   parseIR(src, graph.get());
44   EXPECT_FALSE(AddIfThenElseOp(graph));
45 
46   testing::FileCheck()
47       .check_count("= prim::IfThenElse", 0, /*exactly*/ true)
48       ->check_count("= prim::If", 1, /*exactly*/ true)
49       ->run(*graph);
50 }
51 
52 } // namespace jit
53 } // namespace torch
54