xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_autodiff.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "test/cpp/jit/test_utils.h"
4 #include "torch/csrc/jit/frontend/tracer.h"
5 #include "torch/csrc/jit/passes/common_subexpression_elimination.h"
6 #include "torch/csrc/jit/passes/constant_propagation.h"
7 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
8 #include "torch/csrc/jit/passes/dead_code_elimination.h"
9 #include "torch/csrc/jit/passes/graph_fuser.h"
10 #include "torch/csrc/jit/passes/lower_grad_of.h"
11 #include "torch/csrc/jit/passes/requires_grad_analysis.h"
12 #include "torch/csrc/jit/passes/shape_analysis.h"
13 #include "torch/csrc/jit/passes/utils/subgraph_utils.h"
14 #include "torch/csrc/jit/runtime/argument_spec.h"
15 #include "torch/csrc/jit/runtime/autodiff.h"
16 #include "torch/csrc/jit/runtime/graph_iterator.h"
17 #include "torch/csrc/jit/runtime/profiling_graph_executor_impl.h"
18 #include "torch/torch.h"
19 
20 #include <ATen/ATen.h>
21 #include "torch/csrc/autograd/engine.h"
22 #include "torch/csrc/autograd/generated/variable_factories.h"
23 #include "torch/csrc/autograd/variable.h"
24 
25 namespace torch {
26 namespace jit {
27 
28 using namespace torch::autograd;
29 
30 using var_meta_type = std::vector<int64_t>;
31 using var_meta_list = std::vector<var_meta_type>;
32 using test_fn_type = std::function<variable_list(const variable_list&)>;
33 
34 struct ADTestSpec {
ADTestSpectorch::jit::ADTestSpec35   ADTestSpec(
36       const char* name,
37       // NOLINTNEXTLINE(modernize-pass-by-value)
38       var_meta_list input_meta,
39       // NOLINTNEXTLINE(modernize-pass-by-value)
40       test_fn_type test_fn,
41       float clampMax = -1.0f)
42       : name(name),
43         input_meta(input_meta),
44         test_fn(test_fn),
45         clampMax(clampMax) {}
46 
operator ()torch::jit::ADTestSpec47   variable_list operator()(const variable_list& inputs) const {
48     return test_fn(inputs);
49   };
50 
make_varstorch::jit::ADTestSpec51   std::vector<Variable> make_vars() const {
52     std::vector<Variable> out;
53     for (const auto& m : input_meta) {
54       if (clampMax > 0.0f) {
55         out.push_back(torch::randn(m, at::requires_grad(true))
56                           .clamp(-clampMax, clampMax));
57         continue;
58       }
59       out.push_back(torch::randn(m, at::requires_grad(true)));
60     }
61     return out;
62   }
63 
64   const char* name;
65   var_meta_list input_meta;
66   test_fn_type test_fn;
67   float clampMax;
68 };
69 
get_grad_outputs(const variable_list & vars)70 variable_list get_grad_outputs(const variable_list& vars) {
71   return fmap(vars, [](const Variable& v) -> Variable {
72     return at::randn(v.sizes(), v.options());
73   });
74 }
75 
grad(const variable_list & outputs,const variable_list & inputs,const variable_list & grad_outputs)76 variable_list grad(
77     const variable_list& outputs,
78     const variable_list& inputs,
79     const variable_list& grad_outputs) {
80   const auto get_edge = [](const Variable& v) {
81     return torch::autograd::impl::gradient_edge(v);
82   };
83   auto& engine = torch::autograd::Engine::get_default_engine();
84   return engine.execute(
85       fmap(outputs, get_edge),
86       grad_outputs,
87       true,
88       false,
89       false,
90       fmap(inputs, get_edge));
91 }
92 
TEST(AutodiffTest,ADFormulas)93 TEST(AutodiffTest, ADFormulas) {
94   const auto cast = [](const Variable& v) {
95     return static_cast<at::Tensor>(v);
96   };
97 
98   using VL = variable_list;
99   const var_meta_list binary_pointwise = {{2, 3, 4, 5}, {2, 3, 4, 5}};
100   const var_meta_list unary_pointwise = {{2, 3, 4, 5}};
101   const var_meta_list unary_pointwise_2d = {{2, 3}};
102   const std::vector<ADTestSpec> ad_tests = {
103       {"add",
104        binary_pointwise,
105        [](const VL& v) -> VL { return {v[0] + v[1]}; }},
106       {"sub",
107        binary_pointwise,
108        [](const VL& v) -> VL { return {v[0] - v[1]}; }},
109       {"mul",
110        binary_pointwise,
111        [](const VL& v) -> VL { return {v[0] * v[1]}; }},
112       {"sigmoid",
113        unary_pointwise,
114        [](const VL& v) -> VL { return {v[0].sigmoid()}; }},
115       // Clamp tanh input tensor values to [-3, 3]
116       // to set a minimum on gradient absolute values
117       {"tanh",
118        unary_pointwise,
119        [](const VL& v) -> VL { return {v[0].tanh()}; },
120        3.0f},
121       {"t", unary_pointwise_2d, [](const VL& v) -> VL { return {v[0].t()}; }},
122       {"view",
123        unary_pointwise_2d,
124        [](const VL& v) -> VL {
125          return {v[0].view({3, 2})};
126        }},
127       {"expand",
128        {{2, 1}},
129        [](const VL& v) -> VL {
130          return {v[0].expand({2, 3})};
131        }},
132       {"mm",
133        {{10, 12}, {12, 15}},
134        [](const VL& v) -> VL { return {v[0].mm(v[1])}; }},
135       // TODO: enable once we'll be able to capture lists across
136       // forward-backward
137       //{"chunk",   {{10, 12, 15}}, [](const VL& v) -> VL { return
138       // fmap<Variable>(v[0].chunk(4, 1)); }},
139       //{"chunk",   {{10, 12, 15}}, [](const VL& v) -> VL { return
140       // fmap<Variable>(v[0].chunk(3, 2)); }},
141       //{"split",   {{10, 12, 15}}, [](const VL& v) -> VL { return
142       // fmap<Variable>(v[0].split(4, 1)); }},
143       //{"split",   {{10, 12, 15}}, [](const VL& v) -> VL { return
144       // fmap<Variable>(v[0].split(3, 2)); }},
145   };
146 
147   for (const auto& test : ad_tests) {
148     // Get reference values form autograd
149     auto vars_in = test.make_vars();
150     auto vars_out = test(vars_in);
151     auto var_grads_in = get_grad_outputs(vars_out);
152     auto var_grads_out = grad(vars_out, vars_in, var_grads_in);
153 
154     // Trace and differentiate the op
155     auto graph = tracer::trace(
156                      fmap<IValue>(vars_in),
157                      [&test](Stack in) -> Stack {
158                        auto ivalue_inps = fmap(in, [](const IValue& v) {
159                          return Variable(v.toTensor());
160                        });
161                        return fmap<IValue>(test(ivalue_inps));
162                      },
163                      [](const Variable& var) { return ""; })
164                      .first->graph;
165     EliminateDeadCode(graph); // Tracing of some ops depends on the DCE trick
166     ConstantPropagation(graph);
167     auto grad_spec = differentiate(graph);
168     LowerGradOf(*grad_spec.df);
169     // Get outputs from the interpreter
170     auto tensors_in = fmap(vars_in, cast);
171     auto tensor_grads_in = fmap(var_grads_in, cast);
172     auto [tensors_out, tensor_grads_out] =
173         runGradient(grad_spec, tensors_in, tensor_grads_in);
174 
175     // Compare results
176     auto expected_tensors_out = fmap(vars_out, cast);
177     auto expected_tensor_grads_out = fmap(var_grads_out, cast);
178     assertAllClose(tensors_out, expected_tensors_out);
179     assertAllClose(tensor_grads_out, expected_tensor_grads_out);
180   }
181 }
182 
TEST(AutodiffTest,Differentiate)183 TEST(AutodiffTest, Differentiate) {
184   // Note: can't use IRParser for this test due to issue #23989
185   auto graph = std::make_shared<Graph>();
186   std::vector<int64_t> sizes{2, 3, 4};
187   std::vector<int64_t> strides{12, 4, 1};
188   const auto type = TensorType::create(
189       at::ScalarType::Float,
190       at::kCPU,
191       c10::VaryingShape<int64_t>{sizes},
192       c10::VaryingShape<int64_t>{strides},
193       true);
194 
195   // Builds graph a * b * a + b
196   auto* a = graph->addInput()->setType(type);
197   auto* b = graph->addInput()->setType(type);
198   auto* cOne = graph->insertConstant(1);
199 
200   auto* ab = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
201   ab->addInput(a);
202   ab->addInput(b);
203 
204   auto* aba = graph->insertNode(graph->create(aten::mul, /*num_outputs =*/1));
205   aba->addInput(ab->output());
206   aba->addInput(a);
207 
208   auto* abaplusb =
209       graph->insertNode(graph->create(aten::add, /*num_outputs =*/1));
210   abaplusb->addInput(aba->output());
211   abaplusb->addInput(b);
212   abaplusb->addInput(cOne);
213 
214   graph->registerOutput(abaplusb->output());
215 
216   auto grad_spec = differentiate(graph);
217   std::vector<size_t> expected_captured_inputs = {0, 1};
218   std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5, 6, 7};
219   std::vector<size_t> expected_input_vjps = {0, 1};
220   std::vector<size_t> expected_output_vjps = {0, 1};
221   ASSERT_EQ(grad_spec.f_real_outputs, 1);
222   ASSERT_EQ(grad_spec.df_input_captured_inputs, expected_captured_inputs);
223   ASSERT_EQ(grad_spec.df_input_captured_outputs, expected_captured_outputs);
224   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
225   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
226   testing::FileCheck()
227       .check_count("aten::mul", 2)
228       ->check("aten::size")
229       ->check("aten::add")
230       ->run(*grad_spec.f);
231   testing::FileCheck()
232       .check("prim::GradOf[name=\"aten::add\"]")
233       ->check_count("prim::GradOf[name=\"aten::mul\"]", 2)
234       ->check_count("AutogradAdd", 2)
235       ->run(*grad_spec.df);
236 }
237 
TEST(AutodiffTest,DifferentiateWithRequiresGrad)238 TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
239   const auto graph_string = R"IR(
240     graph(%0 : Tensor,
241           %1 : Tensor):
242       %2 : int = prim::Constant[value=1]()
243       %3 : Tensor = aten::mul(%1, %1)
244       %4 : Tensor = aten::add(%3, %1, %2)
245       %5 : Tensor = aten::add(%4, %0, %2)
246       %6 : Tensor = aten::mul(%5, %0)
247       %7 : Tensor = aten::add(%6, %1, %2)
248       return (%4, %7))IR";
249   auto g = std::make_shared<Graph>();
250   torch::jit::parseIR(graph_string, g.get());
251 
252   auto a_var = autograd::make_variable(
253       at::empty_strided(2, 2, at::CPU(at::kFloat).options()), true);
254   auto b_var = autograd::make_variable(
255       at::empty_strided(2, 2, at::CPU(at::kFloat).options()), false);
256 
257   ArgumentSpecCreator asc(*g);
258   asc.specializeTypes(*g, asc.create(true, {a_var, b_var}));
259 
260   PropagateInputShapes(g);
261   PropagateRequiresGrad(g);
262 
263   auto grad_spec = differentiate(g);
264   std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
265   std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
266   ASSERT_EQ(grad_spec.f_real_outputs, 2);
267   ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
268   ASSERT_EQ(
269       grad_spec.df_input_captured_outputs,
270       std::vector<size_t>({2, 3, 4, 5, 6}));
271   ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
272   ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
273   testing::FileCheck()
274       .check("aten::mul")
275       ->check_count("aten::add", 2)
276       ->check("aten::mul")
277       ->check("aten::size")
278       ->check("aten::add")
279       ->run(*grad_spec.f);
280 
281   testing::FileCheck()
282       .check_count("prim::GradOf[name=\"aten::mul\"]", 1, /*exactly*/ true)
283       ->run(*grad_spec.df);
284 }
285 
286 class AutodiffRemoveUnusedGradientsTest : public ::testing::Test {
287  protected:
SetUp()288   void SetUp() override {
289     prev_exec = getExecutorMode();
290     getExecutorMode() = true;
291     prev_inline_autodiff = getAutodiffSubgraphInlining();
292     debugSetAutodiffSubgraphInlining(false);
293   }
TearDown()294   void TearDown() override {
295     getExecutorMode() = prev_exec;
296     debugSetAutodiffSubgraphInlining(prev_inline_autodiff);
297   }
298 
299   bool prev_exec;
300   bool prev_profiling;
301   bool prev_inline_autodiff;
302 };
303 
TEST_F(AutodiffRemoveUnusedGradientsTest,Linear)304 TEST_F(AutodiffRemoveUnusedGradientsTest, Linear) {
305   auto graph = std::make_shared<Graph>();
306   const std::string input =
307       R"IR(
308 graph(%inp.1 : Tensor,
309       %weight.1 : Tensor,
310       %bias.1 : Tensor):
311   %6 : Tensor = aten::linear(%inp.1, %weight.1, %bias.1)
312   return (%6))IR";
313   parseIR(input, graph.get());
314 
315   auto inp = torch::randn({10, 10}).requires_grad_(false);
316   auto weight = torch::randn({10, 10}).requires_grad_(true);
317   auto bias = torch::randn({1, 10}).requires_grad_(true);
318   auto stack = createStack({inp, weight, bias});
319 
320   ProfilingGraphExecutorImpl executor(graph, "linear");
321 
322   // initial run to profile requires_grad information
323   auto plan = executor.getPlanFor(stack, 20);
324   InterpreterState is{plan.code};
325   is.run(stack);
326 
327   auto optimized_plan = executor.getPlanFor(stack, 20);
328   DepthFirstGraphNodeIterator it(optimized_plan.graph);
329   Node* diff_graph_node = nullptr;
330 
331   while ((diff_graph_node = it.next()) != nullptr) {
332     if (diff_graph_node->kind() == prim::DifferentiableGraph) {
333       break;
334     }
335   }
336   ASSERT_NE(nullptr, diff_graph_node);
337 
338   auto backward_graph = diff_graph_node->g(attr::ReverseSubgraph);
339 
340   // we expect to compute grad_weight (which requires a matmul) but we don't
341   // expect to compute grad_input. So, we expect exactly 1 matmul.
342   // Note: this could change, e.g. if mm is used instead
343   testing::FileCheck().check_count("matmul", 1, true)->run(*backward_graph);
344 }
345 
346 } // namespace jit
347 } // namespace torch
348