xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_fuser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/interned_strings.h>
5 #include <ATen/core/ivalue.h>
6 #include <c10/util/irange.h>
7 
8 #include <torch/csrc/autograd/engine.h>
9 #include <torch/csrc/autograd/generated/variable_factories.h>
10 #include <torch/csrc/autograd/variable.h>
11 #include <torch/csrc/jit/api/module.h>
12 #include <torch/csrc/jit/codegen/cuda/interface.h>
13 #include <torch/csrc/jit/codegen/fuser/interface.h>
14 #include <torch/csrc/jit/frontend/ir_emitter.h>
15 #include <torch/csrc/jit/frontend/tracer.h>
16 #include <torch/csrc/jit/ir/alias_analysis.h>
17 #include <torch/csrc/jit/ir/attributes.h>
18 #include <torch/csrc/jit/ir/irparser.h>
19 #include <torch/csrc/jit/passes/canonicalize.h>
20 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
21 #include <torch/csrc/jit/passes/constant_propagation.h>
22 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
23 #include <torch/csrc/jit/passes/dead_code_elimination.h>
24 #include <torch/csrc/jit/passes/graph_fuser.h>
25 #include <torch/csrc/jit/passes/lower_grad_of.h>
26 #include <torch/csrc/jit/passes/lower_tuples.h>
27 #include <torch/csrc/jit/passes/requires_grad_analysis.h>
28 #include <torch/csrc/jit/passes/shape_analysis.h>
29 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
30 #include <torch/csrc/jit/runtime/argument_spec.h>
31 #include <torch/csrc/jit/runtime/autodiff.h>
32 #include <torch/csrc/jit/runtime/custom_operator.h>
33 #include <torch/csrc/jit/runtime/graph_executor.h>
34 #include <torch/csrc/jit/runtime/interpreter.h>
35 #include <torch/csrc/jit/runtime/symbolic_script.h>
36 #include <torch/csrc/jit/serialization/import.h>
37 #include <torch/csrc/jit/testing/file_check.h>
38 
39 #include <onnx/onnx_pb.h>
40 
41 #include <c10/util/Exception.h>
42 
43 #include <algorithm>
44 #include <cstddef>
45 #include <functional>
46 #include <iostream>
47 #include <memory>
48 #include <stdexcept>
49 #include <string>
50 #include <tuple>
51 #include <unordered_set>
52 #include <utility>
53 #include <vector>
54 
55 namespace torch {
56 namespace jit {
57 
58 class FuserTest : public ::testing::Test {
SetUp()59   void SetUp() override {
60     old_nvfuser_value_ = fuser::cuda::setEnabled(false);
61   }
TearDown()62   void TearDown() override {
63     fuser::cuda::setEnabled(old_nvfuser_value_);
64   }
65 
66  private:
67   bool old_nvfuser_value_;
68 };
69 
TEST_F(FuserTest,TestSimple_CUDA)70 TEST_F(FuserTest, TestSimple_CUDA) {
71 #if defined(FBCODE_CAFFE2)
72   return;
73 #endif
74   const auto graph_string = R"IR(
75       graph(%0 : Tensor,
76             %1 : Tensor):
77         %2 : Tensor = aten::mul(%0, %1)
78         return (%2))IR";
79   Graph graph;
80   torch::jit::parseIR(graph_string, &graph);
81 
82   auto a = at::rand({3, 4}, at::kCUDA);
83   auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
84   auto o = at::zeros({3, 4}, at::kCUDA);
85   auto outputs = debugLaunchGraph(graph, {a, b});
86   ASSERT_EQ(outputs.size(), 1);
87   auto o2 = a * b;
88   float max_diff = (o2 - outputs[0]).abs().max().item<double>();
89   // std::cout << "max diff: " << max_diff << "\n";
90   ASSERT_EQ(max_diff, 0);
91 }
92 
TEST_F(FuserTest,TestOne_CUDA)93 TEST_F(FuserTest, TestOne_CUDA) {
94 #if defined(FBCODE_CAFFE2)
95   return;
96 #endif
97   auto testOne = [&](int ti, int tj) {
98     const auto graph_string = R"IR(
99       graph(%0 : Tensor,
100             %1 : Tensor,
101             %2 : Tensor,
102             %3 : Tensor,
103             %4 : Tensor):
104         %5 : Tensor = aten::sigmoid(%4)
105         %6 : Tensor = aten::sigmoid(%3)
106         %7 : Tensor = aten::tanh(%2)
107         %8 : Tensor = aten::sigmoid(%1)
108         %9 : Tensor = aten::mul(%6, %0)
109         %10 : Tensor = aten::mul(%5, %7)
110         %11 : int = prim::Constant[value=1]()
111         %12 : Tensor = aten::add(%9, %10, %11)
112         %13 : Tensor = aten::tanh(%12)
113         %14 : Tensor = aten::mul(%8, %13)
114         return (%14, %12))IR";
115     Graph graph;
116     torch::jit::parseIR(graph_string, &graph);
117 
118     graph.lint();
119 
120     std::vector<at::Tensor> inputs;
121     // We want to generate input/output tensors with dimension 128x128x32, but
122     // with different internal strides.  To do this, we generate a tensor
123     // with the "wrong" dimensions, and then use transpose to get an
124     // appropriately sized view.
125     std::generate_n(
126         std::back_inserter(inputs), graph.inputs().size(), [ti, tj] {
127           std::array<int64_t, 3> dims = {128, 128, 32};
128           std::swap(dims[ti], dims[tj]);
129           return at::rand(dims, at::kCUDA).transpose(ti, tj);
130         });
131 
132     auto t22 = inputs[4].sigmoid();
133     auto t20 = inputs[3].sigmoid();
134     auto t18 = inputs[2].tanh();
135     auto t16 = inputs[1].sigmoid();
136     auto t14 = t20 * inputs[0];
137     auto t11 = t22 * t18;
138     auto out1 = t14 + t11;
139     auto t5 = out1.tanh();
140     auto out0 = t16 * t5;
141 
142     auto outputs = debugLaunchGraph(graph, inputs);
143     ASSERT_EQ(outputs.size(), graph.outputs().size());
144     ASSERT_TRUE(out0.is_same_size(outputs.front()));
145     float max_diff = (outputs.front() - out0).abs().max().item<double>();
146     ASSERT_TRUE(max_diff < 1e-6);
147   };
148   testOne(0, 0);
149   testOne(0, 1);
150   testOne(1, 2);
151   testOne(0, 2);
152 }
153 
TEST_F(FuserTest,FusedConcat_CUDA)154 TEST_F(FuserTest, FusedConcat_CUDA) {
155 #if defined(FBCODE_CAFFE2)
156   return;
157 #endif
158   const auto graph_string0 = R"IR(
159     graph(%0 : Tensor,
160           %1 : Tensor):
161       %2 : Tensor = aten::mul(%0, %1)
162       %3 : Tensor = prim::FusedConcat[dim=0](%0, %2)
163       return (%2, %3))IR";
164   const auto graph_string1 = R"IR(
165     graph(%0 : Tensor,
166           %1 : Tensor):
167       %2 : Tensor = aten::mul(%0, %1)
168       %3 : Tensor = prim::FusedConcat[dim=1](%0, %2)
169       return (%2, %3))IR";
170   const auto graph_string2 = R"IR(
171     graph(%0 : Tensor,
172           %1 : Tensor):
173       %2 : Tensor = aten::mul(%0, %1)
174       %3 : Tensor = prim::FusedConcat[dim=2](%0, %2)
175       return (%2, %3))IR";
176 
177   auto a = at::rand({3, 4, 5}, at::kCUDA);
178   auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1);
179   const auto o_r = a * b;
180 
181   std::vector<std::string> graph_strings{
182       graph_string0, graph_string1, graph_string2};
183   for (const auto i : c10::irange(graph_strings.size())) {
184     Graph g;
185     torch::jit::parseIR(graph_strings[i], &g);
186 
187     auto outputs = debugLaunchGraph(g, {a, b});
188     ASSERT_EQ(outputs.size(), 2);
189 
190     float max_diff = (o_r - outputs[0]).abs().max().item<double>();
191     ASSERT_EQ(max_diff, 0);
192 
193     const auto o2_r = at::cat({a, o_r}, i);
194     float max_diff2 = (o2_r - outputs[1]).abs().max().item<double>();
195     ASSERT_EQ(max_diff2, 0);
196   };
197 }
198 
TEST_F(FuserTest,FusionAliasing)199 TEST_F(FuserTest, FusionAliasing) {
200 #if defined(FBCODE_CAFFE2)
201   return;
202 #endif
203   const auto graph_string = R"IR(
204     graph(%0 : Tensor,
205           %1 : Tensor):
206       %12 : int = prim::Constant[value=1]()
207       %2.1 : Tensor = aten::mul(%0, %1)
208       %2 : Tensor = aten::mul(%2.1, %1)
209       %3 : Tensor = aten::add_(%2, %1, %12)
210       %4 : Tensor = aten::mul(%2, %1)
211       %5 : Tensor = aten::add(%2, %4, %12)
212       return (%5))IR";
213   auto g = std::make_shared<Graph>();
214   torch::jit::parseIR(graph_string, g.get());
215 
216   g->lint();
217   FuseGraph(g);
218 
219   // We should not be able to fuse across the in-place operation here.
220   testing::FileCheck()
221       .check("prim::FusionGroup_0")
222       ->check("aten::add_")
223       ->check("prim::FusionGroup_1")
224       ->run(*g);
225 }
226 
TEST_F(FuserTest,KernelCaching)227 TEST_F(FuserTest, KernelCaching) {
228 #if defined(FBCODE_CAFFE2)
229   return;
230 #endif
231 
232   // Constructs two functionally equivalent graphs
233   const auto graph0_string = R"IR(
234     graph(%0 : Float(2, 3, 4),
235           %1 : Float(2, 3, 4)):
236       %c0 : Float(2, 3, 4) = aten::mul(%0, %1)
237       %d0 : Float(2, 3, 4) = aten::mul(%c0, %0)
238       return (%d0))IR";
239   auto g0 = std::make_shared<Graph>();
240   torch::jit::parseIR(graph0_string, g0.get());
241 
242   const auto graph1_string = R"IR(
243     graph(%0 : Float(2, 3, 4),
244           %1 : Float(2, 3, 4)):
245       %c1 : Float(2, 3, 4) = aten::mul(%0, %1)
246       %d1 : Float(2, 3, 4) = aten::mul(%c1, %0)
247       return (%d1))IR";
248   auto g1 = std::make_shared<Graph>();
249   torch::jit::parseIR(graph1_string, g1.get());
250 
251   auto getFusionGroup = [](const std::shared_ptr<Graph>& graph) {
252     const auto& nodes = graph->nodes();
253     auto maybe_fusion_group =
254         std::find_if(nodes.begin(), nodes.end(), [](const Node* node) {
255           return node->kind() == prim::FusionGroup;
256         });
257     TORCH_CHECK(
258         maybe_fusion_group != nodes.end(),
259         "testRegisterFusionCachesKernel: could not create FusionGroup");
260     return *maybe_fusion_group;
261   };
262 
263   // Creates two alpha-equivalent fusion groups
264   torch::jit::overrideCanFuseOnCPU(true);
265   FuseGraph(g0);
266   FuseGraph(g1);
267   torch::jit::overrideCanFuseOnCPU(false);
268   auto fg0 = getFusionGroup(g0);
269   auto fg1 = getFusionGroup(g1);
270 
271   // Registers both with the fusion compiler.
272   auto expected_key = registerFusion(fg0);
273   auto second_key = registerFusion(fg1);
274 
275   // Because the graphs are alpha-equivalent, they should return the same key
276   // and therefore share a KernelSpec to share kernels for specializations
277   ASSERT_EQ(second_key, expected_key);
278 }
279 } // namespace jit
280 } // namespace torch
281