xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/python_compiled_autograd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/dynamo/python_compiled_autograd.h>
2 
3 #include <torch/csrc/autograd/engine.h>
4 #include <torch/csrc/autograd/functions/accumulate_grad.h>
5 #include <torch/csrc/autograd/python_function.h>
6 #include <torch/csrc/dynamo/compiled_autograd.h>
7 #include <torch/csrc/jit/python/pybind_utils.h>
8 #include <torch/csrc/python_headers.h>
9 #include <torch/csrc/utils/pythoncapi_compat.h>
10 #include <iostream>
11 #include <sstream>
12 #include <string>
13 #include <vector>
14 
15 /*
16 [Note: Compiled Autograd]
17 
18 Compiled autograd replaces the standard autograd engine by converting
19 the autograd graph to an FX graph that can be torch.compiled. It caches
20 this conversion using a shadow graph. We compare the new graph to the
21 shadow graph by walking the two graphs simultaneously and computing a
22 CacheKey for each original node to find the next edge in the shadow graph.
23 Two different graphs might have a shared common prefix in the shadow
24 graph, but then diverge at the first difference. Tensors, SavedVariables,
25 and SymInt found stored on the nodes in the autograd graph are lifted to
26 become inputs to the graph. All other properties (ints, floats, types,
27 etc.) are specialized using the CacheKey and will result in landing on
28 a different cache node in the shadow graph if some property differs.
29 
30 To interact with the (hundreds) of different autograd::Node types,
31 we use a visitor pattern that walks each Node structure recursively.
32 
33 - The first pass, compiled_args/collect, extracts all the inputs to the
34 graph and builds a CacheKey for us to specialize on.  On a cache hit,
35 we stop here and this is the only pass.
36 
37 - On a cache miss, a second pass kicks in to extract the FX graph using
38 apply_with_saved, which uses another visitor pattern.  The before()
39 visitor swaps out all the Tensors, SavedVariables, and SymInt for
40 fake/symbolic versions to allow tracing.  We then run the standard apply()
41 method, and after() restores things to how we found them.
42 
43 When we see tensor hooks, we record them directly in the output graph
44 without tracing into them.  We do this to avoid executing unsafe code
45 at trace time.
46 
47 Notes:
48   - We require hooks to not change shapes of tensors.
49   - We require non-hook autograd nodes to be tracable.
50 */
51 
52 namespace torch::dynamo::autograd {
53 using c10::SymInt;
54 
wrap_int_list(const std::vector<int64_t> & inputs)55 static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
56   PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
57   for (const auto i : c10::irange(inputs.size())) {
58     PyTuple_SET_ITEM(pyinput, i, PyLong_FromSsize_t(inputs[i]));
59   }
60   return pyinput;
61 }
62 
convert_hook_list(std::vector<c10::SafePyObject> & inputs)63 static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
64   // inplace, consumes the input hooks
65   PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
66   for (const auto i : c10::irange(inputs.size())) {
67     PyTuple_SET_ITEM(pyinput, i, inputs[i].release());
68   }
69   return pyinput;
70 }
71 
check(PyObject * pyresult)72 static PyObject* check(PyObject* pyresult) {
73   if (C10_UNLIKELY(pyresult == nullptr)) {
74     // see https://github.com/pytorch/pytorch/pull/34845
75     python_error err;
76     err.persist();
77     // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
78     throw err;
79   }
80   return pyresult;
81 }
82 
check(bool result)83 static void check(bool result) {
84   if (C10_UNLIKELY(!result))
85     check(nullptr);
86 }
87 
88 // snapshot of python verbose logging toggle
89 static PyObject* python_verbose_logger = nullptr;
90 struct VerboseLogger {
maybe_createtorch::dynamo::autograd::VerboseLogger91   static std::optional<VerboseLogger> maybe_create() {
92     if (python_verbose_logger == nullptr) {
93       return std::nullopt;
94     }
95     return VerboseLogger();
96   }
97 
verbose_log_fntorch::dynamo::autograd::VerboseLogger98   void verbose_log_fn(std::string_view msg) const {
99     TORCH_CHECK(python_verbose_logger != nullptr);
100     check(PyObject_CallFunction(python_verbose_logger, "s", msg.data()));
101   }
102 
log_node_checktorch::dynamo::autograd::VerboseLogger103   void log_node_check(
104       const Node& fn,
105       size_t size_inputs_num,
106       std::unordered_set<CacheKey> cached_keys,
107       const CacheKey& key,
108       size_t node_idx) {
109     std::string node_name =
110         fn.name() + " (NodeCall " + std::to_string(node_idx) + ")";
111 
112     cumulative_sizes_per_node[size_inputs_num] = node_name;
113 
114     if (!logged_node_miss && cached_keys.find(key) == cached_keys.end()) {
115       _log_node_miss(typeid(fn), cached_keys, key, node_name);
116       logged_node_miss = true;
117     }
118   }
119 
_log_node_misstorch::dynamo::autograd::VerboseLogger120   void _log_node_miss(
121       const std::type_info& node_type,
122       std::unordered_set<CacheKey> cached_keys,
123       const CacheKey& key,
124       const std::string& node_name) const {
125     std::ostringstream oss;
126     oss << "Cache miss due to new autograd node: " << node_name
127         << " with key size " << std::to_string(key.key_size)
128         << ", previous key sizes=[";
129 
130     for (auto it = cached_keys.begin(); it != cached_keys.end(); it++) {
131       if (it->node_type != node_type) {
132         continue;
133       }
134       oss << it->key_size;
135       if (std::next(it) != cached_keys.end()) {
136         oss << ",";
137       }
138     }
139     oss << "]";
140     verbose_log_fn(oss.str());
141   }
142 
log_dynamic_shapes_checktorch::dynamo::autograd::VerboseLogger143   void log_dynamic_shapes_check(size_t size_idx) const {
144     if (cumulative_sizes_per_node.empty()) {
145       return;
146     }
147 
148     auto it = cumulative_sizes_per_node.lower_bound(size_idx);
149     TORCH_CHECK(it != cumulative_sizes_per_node.end());
150     size_t start_idx =
151         it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first;
152     verbose_log_fn(
153         "Cache miss due to changed shapes: marking size idx " +
154         std::to_string(size_idx - start_idx) + " of " + it->second +
155         " as dynamic");
156   }
157 
158   // track which size index belongs to which node
159   std::map<size_t, std::string> cumulative_sizes_per_node;
160   // only log cache miss due to node key once
161   bool logged_node_miss = false;
162 };
163 
164 struct CacheNode {
165   // A node in the shadow graph, we follow next edges until we reach the end of
166   // the graph
roottorch::dynamo::autograd::CacheNode167   static CacheNode* root() {
168     static CacheNode _root;
169     return &_root;
170   }
171 
lookuptorch::dynamo::autograd::CacheNode172   CacheNode* lookup(const CacheKey& key, bool create = true) {
173     auto it = next.find(key);
174     if (it == next.end()) {
175       if (!create)
176         return nullptr;
177       // caller's key is in temporary memory, must copy it
178       CacheKeyBuffer buffer(key.key, key.key_size);
179       CacheKey key_with_storage(key.node_type, buffer.get(), key.key_size);
180       it = next.emplace(key_with_storage, std::make_unique<CacheNode>()).first;
181       key_storage.emplace_back(std::move(buffer));
182     }
183     return it->second.get();
184   }
185 
cleartorch::dynamo::autograd::CacheNode186   void clear() {
187     next.clear();
188     key_storage.clear();
189     expected_sizes.clear();
190     runtime_wrapper = nullptr;
191     compiled_fn = nullptr;
192   }
193 
is_emptytorch::dynamo::autograd::CacheNode194   bool is_empty() const {
195     return next.empty() && !compiled_fn;
196   }
197 
CacheNodetorch::dynamo::autograd::CacheNode198   CacheNode() : runtime_wrapper(nullptr), compiled_fn(nullptr) {}
~CacheNodetorch::dynamo::autograd::CacheNode199   ~CacheNode() {
200     if (!Py_IsInitialized()) {
201       // leak on shutdown
202       runtime_wrapper.release();
203       compiled_fn.release();
204     }
205   }
206   CacheNode(CacheNode&&) = delete;
207   CacheNode(const CacheNode&) = delete;
208   CacheNode& operator=(const CacheNode&) = delete;
209   CacheNode& operator=(CacheNode&&) = delete;
210 
check_dynamic_sizestorch::dynamo::autograd::CacheNode211   bool check_dynamic_sizes(
212       AutogradCompilerCall& call,
213       const std::optional<VerboseLogger>& vlogger) {
214     /*
215     We start off by assuming everything is static, then we mark things
216     as dynamic when we see them change.  This function:
217       1) Checks for a cache hit
218       2) Updates expected_sizes to track what is dynamic
219       3) Populates call.dyn_size_inputs by filtering call.all_size_inputs
220     */
221     bool cache_hit = compiled_fn.get() != nullptr;
222     auto len = call.all_size_inputs.size();
223     const SizeInput* data = call.all_size_inputs.data();
224     if (expected_sizes.empty()) {
225       expected_sizes.reserve(len);
226       for (const auto i : c10::irange(len)) {
227         expected_sizes.emplace_back(data[i]);
228       }
229     }
230 
231     TORCH_INTERNAL_ASSERT(expected_sizes.size() == call.all_size_inputs.size());
232     for (const auto i : c10::irange(len)) {
233       auto& expected = expected_sizes[i];
234       bool was_dynamic = expected.dyn_type == SizeInput::DYNAMIC;
235       bool changed_value = expected.value != data[i].value;
236       if (changed_value) {
237         if (!was_dynamic) {
238           cache_hit = false;
239           if (vlogger.has_value()) {
240             vlogger->log_dynamic_shapes_check(i);
241           }
242         }
243         expected = SizeInput(SizeInput::DYNAMIC, data[i].value);
244       }
245 
246       if (changed_value || was_dynamic) {
247         if (call.dyn_size_inputs.empty()) {
248           call.dyn_size_inputs.reserve(len);
249         }
250         call.dyn_size_inputs.emplace_back(data[i].value);
251       }
252     }
253 
254     if (!cache_hit) {
255       // we missed cache because static size inputs didn't match; force
256       // recompilation with the varying size input as dynamic
257       runtime_wrapper = nullptr;
258       compiled_fn = nullptr;
259     }
260     return cache_hit;
261   }
262 
wrap_dynamic_inputstorch::dynamo::autograd::CacheNode263   PyObject* wrap_dynamic_inputs() const {
264     size_t dynamic_count = 0;
265     size_t idx = 0;
266     for (const auto& i : expected_sizes) {
267       if (i.dyn_type == SizeInput::DYNAMIC) {
268         ++dynamic_count;
269       }
270     }
271     PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(dynamic_count));
272     for (const auto& i : expected_sizes) {
273       if (i.dyn_type == SizeInput::DYNAMIC) {
274         PyTuple_SET_ITEM(pyinput, idx++, PyLong_FromSsize_t(i.value));
275       }
276     }
277     TORCH_INTERNAL_ASSERT(idx == dynamic_count);
278     return pyinput;
279   }
280 
unwrap_dynamic_inputstorch::dynamo::autograd::CacheNode281   std::vector<std::optional<SymInt>> unwrap_dynamic_inputs(
282       PyObject* pyresult) const {
283     TORCH_INTERNAL_ASSERT(PyList_CheckExact(pyresult));
284     size_t idx = 0;
285     size_t result_len = PyList_GET_SIZE(pyresult);
286     std::vector<std::optional<SymInt>> result;
287     result.reserve(expected_sizes.size());
288     for (const auto& i : expected_sizes) {
289       if (i.dyn_type == SizeInput::DYNAMIC) {
290         TORCH_INTERNAL_ASSERT(idx < result_len);
291         result.emplace_back(
292             py::cast<c10::SymInt>(PyList_GET_ITEM(pyresult, idx++)));
293       } else {
294         result.emplace_back();
295       }
296     }
297     TORCH_INTERNAL_ASSERT(
298         idx == result_len && result.size() == expected_sizes.size());
299     return result;
300   }
301 
302   std::unordered_map<CacheKey, std::unique_ptr<CacheNode>> next;
303   std::vector<CacheKeyBuffer> key_storage;
304   std::vector<SizeInput> expected_sizes;
305 
306   THPObjectPtr runtime_wrapper;
307   THPObjectPtr compiled_fn;
308 };
309 
310 struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
lookuptorch::dynamo::autograd::InputBuffers311   InputBuffer& lookup(Node* function) {
312     auto it = emplace(function, InputBuffer(function->num_inputs())).first;
313     return it->second;
314   }
315 };
316 
317 static PyObject* the_autograd_compiler = nullptr;
318 static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
319 
clear_cache(PyObject * dummy,PyObject * args)320 static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
321   HANDLE_TH_ERRORS;
322   CacheNode::root()->clear();
323   Py_RETURN_NONE;
324   END_HANDLE_TH_ERRORS;
325 }
326 
is_cache_empty(PyObject * dummy,PyObject * args)327 static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
328   HANDLE_TH_ERRORS;
329   if (CacheNode::root()->is_empty()) {
330     Py_RETURN_TRUE;
331   }
332   Py_RETURN_FALSE;
333   END_HANDLE_TH_ERRORS;
334 }
335 
set_verbose_logger(PyObject * dummy,PyObject * args)336 static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
337   HANDLE_TH_ERRORS;
338   PyObject* logger = nullptr;
339   if (!PyArg_ParseTuple(args, "O", &logger)) {
340     Py_RETURN_FALSE;
341   }
342 
343   if (logger == Py_None) {
344     python_verbose_logger = nullptr;
345   } else {
346     python_verbose_logger = logger;
347   }
348   Py_RETURN_TRUE;
349   END_HANDLE_TH_ERRORS;
350 }
351 
352 // NOLINTNEXTLINE(*array*)
353 static PyMethodDef _methods[] = {
354     {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr},
355     {"clear_cache", clear_cache, METH_NOARGS, nullptr},
356     {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
357     {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
358     {nullptr, nullptr, 0, nullptr}};
359 
360 static struct PyModuleDef _module = {
361     PyModuleDef_HEAD_INIT,
362     "torch._C._dynamo.autograd_compiler",
363     "Hooks for compiling autograd",
364     -1,
365     _methods};
366 
wrap_lifted_ivalue_args(const std::vector<LiftedIValueArg> & lifted_ivalue_args)367 PyObject* wrap_lifted_ivalue_args(
368     const std::vector<LiftedIValueArg>& lifted_ivalue_args) {
369   PyObject* pyivalueargs =
370       PyList_New(static_cast<Py_ssize_t>(lifted_ivalue_args.size()));
371   size_t idx = 0;
372   for (const auto& arg : lifted_ivalue_args) {
373     if (arg.actual_ptr->isInt() || arg.actual_ptr->isSymInt()) {
374       PyList_SET_ITEM(
375           pyivalueargs, idx++, PyLong_FromSsize_t(arg.actual_ptr->toInt()));
376     } else if (arg.actual_ptr->isDouble() || arg.actual_ptr->isSymFloat()) {
377       PyList_SET_ITEM(
378           pyivalueargs, idx++, PyFloat_FromDouble(arg.actual_ptr->toDouble()));
379     } else {
380       TORCH_INTERNAL_ASSERT(false, "Unexpected lifted ivalue type");
381     }
382   }
383   return pyivalueargs;
384 }
385 
set_ivalue_proxies(PyObject * fake_ivalue_args,std::vector<LiftedIValueArg> & lifted_ivalue_args)386 void set_ivalue_proxies(
387     PyObject* fake_ivalue_args,
388     std::vector<LiftedIValueArg>& lifted_ivalue_args) {
389   TORCH_INTERNAL_ASSERT(PyList_Check(fake_ivalue_args));
390   TORCH_INTERNAL_ASSERT(
391       static_cast<size_t>(PyList_Size(fake_ivalue_args)) ==
392       lifted_ivalue_args.size());
393 
394   for (const auto& i : c10::irange(lifted_ivalue_args.size())) {
395     auto& arg = lifted_ivalue_args[i];
396     if (arg.actual_ptr->isInt() || arg.actual_ptr->isSymInt()) {
397       arg.proxy = at::IValue(
398           py::cast<c10::SymInt>(PyList_GET_ITEM(fake_ivalue_args, i)));
399       TORCH_INTERNAL_ASSERT(arg.proxy.isSymInt());
400     } else if (arg.actual_ptr->isDouble() || arg.actual_ptr->isSymFloat()) {
401       arg.proxy = at::IValue(
402           py::cast<c10::SymFloat>(PyList_GET_ITEM(fake_ivalue_args, i)));
403     } else {
404       TORCH_INTERNAL_ASSERT(false, "Unexpected lifted ivalue type");
405     }
406   }
407 }
408 
call_begin_capture(PyObject * self,CacheNode & cache,AutogradCompilerCall & compiler_call,size_t num_outputs)409 static TraceState call_begin_capture(
410     PyObject* self,
411     CacheNode& cache,
412     AutogradCompilerCall& compiler_call,
413     size_t num_outputs) {
414   static PyObject* method_name = PyUnicode_InternFromString("begin_capture");
415   THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
416   THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
417   THPObjectPtr pyivalueargsinput(
418       wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args));
419   THPObjectPtr pyresult(check(PyObject_CallMethodObjArgs(
420       self,
421       method_name,
422       pyinput.get(),
423       pysizeinput.get(),
424       pyivalueargsinput.get(),
425       nullptr)));
426 
427   PyObject *fake_inputs{nullptr}, *fake_sizes{nullptr},
428       *fake_ivalue_args{nullptr};
429   check(PyArg_ParseTuple(
430       pyresult.get(), "OOO", &fake_inputs, &fake_sizes, &fake_ivalue_args));
431 
432   variable_list proxy_inputs = THPVariable_UnpackList(fake_inputs);
433   TORCH_INTERNAL_ASSERT(
434       proxy_inputs.size() == compiler_call.tensor_args.inputs.size());
435   for (const auto i : c10::irange(proxy_inputs.size())) {
436     TensorArg& arg =
437         compiler_call.tensor_args.lookup(compiler_call.tensor_args.inputs[i]);
438     arg.proxy_tensor = proxy_inputs[i];
439   }
440 
441   set_ivalue_proxies(fake_ivalue_args, compiler_call.lifted_ivalue_args.args);
442   return TraceState(cache.unwrap_dynamic_inputs(fake_sizes), num_outputs);
443 }
444 
call_end_capture(PyObject * self,const variable_list & inputs)445 static PyObject* call_end_capture(PyObject* self, const variable_list& inputs) {
446   static PyObject* method_name = PyUnicode_InternFromString("end_capture");
447   THPObjectPtr pyinput(THPVariable_WrapList(inputs));
448   return check(PyObject_CallMethodOneArg(self, method_name, pyinput.get()));
449 }
450 
451 struct ClosingTHPObjectPtr : public THPObjectPtr {
ClosingTHPObjectPtrtorch::dynamo::autograd::ClosingTHPObjectPtr452   ClosingTHPObjectPtr(PyObject* o) : THPObjectPtr(o) {}
~ClosingTHPObjectPtrtorch::dynamo::autograd::ClosingTHPObjectPtr453   ~ClosingTHPObjectPtr() {
454     if (PyErr_Occurred()) {
455       // do nothing, do not attempt to close
456       return;
457     }
458     static PyObject* method_name = PyUnicode_InternFromString("close");
459     if (PyObject_CallMethodNoArgs(get(), method_name) == nullptr) {
460       PyErr_WriteUnraisable(get());
461       PyErr_Clear();
462     }
463   }
464 };
465 
466 // Only call this function while holding GIL
_compiled_autograd_impl(const std::shared_ptr<Node> & graph_root,GraphTask & graph_task,bool accumulate_grad,const edge_list & output_edges,THPObjectPtr * graph_arg_inputs,THPObjectPtr * graph_arg_sizes,THPObjectPtr * graph_arg_ivalue_args,THPObjectPtr * graph_arg_hooks)467 CacheNode* _compiled_autograd_impl(
468     const std::shared_ptr<Node>& graph_root,
469     GraphTask& graph_task,
470     bool accumulate_grad,
471     const edge_list& output_edges,
472     THPObjectPtr* graph_arg_inputs,
473     THPObjectPtr* graph_arg_sizes,
474     THPObjectPtr* graph_arg_ivalue_args,
475     THPObjectPtr* graph_arg_hooks) {
476   std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
477   std::vector<std::shared_ptr<Node>> worklist{graph_root};
478   AutogradCompilerCall compiler_call;
479 
480   for (const auto i : c10::irange(output_edges.size())) {
481     compiler_call.node_calls
482         .lookup(output_edges[i].function)
483         // NOLINTNEXTLINE(*-narrowing-conversions)
484         .mark_output(output_edges[i].input_nr, i);
485   }
486   const bool check_exec_info = !graph_task.exec_info_.empty();
487   CacheNode* cache = CacheNode::root();
488   std::vector<NodeCall*> calls;
489   calls.reserve(
490       check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
491 
492   int i = 0;
493   std::optional<VerboseLogger> vlogger = VerboseLogger::maybe_create();
494   while (!worklist.empty()) {
495     std::shared_ptr<Node> fn = std::move(worklist.back());
496     worklist.pop_back();
497     NodeCall& call = compiler_call.node_calls.lookup(fn);
498     calls.emplace_back(&call);
499 
500     { // update cache and gather args into `compiler_call`
501       CompiledNodeArgs node_args(compiler_call, call);
502       node_args.collect(call);
503       if (node_args.cond(call.needed)) {
504         fn->compiled_args(node_args);
505         node_args.collect(call.node->next_edges());
506       }
507       CacheKey key = node_args.key();
508       if (vlogger.has_value()) {
509         std::unordered_set<CacheKey> cached_keys;
510         for (const auto& [k, _] : cache->next) {
511           cached_keys.emplace(k);
512         }
513         vlogger->log_node_check(
514             *fn,
515             compiler_call.all_size_inputs.size(),
516             std::move(cached_keys),
517             key,
518             i);
519       }
520       cache = cache->lookup(key);
521     }
522 
523     for (const auto& edge : fn->next_edges()) {
524       if (!edge.is_valid()) {
525         continue;
526       }
527       if (check_exec_info) {
528         auto it = graph_task.exec_info_.find(edge.function.get());
529         if (it == graph_task.exec_info_.end() || !it->second.should_execute()) {
530           continue;
531         }
532         if (!it->second.needed_) {
533           compiler_call.node_calls.lookup(edge.function).needed = false;
534         }
535       }
536       auto it = dependencies.find(edge.function.get());
537       TORCH_INTERNAL_ASSERT(it != dependencies.end());
538       if (--it->second == 0) {
539         dependencies.erase(it);
540         worklist.emplace_back(edge.function);
541       }
542     }
543     i++;
544   }
545 
546   // TODO(jansel): some dynamic sizes seem to be ints not symints
547   if (!cache->check_dynamic_sizes(compiler_call, vlogger)) {
548     // cache miss, need to capture FX graph
549     ClosingTHPObjectPtr py_compiler(
550         check(PyObject_CallNoArgs((the_autograd_compiler))));
551 
552     TraceState state = call_begin_capture(
553         py_compiler, *cache, compiler_call, output_edges.size());
554     InputBuffers input_buffers;
555 
556     for (size_t i = 0; i < calls.size(); i++) {
557       NodeCall& call = *calls[i];
558       // TODO(jansel): consider adding some of this stuff:
559       // guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
560       // opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
561       // c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
562       // CheckpointValidGuard cpvguard(graph_task);
563       // at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
564       // if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
565 
566       variable_list inputs =
567           std::move(input_buffers.lookup(call.node.get()).buffer);
568       input_buffers.erase(call.node.get());
569 
570       if (!call.tensor_pre_hooks.empty()) {
571         THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
572         for (const auto& hook : call.tensor_pre_hooks) {
573           pyinputs = check(PyObject_CallMethod(
574               py_compiler,
575               "tensor_pre_hook",
576               "Oii",
577               pyinputs.get(),
578               hook.first,
579               hook.second));
580         }
581         inputs = THPVariable_UnpackList(pyinputs);
582       }
583       for (const auto& graph_output : call.graph_output) {
584         int input_nr = graph_output.first;
585         int output_index = graph_output.second;
586         TORCH_INTERNAL_ASSERT(
587             output_index < static_cast<int>(state.outputs.size()));
588         TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
589         state.outputs[output_index] = inputs[input_nr];
590       }
591       if (!call.needed) {
592         continue;
593       }
594       if (!call.pre_hooks.empty()) {
595         THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
596         for (const auto hook : call.pre_hooks) {
597           pyinputs = check(PyObject_CallMethod(
598               py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
599         }
600         inputs = THPVariable_UnpackList(pyinputs);
601       }
602 
603       std::string _node_name = call.node->name();
604       THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
605       TORCH_INTERNAL_ASSERT(node_name != nullptr);
606       THPObjectPtr set_node_origin(
607           PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
608 
609       PyObject* pyobj = Py_None;
610       if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
611         pyobj = pynode->obj;
612       }
613 
614       check(PyObject_CallFunction(
615           set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
616 
617       SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
618       variable_list outputs = call.node->apply_with_saved(inputs, saved);
619 
620       saved.debug_asserts();
621       saved.before(call.node->next_edges());
622       validate_outputs(
623           call.node->next_edges(), outputs, [&](const std::string& msg) {
624             std::ostringstream ss;
625             ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
626                << msg;
627             return ss.str();
628           });
629       saved.after(call.node->next_edges());
630       saved.debug_asserts();
631 
632       if (!call.post_hooks.empty()) {
633         THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
634         THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
635         for (const auto hook : call.post_hooks) {
636           pyoutputs = check(PyObject_CallMethod(
637               py_compiler.get(),
638               "post_hook",
639               "OOi",
640               pyoutputs.get(),
641               pyinputs.get(),
642               hook));
643         }
644         outputs = THPVariable_UnpackList(pyoutputs);
645       }
646       for (const auto i : c10::irange(outputs.size())) {
647         auto& output = outputs[i];
648         const auto& next = call.node->next_edge(i);
649         if (next.is_valid() && output.defined()) {
650           input_buffers.lookup(next.function.get())
651               .add(
652                   next.input_nr, std::move(output), std::nullopt, std::nullopt);
653         }
654       }
655     }
656 
657     PyObject* res = check(call_end_capture(py_compiler, state.outputs));
658     TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
659     TORCH_CHECK(
660         PyTuple_Size(res) == 2,
661         "Expected end_capture to return tuple of size 2");
662     cache->runtime_wrapper = Py_NewRef(PyTuple_GetItem(res, 0));
663     TORCH_CHECK(
664         PyCallable_Check(cache->runtime_wrapper),
665         "Expected end_capture to return runtime_wrapper");
666     cache->compiled_fn = Py_NewRef(PyTuple_GetItem(res, 1));
667     TORCH_CHECK(
668         PyCallable_Check(cache->compiled_fn),
669         "Expected end_capture to return compiled_fn");
670     state.debug_asserts();
671   } // End cache miss region
672 
673   // TODO(jansel): clear grads we will overwrite below
674   if (!graph_task.keep_graph_) {
675     for (auto& call : calls) {
676       call->node->release_variables();
677     }
678   }
679 
680   *graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
681   *graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
682   *graph_arg_ivalue_args =
683       wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
684   *graph_arg_hooks = convert_hook_list(compiler_call.hooks);
685   return cache;
686 }
687 
688 struct LockGuardWithErrorLogs {
LockGuardWithErrorLogstorch::dynamo::autograd::LockGuardWithErrorLogs689   LockGuardWithErrorLogs(std::mutex& mtx) : mtx_(mtx) {
690     // Note: the standard allows try_lock to fail spuriously during races for
691     // performance reasons, but it shouldn't happen here since we:
692     // 1. disable multithreaded autograd
693     // 2. plenty of latency between backward calls
694     TORCH_INTERNAL_ASSERT(
695         mtx_.try_lock(),
696         "Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet.");
697   }
698 
~LockGuardWithErrorLogstorch::dynamo::autograd::LockGuardWithErrorLogs699   ~LockGuardWithErrorLogs() {
700     mtx_.unlock();
701   }
702 
703   std::mutex& mtx_;
704 };
705 
compiled_autograd(const std::shared_ptr<Node> & graph_root,GraphTask & graph_task,bool accumulate_grad,const edge_list & output_edges)706 variable_list compiled_autograd(
707     const std::shared_ptr<Node>& graph_root,
708     GraphTask& graph_task,
709     bool accumulate_grad,
710     const edge_list& output_edges) {
711   TORCH_CHECK(
712       c10::impl::TorchDispatchModeTLS::stack_len() == 0,
713       "TorchDispatchMode not yet implemented for compiled autograd")
714   static std::mutex mtx;
715   LockGuardWithErrorLogs lock_guard(mtx);
716   pybind11::gil_scoped_acquire gil;
717   at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
718 
719   THPObjectPtr inputs;
720   THPObjectPtr sizes;
721   THPObjectPtr ivalue_args;
722   THPObjectPtr hooks;
723   CacheNode* cache = _compiled_autograd_impl(
724       graph_root,
725       graph_task,
726       accumulate_grad,
727       output_edges,
728       &inputs,
729       &sizes,
730       &ivalue_args,
731       &hooks);
732 
733   THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
734       cache->runtime_wrapper.get(),
735       cache->compiled_fn.get(),
736       inputs.get(),
737       sizes.get(),
738       ivalue_args.get(),
739       hooks.get(),
740       NULL)));
741   variable_list outputs = THPVariable_UnpackList(pyresult);
742   TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
743   return outputs;
744 }
745 
set_autograd_compiler(PyObject * dummy,PyObject * args)746 static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
747   HANDLE_TH_ERRORS;
748   PyObject* obj = nullptr;
749   if (!PyArg_ParseTuple(args, "O", &obj)) {
750     return nullptr;
751   }
752 
753   PyObject* prior = the_autograd_compiler;
754   if (obj == Py_None) { // disable
755     the_autograd_compiler = nullptr; // decref not needed due to `prior`
756     Engine::set_compiled_autograd(nullptr);
757   } else { // enable
758     Py_INCREF(obj);
759     the_autograd_compiler = obj;
760     Engine::set_compiled_autograd(&compiled_autograd);
761   }
762 
763   if (prior == nullptr) {
764     Py_RETURN_NONE;
765   } else {
766     return prior;
767   }
768   END_HANDLE_TH_ERRORS;
769 }
770 
torch_c_dynamo_compiled_autograd_init()771 PyObject* torch_c_dynamo_compiled_autograd_init() {
772   return PyModule_Create(&_module);
773 }
774 
775 } // namespace torch::dynamo::autograd
776