1 #include <torch/csrc/dynamo/python_compiled_autograd.h>
2
3 #include <torch/csrc/autograd/engine.h>
4 #include <torch/csrc/autograd/functions/accumulate_grad.h>
5 #include <torch/csrc/autograd/python_function.h>
6 #include <torch/csrc/dynamo/compiled_autograd.h>
7 #include <torch/csrc/jit/python/pybind_utils.h>
8 #include <torch/csrc/python_headers.h>
9 #include <torch/csrc/utils/pythoncapi_compat.h>
10 #include <iostream>
11 #include <sstream>
12 #include <string>
13 #include <vector>
14
15 /*
16 [Note: Compiled Autograd]
17
18 Compiled autograd replaces the standard autograd engine by converting
19 the autograd graph to an FX graph that can be torch.compiled. It caches
20 this conversion using a shadow graph. We compare the new graph to the
21 shadow graph by walking the two graphs simultaneously and computing a
22 CacheKey for each original node to find the next edge in the shadow graph.
23 Two different graphs might have a shared common prefix in the shadow
24 graph, but then diverge at the first difference. Tensors, SavedVariables,
25 and SymInt found stored on the nodes in the autograd graph are lifted to
26 become inputs to the graph. All other properties (ints, floats, types,
27 etc.) are specialized using the CacheKey and will result in landing on
28 a different cache node in the shadow graph if some property differs.
29
30 To interact with the (hundreds) of different autograd::Node types,
31 we use a visitor pattern that walks each Node structure recursively.
32
33 - The first pass, compiled_args/collect, extracts all the inputs to the
34 graph and builds a CacheKey for us to specialize on. On a cache hit,
35 we stop here and this is the only pass.
36
37 - On a cache miss, a second pass kicks in to extract the FX graph using
38 apply_with_saved, which uses another visitor pattern. The before()
39 visitor swaps out all the Tensors, SavedVariables, and SymInt for
40 fake/symbolic versions to allow tracing. We then run the standard apply()
41 method, and after() restores things to how we found them.
42
43 When we see tensor hooks, we record them directly in the output graph
44 without tracing into them. We do this to avoid executing unsafe code
45 at trace time.
46
47 Notes:
48 - We require hooks to not change shapes of tensors.
49 - We require non-hook autograd nodes to be tracable.
50 */
51
52 namespace torch::dynamo::autograd {
53 using c10::SymInt;
54
wrap_int_list(const std::vector<int64_t> & inputs)55 static PyObject* wrap_int_list(const std::vector<int64_t>& inputs) {
56 PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
57 for (const auto i : c10::irange(inputs.size())) {
58 PyTuple_SET_ITEM(pyinput, i, PyLong_FromSsize_t(inputs[i]));
59 }
60 return pyinput;
61 }
62
convert_hook_list(std::vector<c10::SafePyObject> & inputs)63 static PyObject* convert_hook_list(std::vector<c10::SafePyObject>& inputs) {
64 // inplace, consumes the input hooks
65 PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(inputs.size()));
66 for (const auto i : c10::irange(inputs.size())) {
67 PyTuple_SET_ITEM(pyinput, i, inputs[i].release());
68 }
69 return pyinput;
70 }
71
check(PyObject * pyresult)72 static PyObject* check(PyObject* pyresult) {
73 if (C10_UNLIKELY(pyresult == nullptr)) {
74 // see https://github.com/pytorch/pytorch/pull/34845
75 python_error err;
76 err.persist();
77 // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference)
78 throw err;
79 }
80 return pyresult;
81 }
82
check(bool result)83 static void check(bool result) {
84 if (C10_UNLIKELY(!result))
85 check(nullptr);
86 }
87
88 // snapshot of python verbose logging toggle
89 static PyObject* python_verbose_logger = nullptr;
90 struct VerboseLogger {
maybe_createtorch::dynamo::autograd::VerboseLogger91 static std::optional<VerboseLogger> maybe_create() {
92 if (python_verbose_logger == nullptr) {
93 return std::nullopt;
94 }
95 return VerboseLogger();
96 }
97
verbose_log_fntorch::dynamo::autograd::VerboseLogger98 void verbose_log_fn(std::string_view msg) const {
99 TORCH_CHECK(python_verbose_logger != nullptr);
100 check(PyObject_CallFunction(python_verbose_logger, "s", msg.data()));
101 }
102
log_node_checktorch::dynamo::autograd::VerboseLogger103 void log_node_check(
104 const Node& fn,
105 size_t size_inputs_num,
106 std::unordered_set<CacheKey> cached_keys,
107 const CacheKey& key,
108 size_t node_idx) {
109 std::string node_name =
110 fn.name() + " (NodeCall " + std::to_string(node_idx) + ")";
111
112 cumulative_sizes_per_node[size_inputs_num] = node_name;
113
114 if (!logged_node_miss && cached_keys.find(key) == cached_keys.end()) {
115 _log_node_miss(typeid(fn), cached_keys, key, node_name);
116 logged_node_miss = true;
117 }
118 }
119
_log_node_misstorch::dynamo::autograd::VerboseLogger120 void _log_node_miss(
121 const std::type_info& node_type,
122 std::unordered_set<CacheKey> cached_keys,
123 const CacheKey& key,
124 const std::string& node_name) const {
125 std::ostringstream oss;
126 oss << "Cache miss due to new autograd node: " << node_name
127 << " with key size " << std::to_string(key.key_size)
128 << ", previous key sizes=[";
129
130 for (auto it = cached_keys.begin(); it != cached_keys.end(); it++) {
131 if (it->node_type != node_type) {
132 continue;
133 }
134 oss << it->key_size;
135 if (std::next(it) != cached_keys.end()) {
136 oss << ",";
137 }
138 }
139 oss << "]";
140 verbose_log_fn(oss.str());
141 }
142
log_dynamic_shapes_checktorch::dynamo::autograd::VerboseLogger143 void log_dynamic_shapes_check(size_t size_idx) const {
144 if (cumulative_sizes_per_node.empty()) {
145 return;
146 }
147
148 auto it = cumulative_sizes_per_node.lower_bound(size_idx);
149 TORCH_CHECK(it != cumulative_sizes_per_node.end());
150 size_t start_idx =
151 it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first;
152 verbose_log_fn(
153 "Cache miss due to changed shapes: marking size idx " +
154 std::to_string(size_idx - start_idx) + " of " + it->second +
155 " as dynamic");
156 }
157
158 // track which size index belongs to which node
159 std::map<size_t, std::string> cumulative_sizes_per_node;
160 // only log cache miss due to node key once
161 bool logged_node_miss = false;
162 };
163
164 struct CacheNode {
165 // A node in the shadow graph, we follow next edges until we reach the end of
166 // the graph
roottorch::dynamo::autograd::CacheNode167 static CacheNode* root() {
168 static CacheNode _root;
169 return &_root;
170 }
171
lookuptorch::dynamo::autograd::CacheNode172 CacheNode* lookup(const CacheKey& key, bool create = true) {
173 auto it = next.find(key);
174 if (it == next.end()) {
175 if (!create)
176 return nullptr;
177 // caller's key is in temporary memory, must copy it
178 CacheKeyBuffer buffer(key.key, key.key_size);
179 CacheKey key_with_storage(key.node_type, buffer.get(), key.key_size);
180 it = next.emplace(key_with_storage, std::make_unique<CacheNode>()).first;
181 key_storage.emplace_back(std::move(buffer));
182 }
183 return it->second.get();
184 }
185
cleartorch::dynamo::autograd::CacheNode186 void clear() {
187 next.clear();
188 key_storage.clear();
189 expected_sizes.clear();
190 runtime_wrapper = nullptr;
191 compiled_fn = nullptr;
192 }
193
is_emptytorch::dynamo::autograd::CacheNode194 bool is_empty() const {
195 return next.empty() && !compiled_fn;
196 }
197
CacheNodetorch::dynamo::autograd::CacheNode198 CacheNode() : runtime_wrapper(nullptr), compiled_fn(nullptr) {}
~CacheNodetorch::dynamo::autograd::CacheNode199 ~CacheNode() {
200 if (!Py_IsInitialized()) {
201 // leak on shutdown
202 runtime_wrapper.release();
203 compiled_fn.release();
204 }
205 }
206 CacheNode(CacheNode&&) = delete;
207 CacheNode(const CacheNode&) = delete;
208 CacheNode& operator=(const CacheNode&) = delete;
209 CacheNode& operator=(CacheNode&&) = delete;
210
check_dynamic_sizestorch::dynamo::autograd::CacheNode211 bool check_dynamic_sizes(
212 AutogradCompilerCall& call,
213 const std::optional<VerboseLogger>& vlogger) {
214 /*
215 We start off by assuming everything is static, then we mark things
216 as dynamic when we see them change. This function:
217 1) Checks for a cache hit
218 2) Updates expected_sizes to track what is dynamic
219 3) Populates call.dyn_size_inputs by filtering call.all_size_inputs
220 */
221 bool cache_hit = compiled_fn.get() != nullptr;
222 auto len = call.all_size_inputs.size();
223 const SizeInput* data = call.all_size_inputs.data();
224 if (expected_sizes.empty()) {
225 expected_sizes.reserve(len);
226 for (const auto i : c10::irange(len)) {
227 expected_sizes.emplace_back(data[i]);
228 }
229 }
230
231 TORCH_INTERNAL_ASSERT(expected_sizes.size() == call.all_size_inputs.size());
232 for (const auto i : c10::irange(len)) {
233 auto& expected = expected_sizes[i];
234 bool was_dynamic = expected.dyn_type == SizeInput::DYNAMIC;
235 bool changed_value = expected.value != data[i].value;
236 if (changed_value) {
237 if (!was_dynamic) {
238 cache_hit = false;
239 if (vlogger.has_value()) {
240 vlogger->log_dynamic_shapes_check(i);
241 }
242 }
243 expected = SizeInput(SizeInput::DYNAMIC, data[i].value);
244 }
245
246 if (changed_value || was_dynamic) {
247 if (call.dyn_size_inputs.empty()) {
248 call.dyn_size_inputs.reserve(len);
249 }
250 call.dyn_size_inputs.emplace_back(data[i].value);
251 }
252 }
253
254 if (!cache_hit) {
255 // we missed cache because static size inputs didn't match; force
256 // recompilation with the varying size input as dynamic
257 runtime_wrapper = nullptr;
258 compiled_fn = nullptr;
259 }
260 return cache_hit;
261 }
262
wrap_dynamic_inputstorch::dynamo::autograd::CacheNode263 PyObject* wrap_dynamic_inputs() const {
264 size_t dynamic_count = 0;
265 size_t idx = 0;
266 for (const auto& i : expected_sizes) {
267 if (i.dyn_type == SizeInput::DYNAMIC) {
268 ++dynamic_count;
269 }
270 }
271 PyObject* pyinput = PyTuple_New(static_cast<Py_ssize_t>(dynamic_count));
272 for (const auto& i : expected_sizes) {
273 if (i.dyn_type == SizeInput::DYNAMIC) {
274 PyTuple_SET_ITEM(pyinput, idx++, PyLong_FromSsize_t(i.value));
275 }
276 }
277 TORCH_INTERNAL_ASSERT(idx == dynamic_count);
278 return pyinput;
279 }
280
unwrap_dynamic_inputstorch::dynamo::autograd::CacheNode281 std::vector<std::optional<SymInt>> unwrap_dynamic_inputs(
282 PyObject* pyresult) const {
283 TORCH_INTERNAL_ASSERT(PyList_CheckExact(pyresult));
284 size_t idx = 0;
285 size_t result_len = PyList_GET_SIZE(pyresult);
286 std::vector<std::optional<SymInt>> result;
287 result.reserve(expected_sizes.size());
288 for (const auto& i : expected_sizes) {
289 if (i.dyn_type == SizeInput::DYNAMIC) {
290 TORCH_INTERNAL_ASSERT(idx < result_len);
291 result.emplace_back(
292 py::cast<c10::SymInt>(PyList_GET_ITEM(pyresult, idx++)));
293 } else {
294 result.emplace_back();
295 }
296 }
297 TORCH_INTERNAL_ASSERT(
298 idx == result_len && result.size() == expected_sizes.size());
299 return result;
300 }
301
302 std::unordered_map<CacheKey, std::unique_ptr<CacheNode>> next;
303 std::vector<CacheKeyBuffer> key_storage;
304 std::vector<SizeInput> expected_sizes;
305
306 THPObjectPtr runtime_wrapper;
307 THPObjectPtr compiled_fn;
308 };
309
310 struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
lookuptorch::dynamo::autograd::InputBuffers311 InputBuffer& lookup(Node* function) {
312 auto it = emplace(function, InputBuffer(function->num_inputs())).first;
313 return it->second;
314 }
315 };
316
317 static PyObject* the_autograd_compiler = nullptr;
318 static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
319
clear_cache(PyObject * dummy,PyObject * args)320 static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
321 HANDLE_TH_ERRORS;
322 CacheNode::root()->clear();
323 Py_RETURN_NONE;
324 END_HANDLE_TH_ERRORS;
325 }
326
is_cache_empty(PyObject * dummy,PyObject * args)327 static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
328 HANDLE_TH_ERRORS;
329 if (CacheNode::root()->is_empty()) {
330 Py_RETURN_TRUE;
331 }
332 Py_RETURN_FALSE;
333 END_HANDLE_TH_ERRORS;
334 }
335
set_verbose_logger(PyObject * dummy,PyObject * args)336 static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
337 HANDLE_TH_ERRORS;
338 PyObject* logger = nullptr;
339 if (!PyArg_ParseTuple(args, "O", &logger)) {
340 Py_RETURN_FALSE;
341 }
342
343 if (logger == Py_None) {
344 python_verbose_logger = nullptr;
345 } else {
346 python_verbose_logger = logger;
347 }
348 Py_RETURN_TRUE;
349 END_HANDLE_TH_ERRORS;
350 }
351
352 // NOLINTNEXTLINE(*array*)
353 static PyMethodDef _methods[] = {
354 {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr},
355 {"clear_cache", clear_cache, METH_NOARGS, nullptr},
356 {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
357 {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
358 {nullptr, nullptr, 0, nullptr}};
359
360 static struct PyModuleDef _module = {
361 PyModuleDef_HEAD_INIT,
362 "torch._C._dynamo.autograd_compiler",
363 "Hooks for compiling autograd",
364 -1,
365 _methods};
366
wrap_lifted_ivalue_args(const std::vector<LiftedIValueArg> & lifted_ivalue_args)367 PyObject* wrap_lifted_ivalue_args(
368 const std::vector<LiftedIValueArg>& lifted_ivalue_args) {
369 PyObject* pyivalueargs =
370 PyList_New(static_cast<Py_ssize_t>(lifted_ivalue_args.size()));
371 size_t idx = 0;
372 for (const auto& arg : lifted_ivalue_args) {
373 if (arg.actual_ptr->isInt() || arg.actual_ptr->isSymInt()) {
374 PyList_SET_ITEM(
375 pyivalueargs, idx++, PyLong_FromSsize_t(arg.actual_ptr->toInt()));
376 } else if (arg.actual_ptr->isDouble() || arg.actual_ptr->isSymFloat()) {
377 PyList_SET_ITEM(
378 pyivalueargs, idx++, PyFloat_FromDouble(arg.actual_ptr->toDouble()));
379 } else {
380 TORCH_INTERNAL_ASSERT(false, "Unexpected lifted ivalue type");
381 }
382 }
383 return pyivalueargs;
384 }
385
set_ivalue_proxies(PyObject * fake_ivalue_args,std::vector<LiftedIValueArg> & lifted_ivalue_args)386 void set_ivalue_proxies(
387 PyObject* fake_ivalue_args,
388 std::vector<LiftedIValueArg>& lifted_ivalue_args) {
389 TORCH_INTERNAL_ASSERT(PyList_Check(fake_ivalue_args));
390 TORCH_INTERNAL_ASSERT(
391 static_cast<size_t>(PyList_Size(fake_ivalue_args)) ==
392 lifted_ivalue_args.size());
393
394 for (const auto& i : c10::irange(lifted_ivalue_args.size())) {
395 auto& arg = lifted_ivalue_args[i];
396 if (arg.actual_ptr->isInt() || arg.actual_ptr->isSymInt()) {
397 arg.proxy = at::IValue(
398 py::cast<c10::SymInt>(PyList_GET_ITEM(fake_ivalue_args, i)));
399 TORCH_INTERNAL_ASSERT(arg.proxy.isSymInt());
400 } else if (arg.actual_ptr->isDouble() || arg.actual_ptr->isSymFloat()) {
401 arg.proxy = at::IValue(
402 py::cast<c10::SymFloat>(PyList_GET_ITEM(fake_ivalue_args, i)));
403 } else {
404 TORCH_INTERNAL_ASSERT(false, "Unexpected lifted ivalue type");
405 }
406 }
407 }
408
call_begin_capture(PyObject * self,CacheNode & cache,AutogradCompilerCall & compiler_call,size_t num_outputs)409 static TraceState call_begin_capture(
410 PyObject* self,
411 CacheNode& cache,
412 AutogradCompilerCall& compiler_call,
413 size_t num_outputs) {
414 static PyObject* method_name = PyUnicode_InternFromString("begin_capture");
415 THPObjectPtr pyinput(THPVariable_WrapList(compiler_call.tensor_args.inputs));
416 THPObjectPtr pysizeinput(cache.wrap_dynamic_inputs());
417 THPObjectPtr pyivalueargsinput(
418 wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args));
419 THPObjectPtr pyresult(check(PyObject_CallMethodObjArgs(
420 self,
421 method_name,
422 pyinput.get(),
423 pysizeinput.get(),
424 pyivalueargsinput.get(),
425 nullptr)));
426
427 PyObject *fake_inputs{nullptr}, *fake_sizes{nullptr},
428 *fake_ivalue_args{nullptr};
429 check(PyArg_ParseTuple(
430 pyresult.get(), "OOO", &fake_inputs, &fake_sizes, &fake_ivalue_args));
431
432 variable_list proxy_inputs = THPVariable_UnpackList(fake_inputs);
433 TORCH_INTERNAL_ASSERT(
434 proxy_inputs.size() == compiler_call.tensor_args.inputs.size());
435 for (const auto i : c10::irange(proxy_inputs.size())) {
436 TensorArg& arg =
437 compiler_call.tensor_args.lookup(compiler_call.tensor_args.inputs[i]);
438 arg.proxy_tensor = proxy_inputs[i];
439 }
440
441 set_ivalue_proxies(fake_ivalue_args, compiler_call.lifted_ivalue_args.args);
442 return TraceState(cache.unwrap_dynamic_inputs(fake_sizes), num_outputs);
443 }
444
call_end_capture(PyObject * self,const variable_list & inputs)445 static PyObject* call_end_capture(PyObject* self, const variable_list& inputs) {
446 static PyObject* method_name = PyUnicode_InternFromString("end_capture");
447 THPObjectPtr pyinput(THPVariable_WrapList(inputs));
448 return check(PyObject_CallMethodOneArg(self, method_name, pyinput.get()));
449 }
450
451 struct ClosingTHPObjectPtr : public THPObjectPtr {
ClosingTHPObjectPtrtorch::dynamo::autograd::ClosingTHPObjectPtr452 ClosingTHPObjectPtr(PyObject* o) : THPObjectPtr(o) {}
~ClosingTHPObjectPtrtorch::dynamo::autograd::ClosingTHPObjectPtr453 ~ClosingTHPObjectPtr() {
454 if (PyErr_Occurred()) {
455 // do nothing, do not attempt to close
456 return;
457 }
458 static PyObject* method_name = PyUnicode_InternFromString("close");
459 if (PyObject_CallMethodNoArgs(get(), method_name) == nullptr) {
460 PyErr_WriteUnraisable(get());
461 PyErr_Clear();
462 }
463 }
464 };
465
466 // Only call this function while holding GIL
_compiled_autograd_impl(const std::shared_ptr<Node> & graph_root,GraphTask & graph_task,bool accumulate_grad,const edge_list & output_edges,THPObjectPtr * graph_arg_inputs,THPObjectPtr * graph_arg_sizes,THPObjectPtr * graph_arg_ivalue_args,THPObjectPtr * graph_arg_hooks)467 CacheNode* _compiled_autograd_impl(
468 const std::shared_ptr<Node>& graph_root,
469 GraphTask& graph_task,
470 bool accumulate_grad,
471 const edge_list& output_edges,
472 THPObjectPtr* graph_arg_inputs,
473 THPObjectPtr* graph_arg_sizes,
474 THPObjectPtr* graph_arg_ivalue_args,
475 THPObjectPtr* graph_arg_hooks) {
476 std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
477 std::vector<std::shared_ptr<Node>> worklist{graph_root};
478 AutogradCompilerCall compiler_call;
479
480 for (const auto i : c10::irange(output_edges.size())) {
481 compiler_call.node_calls
482 .lookup(output_edges[i].function)
483 // NOLINTNEXTLINE(*-narrowing-conversions)
484 .mark_output(output_edges[i].input_nr, i);
485 }
486 const bool check_exec_info = !graph_task.exec_info_.empty();
487 CacheNode* cache = CacheNode::root();
488 std::vector<NodeCall*> calls;
489 calls.reserve(
490 check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
491
492 int i = 0;
493 std::optional<VerboseLogger> vlogger = VerboseLogger::maybe_create();
494 while (!worklist.empty()) {
495 std::shared_ptr<Node> fn = std::move(worklist.back());
496 worklist.pop_back();
497 NodeCall& call = compiler_call.node_calls.lookup(fn);
498 calls.emplace_back(&call);
499
500 { // update cache and gather args into `compiler_call`
501 CompiledNodeArgs node_args(compiler_call, call);
502 node_args.collect(call);
503 if (node_args.cond(call.needed)) {
504 fn->compiled_args(node_args);
505 node_args.collect(call.node->next_edges());
506 }
507 CacheKey key = node_args.key();
508 if (vlogger.has_value()) {
509 std::unordered_set<CacheKey> cached_keys;
510 for (const auto& [k, _] : cache->next) {
511 cached_keys.emplace(k);
512 }
513 vlogger->log_node_check(
514 *fn,
515 compiler_call.all_size_inputs.size(),
516 std::move(cached_keys),
517 key,
518 i);
519 }
520 cache = cache->lookup(key);
521 }
522
523 for (const auto& edge : fn->next_edges()) {
524 if (!edge.is_valid()) {
525 continue;
526 }
527 if (check_exec_info) {
528 auto it = graph_task.exec_info_.find(edge.function.get());
529 if (it == graph_task.exec_info_.end() || !it->second.should_execute()) {
530 continue;
531 }
532 if (!it->second.needed_) {
533 compiler_call.node_calls.lookup(edge.function).needed = false;
534 }
535 }
536 auto it = dependencies.find(edge.function.get());
537 TORCH_INTERNAL_ASSERT(it != dependencies.end());
538 if (--it->second == 0) {
539 dependencies.erase(it);
540 worklist.emplace_back(edge.function);
541 }
542 }
543 i++;
544 }
545
546 // TODO(jansel): some dynamic sizes seem to be ints not symints
547 if (!cache->check_dynamic_sizes(compiler_call, vlogger)) {
548 // cache miss, need to capture FX graph
549 ClosingTHPObjectPtr py_compiler(
550 check(PyObject_CallNoArgs((the_autograd_compiler))));
551
552 TraceState state = call_begin_capture(
553 py_compiler, *cache, compiler_call, output_edges.size());
554 InputBuffers input_buffers;
555
556 for (size_t i = 0; i < calls.size(); i++) {
557 NodeCall& call = *calls[i];
558 // TODO(jansel): consider adding some of this stuff:
559 // guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
560 // opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
561 // c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
562 // CheckpointValidGuard cpvguard(graph_task);
563 // at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
564 // if (C10_UNLIKELY(step_callbacks.has_value())) { ... }
565
566 variable_list inputs =
567 std::move(input_buffers.lookup(call.node.get()).buffer);
568 input_buffers.erase(call.node.get());
569
570 if (!call.tensor_pre_hooks.empty()) {
571 THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
572 for (const auto& hook : call.tensor_pre_hooks) {
573 pyinputs = check(PyObject_CallMethod(
574 py_compiler,
575 "tensor_pre_hook",
576 "Oii",
577 pyinputs.get(),
578 hook.first,
579 hook.second));
580 }
581 inputs = THPVariable_UnpackList(pyinputs);
582 }
583 for (const auto& graph_output : call.graph_output) {
584 int input_nr = graph_output.first;
585 int output_index = graph_output.second;
586 TORCH_INTERNAL_ASSERT(
587 output_index < static_cast<int>(state.outputs.size()));
588 TORCH_INTERNAL_ASSERT(!state.outputs[output_index].defined());
589 state.outputs[output_index] = inputs[input_nr];
590 }
591 if (!call.needed) {
592 continue;
593 }
594 if (!call.pre_hooks.empty()) {
595 THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
596 for (const auto hook : call.pre_hooks) {
597 pyinputs = check(PyObject_CallMethod(
598 py_compiler.get(), "pre_hook", "Oi", pyinputs.get(), hook));
599 }
600 inputs = THPVariable_UnpackList(pyinputs);
601 }
602
603 std::string _node_name = call.node->name();
604 THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
605 TORCH_INTERNAL_ASSERT(node_name != nullptr);
606 THPObjectPtr set_node_origin(
607 PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
608
609 PyObject* pyobj = Py_None;
610 if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
611 pyobj = pynode->obj;
612 }
613
614 check(PyObject_CallFunction(
615 set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
616
617 SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
618 variable_list outputs = call.node->apply_with_saved(inputs, saved);
619
620 saved.debug_asserts();
621 saved.before(call.node->next_edges());
622 validate_outputs(
623 call.node->next_edges(), outputs, [&](const std::string& msg) {
624 std::ostringstream ss;
625 ss << "[Compiled Autograd Tracing: " << call.node->name() << "] "
626 << msg;
627 return ss.str();
628 });
629 saved.after(call.node->next_edges());
630 saved.debug_asserts();
631
632 if (!call.post_hooks.empty()) {
633 THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
634 THPObjectPtr pyoutputs(THPVariable_WrapList(outputs));
635 for (const auto hook : call.post_hooks) {
636 pyoutputs = check(PyObject_CallMethod(
637 py_compiler.get(),
638 "post_hook",
639 "OOi",
640 pyoutputs.get(),
641 pyinputs.get(),
642 hook));
643 }
644 outputs = THPVariable_UnpackList(pyoutputs);
645 }
646 for (const auto i : c10::irange(outputs.size())) {
647 auto& output = outputs[i];
648 const auto& next = call.node->next_edge(i);
649 if (next.is_valid() && output.defined()) {
650 input_buffers.lookup(next.function.get())
651 .add(
652 next.input_nr, std::move(output), std::nullopt, std::nullopt);
653 }
654 }
655 }
656
657 PyObject* res = check(call_end_capture(py_compiler, state.outputs));
658 TORCH_CHECK(PyTuple_Check(res), "Expected end_capture to return tuple");
659 TORCH_CHECK(
660 PyTuple_Size(res) == 2,
661 "Expected end_capture to return tuple of size 2");
662 cache->runtime_wrapper = Py_NewRef(PyTuple_GetItem(res, 0));
663 TORCH_CHECK(
664 PyCallable_Check(cache->runtime_wrapper),
665 "Expected end_capture to return runtime_wrapper");
666 cache->compiled_fn = Py_NewRef(PyTuple_GetItem(res, 1));
667 TORCH_CHECK(
668 PyCallable_Check(cache->compiled_fn),
669 "Expected end_capture to return compiled_fn");
670 state.debug_asserts();
671 } // End cache miss region
672
673 // TODO(jansel): clear grads we will overwrite below
674 if (!graph_task.keep_graph_) {
675 for (auto& call : calls) {
676 call->node->release_variables();
677 }
678 }
679
680 *graph_arg_inputs = THPVariable_WrapList(compiler_call.tensor_args.inputs);
681 *graph_arg_sizes = wrap_int_list(compiler_call.dyn_size_inputs);
682 *graph_arg_ivalue_args =
683 wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
684 *graph_arg_hooks = convert_hook_list(compiler_call.hooks);
685 return cache;
686 }
687
688 struct LockGuardWithErrorLogs {
LockGuardWithErrorLogstorch::dynamo::autograd::LockGuardWithErrorLogs689 LockGuardWithErrorLogs(std::mutex& mtx) : mtx_(mtx) {
690 // Note: the standard allows try_lock to fail spuriously during races for
691 // performance reasons, but it shouldn't happen here since we:
692 // 1. disable multithreaded autograd
693 // 2. plenty of latency between backward calls
694 TORCH_INTERNAL_ASSERT(
695 mtx_.try_lock(),
696 "Trying to run compiled autograd within another compiled autograd call (e.g. reentrant checkpointing), this is not supported yet.");
697 }
698
~LockGuardWithErrorLogstorch::dynamo::autograd::LockGuardWithErrorLogs699 ~LockGuardWithErrorLogs() {
700 mtx_.unlock();
701 }
702
703 std::mutex& mtx_;
704 };
705
compiled_autograd(const std::shared_ptr<Node> & graph_root,GraphTask & graph_task,bool accumulate_grad,const edge_list & output_edges)706 variable_list compiled_autograd(
707 const std::shared_ptr<Node>& graph_root,
708 GraphTask& graph_task,
709 bool accumulate_grad,
710 const edge_list& output_edges) {
711 TORCH_CHECK(
712 c10::impl::TorchDispatchModeTLS::stack_len() == 0,
713 "TorchDispatchMode not yet implemented for compiled autograd")
714 static std::mutex mtx;
715 LockGuardWithErrorLogs lock_guard(mtx);
716 pybind11::gil_scoped_acquire gil;
717 at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
718
719 THPObjectPtr inputs;
720 THPObjectPtr sizes;
721 THPObjectPtr ivalue_args;
722 THPObjectPtr hooks;
723 CacheNode* cache = _compiled_autograd_impl(
724 graph_root,
725 graph_task,
726 accumulate_grad,
727 output_edges,
728 &inputs,
729 &sizes,
730 &ivalue_args,
731 &hooks);
732
733 THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
734 cache->runtime_wrapper.get(),
735 cache->compiled_fn.get(),
736 inputs.get(),
737 sizes.get(),
738 ivalue_args.get(),
739 hooks.get(),
740 NULL)));
741 variable_list outputs = THPVariable_UnpackList(pyresult);
742 TORCH_INTERNAL_ASSERT(outputs.size() == output_edges.size());
743 return outputs;
744 }
745
set_autograd_compiler(PyObject * dummy,PyObject * args)746 static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
747 HANDLE_TH_ERRORS;
748 PyObject* obj = nullptr;
749 if (!PyArg_ParseTuple(args, "O", &obj)) {
750 return nullptr;
751 }
752
753 PyObject* prior = the_autograd_compiler;
754 if (obj == Py_None) { // disable
755 the_autograd_compiler = nullptr; // decref not needed due to `prior`
756 Engine::set_compiled_autograd(nullptr);
757 } else { // enable
758 Py_INCREF(obj);
759 the_autograd_compiler = obj;
760 Engine::set_compiled_autograd(&compiled_autograd);
761 }
762
763 if (prior == nullptr) {
764 Py_RETURN_NONE;
765 } else {
766 return prior;
767 }
768 END_HANDLE_TH_ERRORS;
769 }
770
torch_c_dynamo_compiled_autograd_init()771 PyObject* torch_c_dynamo_compiled_autograd_init() {
772 return PyModule_Create(&_module);
773 }
774
775 } // namespace torch::dynamo::autograd
776