xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/autodiff.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/autodiff.h>
2 
3 #include <ATen/core/functional.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/jit_log.h>
7 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
8 #include <torch/csrc/jit/passes/dead_code_elimination.h>
9 #include <torch/csrc/jit/passes/inliner.h>
10 #include <torch/csrc/jit/passes/lower_tuples.h>
11 #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
12 #include <torch/csrc/jit/runtime/operator.h>
13 #include <torch/csrc/jit/runtime/symbolic_script.h>
14 #include <algorithm>
15 #include <memory>
16 
17 namespace torch::jit {
18 
19 using value_map = std::unordered_map<Value*, Value*>;
20 using value_set = std::unordered_set<Value*>;
21 
22 // need_trim_grad_ops contains functions that return multiple outputs in
23 // forward, but only the first one requires grad.
24 // Example:
25 // kthvalue returns (kthvalue, index of kthvalue), currently autodiff only
26 // supports at most one output that requires grad. Thus we need to remove
27 // the grad for index that doesn't require grad.
needTrimGrad(Node * n)28 static bool needTrimGrad(Node* n) {
29   static OperatorSet need_trim_grad_ops = {
30       "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
31       "aten::topk(Tensor self, int k, int dim, bool largest, bool sorted) -> (Tensor, Tensor)",
32       "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
33       "aten::max_pool2d_with_indices(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> (Tensor, Tensor)"};
34   if (n->isMemberOf(need_trim_grad_ops)) {
35     return true;
36   }
37   return false;
38 }
39 
isDifferentiable(const Node * n)40 bool isDifferentiable(const Node* n) {
41   // TODO: scalar-tensor ops should be canonicalized
42   static OperatorSet differentiable_ops = {
43       "aten::_slow_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> Tensor",
44       "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)",
45   };
46 
47   // TODO: add support for the following fusible operators.
48   // They're a little tricky to implement; max/min require mutability for best
49   // perf "aten::atan2(Tensor self) -> Tensor", "aten::max(Tensor self) ->
50   // Tensor", "aten::min(Tensor self) -> Tensor"
51 
52   if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero ||
53       n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk ||
54       n->kind() == prim::profile || n->kind() == prim::profile_ivalue)
55     return true;
56 
57   if (n->isMemberOf(differentiable_ops))
58     return true;
59 
60   if (n->matches(
61           "aten::dropout(Tensor input, float p, bool train) -> Tensor",
62           attr::train)) {
63     return n->get<bool>(attr::train).value();
64   }
65 
66   if (n->matches(
67           "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor")) {
68     return n->get<c10::List<int64_t>>(attr::size) &&
69         n->is_constant(attr::implicit);
70   }
71 
72   auto schema = n->maybeSchema();
73   if (schema && hasGradientInfoForSchema(*schema)) {
74     return true;
75   }
76 
77   // linear blocks may appear as inputs to graph executors, but they are removed
78   // before differentiation occurs
79   if (n->kind() == prim::GradOf) {
80     auto body = n->blocks().at(0);
81     return std::all_of(
82         body->nodes().begin(),
83         body->nodes().end(),
84         static_cast<bool (*)(const Node*)>(isDifferentiable));
85   }
86 
87   // formulas are only defined with floating point scalars,
88   // so we fallback to autograd for other cases.
89   for (const Value* input : n->inputs()) {
90     if (input->type() == NumberType::get()) {
91       return false;
92     }
93   }
94 
95   return false;
96 }
97 
isDifferentiable(Graph & g)98 bool isDifferentiable(Graph& g) {
99   return std::all_of(
100       g.nodes().begin(),
101       g.nodes().end(),
102       static_cast<bool (*)(const Node*)>(isDifferentiable));
103 }
104 
105 // NB: Write gradient using torchscript
106 // For example, node aten::mul() should be defined as follows
107 // def forward(x, y):
108 //     return x*y, (x, y)
109 // def backward(ctx, grad_output):
110 //     x, y = ctx
111 //     return (y * grad_output).sum_to_size(x), (x * grad_output).sum_to_size(y)
112 //
113 // Here ctx is a tuple that carries all input/intermediate results needed in
114 // backward from forward pass.
115 //
116 // This python code is compiled into a GradientPair which includes a forward
117 // graph and a backward graph. Forward graph will be used to replace the node in
118 // grad_desc.f, and backward graph will be used to construct GradOf(node) in
119 // reverse_block. Grad_values(a.k.a gradOutputs) propagated through
120 // node->owningGraph() in **reversed** order, thus GradientPair.forward should
121 // be inserted **after** the node being replaced, so that we don't traverse the
122 // graph infinite times.
123 //
124 // The output of compiled forward graph is [real_outputs, ctx]
125 // The input of compiled backward graph is [ctx, grad_values]
126 // We run LowerSimpleTuples afterwards to eliminate all tuples generated in
127 // this process. The original node and TupleConstruct nodes in forward graph
128 // will be cleaned up later using EliminateDeadCode(block). TupleUnPack node in
129 // backward graph will be removed in eliminateDeadcode(ReverseDetails) defined
130 // in this file.
build_script_grad(Node * node,const ArrayRef<Value * > & grads)131 static std::optional<std::vector<Value*>> build_script_grad(
132     Node* node,
133     const ArrayRef<Value*>& grads) {
134   auto graph = node->owningGraph();
135   auto maybe_schema = node->maybeSchema();
136   if (!maybe_schema) {
137     return std::nullopt;
138   }
139   auto compiled_graphs = gradientInfoForSchema(*maybe_schema);
140   if (!compiled_graphs) {
141     return std::nullopt;
142   }
143   // Use forward graph to replace node in grad_desc.f
144   value_list new_outputs;
145   {
146     WithInsertPoint guard(node->next());
147     auto fw_graph = compiled_graphs->forward;
148     new_outputs = insertGraph(*graph, *fw_graph, node->inputs());
149     new_outputs = unpackOutputs(new_outputs);
150     auto outputs = node->outputs();
151     AT_ASSERT(new_outputs.size() == outputs.size() + 1);
152     for (const auto i : c10::irange(outputs.size())) {
153       new_outputs.at(i)->setType(outputs[i]->type());
154       outputs[i]->replaceAllUsesWith(new_outputs.at(i));
155     }
156   }
157 
158   // Use backward graph to construct reverse_block
159   auto bw_graph = compiled_graphs->backward;
160   auto grad_vec = grads.vec();
161   if (needTrimGrad(node)) {
162     grad_vec.erase(grad_vec.begin() + 1, grad_vec.end());
163   }
164   auto it = grad_vec.begin();
165   grad_vec.insert(it, new_outputs.back());
166   ArrayRef<Value*> grad(grad_vec);
167   auto grad_inputs = insertGraph(*graph, *bw_graph, grad);
168   grad_inputs = unpackOutputs(grad_inputs);
169   return grad_inputs;
170 };
171 
172 namespace {
173 class GradientHelper {
174  public:
GradientHelper(Node * n)175   GradientHelper(Node* n) : node(n) {}
176 
gradient(ArrayRef<Value * > grad_values)177   std::vector<Value*> gradient(ArrayRef<Value*> grad_values) {
178     if (!isDifferentiable(node)) {
179       throw std::runtime_error(
180           std::string("differentiation of ") + node->kind().toDisplayString() +
181           " is not supported, or it is missing necessary type information");
182     }
183     // If AD is defined using torchscript, use it instead of symbolic
184     auto script_grads = build_script_grad(node, grad_values);
185     if (script_grads)
186       return *script_grads;
187 
188     // Definition not found in torchscript, look up in the buildSymbolicGradient
189     // TODO: migrate all to using torchscript
190     return buildSymbolicGradient(grad_values);
191   }
192 
193  private:
194   Node* node;
195 
buildSymbolicGradient(const ArrayRef<Value * > & grad_values)196   std::vector<Value*> buildSymbolicGradient(
197       const ArrayRef<Value*>& grad_values) {
198     auto inputs = node->inputs();
199     auto outputs = node->outputs();
200 
201     if (node->kind() == prim::AutogradAdd) {
202       // NB: AutogradAdds don't broadcast
203       return {grad_values.at(0), grad_values.at(0)};
204     } else if (node->kind() == prim::profile) {
205       return {grad_values.at(0)};
206     } else if (node->kind() == prim::ConstantChunk) {
207       auto* g = node->owningGraph();
208 
209       Value* input_list = nullptr;
210       if (grad_values.size() == 1 &&
211           grad_values[0]->type()->isSubtypeOf(*ListType::ofTensors())) {
212         input_list = grad_values[0];
213       } else {
214         input_list =
215             g->insertNode(g->createList(TensorType::get(), grad_values))
216                 ->output();
217       }
218 
219       auto* cDim = g->insertConstant(node->i(attr::dim));
220       auto* cat_node = g->insertNode(g->create(aten::cat, 1));
221       cat_node->addInput(input_list);
222       cat_node->addInput(cDim);
223       return {cat_node->output()};
224     } else if (
225         node->kind() == prim::Constant || node->kind() == prim::AutogradZero) {
226       return {};
227     } else if (
228         node->matches(
229             "aten::_slow_conv2d_forward(Tensor self, Tensor weight, int[] kernel_size, Tensor? bias, int[] stride, int[] padding) -> Tensor")) {
230       auto graph = node->owningGraph();
231       auto backward_value = graph->insert(
232           aten::_slow_conv2d_backward,
233           {grad_values.at(0),
234            inputs.at(0),
235            inputs.at(1),
236            node->namedInput(attr::kernel_size),
237            node->namedInput(attr::stride),
238            node->namedInput(attr::padding),
239            graph->insertConstant(c10::List<bool>({true, true, true}))});
240       // graph->insert returns a tuple automatically if multiple outputs are
241       // returned. So unpack them again.
242       Node* tuple_unpack_node =
243           graph->insertNode(graph->createTupleUnpack(backward_value));
244       auto tuple_outputs = tuple_unpack_node->outputs();
245       AT_ASSERT(tuple_outputs.size() == size_t(3));
246       return {
247           tuple_outputs[0],
248           tuple_outputs[1],
249           nullptr,
250           tuple_outputs[2],
251           nullptr,
252           nullptr};
253 
254     } else if (
255         node->matches(
256             "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)")) {
257       auto graph = node->owningGraph();
258       auto backward_value = graph->insert(
259           aten::native_batch_norm_backward,
260           {grad_values.at(0),
261            inputs.at(0),
262            inputs.at(1),
263            inputs.at(3),
264            inputs.at(4),
265            outputs.at(1),
266            outputs.at(2),
267            inputs.at(5),
268            inputs.at(7),
269            graph->insertConstant(c10::List<bool>({true, true, true}))});
270       // graph->insert returns a tuple automatically if multiple outputs are
271       // returned. So unpack them again.
272       Node* tuple_unpack_node =
273           graph->insertNode(graph->createTupleUnpack(backward_value));
274       auto tuple_outputs = tuple_unpack_node->outputs();
275       AT_ASSERT(tuple_outputs.size() == size_t(3));
276       return {
277           tuple_outputs[0],
278           tuple_outputs[1],
279           tuple_outputs[2],
280           nullptr,
281           nullptr,
282           nullptr,
283           nullptr,
284           nullptr};
285     }
286 
287     throw std::runtime_error(
288         std::string("failed to differentiate `") +
289         node->kind().toDisplayString() + "`");
290   }
291 };
292 } // namespace
293 
294 // If we have a function y = f(x) with jacobian J, the backwards of f is dx =
295 // J^t dy. Note that because the backwards always implements this matrix
296 // multiply, we know that it maps an input vector of zeros to an output vector
297 // of zero regardless of what operations it chooses to do inside to actually
298 // implement the matrix multiply (most use some optimized form and never
299 // generate J^t). More generally, we know that all of the backward computations
300 // are linear and can use this property to do more aggressive optimizations
301 // later. It is ok to replace any backward function with known-zero inputs with
302 // something that produces known-zero outputs. This function encloses each
303 // know-linear backward function in a 'GradOf' sub-block so that we can perform
304 // optimizations using this information. In particular, specializeAutogradZero
305 // will observe if all the inputs to the linear block are AutogradZeroTensor,
306 // which the autograd uses to represent zeros, and then propagate the zeros to
307 // the outputs of the block.
linearGradientForNode(Node * node,ArrayRef<Value * > grad_values)308 static std::vector<Value*> linearGradientForNode(
309     Node* node,
310     ArrayRef<Value*> grad_values) {
311   auto& graph = *node->owningGraph();
312 
313   // FIXME: In case forward has multi outputs, we only support one requires grad
314   if (needTrimGrad(node)) {
315     grad_values = grad_values.at(0);
316   }
317   auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0));
318   // to make reading gradient graphs easier, remember the name of the forward op
319   linear->s_(attr::name, node->kind().toDisplayString());
320   auto block = linear->addBlock();
321   WithInsertPoint guard(block);
322   auto results = GradientHelper(node).gradient(grad_values);
323   return fmap(results, [block, linear](Value* grad) -> Value* {
324     if (!grad || grad->mustBeNone())
325       return nullptr;
326     block->registerOutput(grad);
327     return linear->addOutput()->copyMetadata(grad);
328   });
329 }
330 
331 struct ReverseDetails {
ReverseDetailstorch::jit::ReverseDetails332   ReverseDetails(value_map&& grad_map, Block* reverse_block)
333       : grad_map(std::move(grad_map)), reverse_block(reverse_block) {}
334 
335   value_map grad_map;
336   Block* reverse_block;
337 };
338 
339 // AutogradAdd is a special addition function that handles Undef
340 // AutogradAdd(a, b) == a + b if defined(a) and defined(b)
341 // AutogradAdd(Undef, b) == b
342 // AutogradAdd(a, Undef) == a
343 // AutogradAdd(Undef, Undef) == Undef
createAutogradAdd(Value * a,Value * b)344 static Value* createAutogradAdd(Value* a, Value* b) {
345   auto graph = a->owningGraph();
346   return graph->insertNode(graph->create(prim::AutogradAdd, {a, b}))->output();
347 }
348 
349 namespace {
outputRequiresGrad(Value * output)350 bool outputRequiresGrad(Value* output) {
351   if (output->type()->castRaw<TensorType>() == nullptr) {
352     return output->requires_grad();
353   }
354   std::optional<bool> requiresGrad =
355       output->type()->expectRef<TensorType>().requiresGrad();
356   if (requiresGrad.has_value()) {
357     return *requiresGrad;
358   }
359 
360   Node* n = output->node();
361   if (n->kind() != prim::profile) {
362     return true;
363   }
364   if (!n->hasAttribute(attr::profiled_type)) {
365     return true;
366   }
367   return n->ty(attr::profiled_type)->requires_grad();
368 }
369 } // namespace
370 
371 // Before:
372 //   - grad_desc has field f initialized to the original 0-stage graph
373 // After:
374 //   - the last node of f (f->nodes().reverse()[0]) is a gradient node
375 //     whose block has vjp inputs for all outputs that require_grad
376 //     and vjp outputs for all primal inputs that require_grad
377 //   - grad_desc has df_input_vjps and df_output_vjps set
378 //     (but df_input_vjps will be modified later as well)
addReverseInline(Gradient & grad_desc)379 static ReverseDetails addReverseInline(Gradient& grad_desc) {
380   auto& graph = *grad_desc.f;
381   // note: reverse_node is intentionally not inserted to avoid
382   // accidentally acting on it (e.g. in eliminate dead code),
383   // std::cout << *reverse_node << to view its state.
384   auto reverse_node = graph.create(prim::Reverse, 0);
385   auto reverse_block = reverse_node->addBlock();
386   WithInsertPoint guard(reverse_block);
387 
388   value_map grad_map; // x -> dx mapping
389   const auto get_grad = [&](Value* v) -> Value* {
390     auto it = grad_map.find(v);
391     if (it == grad_map.end()) {
392       auto autograd_zero = graph.insertNode(graph.createAutogradZero());
393       it = grad_map.emplace(v, autograd_zero->output()).first;
394     }
395     return it->second;
396   };
397   const auto set_grad = [&](Value* x, Value* dx) {
398     if (Value* prev_grad = grad_map[x]) {
399       GRAPH_DEBUG("grad_map[", x->debugName(), "] = ", *grad_map[x]->node());
400       grad_map[x] = createAutogradAdd(prev_grad, dx);
401     } else {
402       GRAPH_DEBUG("grad_map[", x->debugName(), "] = ", dx->debugName());
403       grad_map[x] = dx;
404     }
405   };
406 
407   auto outputs = graph.outputs();
408   for (size_t i = 0, num_outputs = outputs.size(); i < num_outputs; ++i) {
409     Value* output = outputs[i];
410     if (!outputRequiresGrad(output))
411       continue;
412     Value* output_grad = reverse_block->addInput()->setType(output->type());
413     GRAPH_DEBUG(
414         "Adding output_grad ",
415         output_grad->debugName(),
416         " for ",
417         output->debugName());
418     set_grad(output, output_grad);
419     grad_desc.df_input_vjps.push_back(i);
420   }
421 
422   for (auto it = graph.nodes().rbegin(), end = graph.nodes().rend(); it != end;
423        ++it) {
424     Node* node = *it;
425     auto inputs = node->inputs();
426     auto outputs = node->outputs();
427     if (std::all_of(outputs.begin(), outputs.end(), [](Value* v) {
428           return !v->requires_grad();
429         })) {
430       continue;
431     }
432 
433     value_list grad_inputs =
434         linearGradientForNode(node, fmap(node->outputs(), get_grad));
435     LowerSimpleTuples(reverse_block);
436 
437     AT_ASSERT(grad_inputs.size() == node->inputs().size());
438     for (size_t i = 0, num_inputs = grad_inputs.size(); i < num_inputs; ++i) {
439       if (!inputs[i]->requires_grad())
440         continue;
441       // NB: Not returning a gradient w.r.t. a value that requires grad is
442       // normal if the input is non-differentiable. This happens e.g. in the
443       // aten::type_as case.
444       if (!grad_inputs[i])
445         continue;
446       set_grad(inputs[i], grad_inputs[i]);
447     }
448   }
449 
450   auto inputs = graph.inputs();
451   for (size_t i = 0, num_inputs = inputs.size(); i < num_inputs; ++i) {
452     Value* input = inputs[i];
453     if (!input->requires_grad())
454       continue;
455     // NB: Not having a gradient defined w.r.t. an input to the graph which
456     // requires grad can happen and is not an error. It might have been used
457     // only in non-differentiable contexts (e.g. as second input to
458     // aten::type_as). In that case we simply ignore it as an output, because it
459     // won't ever produce any meaningful values.
460     if (grad_map.count(input) == 0)
461       continue;
462     reverse_block->registerOutput(get_grad(input));
463     grad_desc.df_output_vjps.push_back(i);
464   }
465 
466   Inline(graph);
467   return ReverseDetails(std::move(grad_map), reverse_block);
468 }
469 
470 // Returns a topologically-sorted list of values produced in f, and used in its
471 // reverse program.
getReverseCaptures(Gradient & grad_desc)472 static value_list getReverseCaptures(Gradient& grad_desc) {
473   auto& graph = *grad_desc.f;
474   auto primal_block = graph.block();
475 
476   value_set reverse_captures_set;
477   value_list reverse_captures; // Invariant: topo sorted
478   auto check_uses = [&](Value* v) {
479     for (auto use : v->uses()) {
480       if (use.user->owningBlock() == primal_block)
481         continue;
482       if (/* bool unseen = */ reverse_captures_set.emplace(v).second) {
483         reverse_captures.push_back(v);
484       }
485     }
486   };
487   for (Value* input : graph.inputs()) {
488     check_uses(input);
489   }
490   for (Node* node : graph.nodes()) {
491     for (Value* output : node->outputs())
492       check_uses(output);
493   }
494   return reverse_captures;
495 }
496 
497 // Any temporary value from the primal graphs needs to be captured for later use
498 // in the reverse graph, to avoid costly recomputations. However, a lot of the
499 // nodes we have in our graphs are simply constants, which are cheap to execute
500 // and replicate, and so it's better to just copy them into the reverse graph,
501 // without polluting the output lists unnecessarily.
502 static void liftConstants(Block* block, Block* move_to_this_block);
503 
504 // is node defined inside container?
inBlock(Node * node,Block * container)505 static bool inBlock(Node* node, Block* container) {
506   Block* b = node->owningBlock();
507   while (b) {
508     if (b == container) {
509       return true;
510     }
511     b = b->owningNode() ? b->owningNode()->owningBlock() : nullptr;
512   }
513   return false;
514 }
515 
liftConstants(Node * node,Block * move_to_this_block)516 static void liftConstants(Node* node, Block* move_to_this_block) {
517   static const auto err = [](Value*) -> Value* {
518     throw std::runtime_error("unexpected input");
519   };
520   auto& graph = *node->owningGraph();
521   for (Value* input : node->inputs()) {
522     if (input->node()->kind() != prim::Constant)
523       continue;
524     // if this constant is _already_ defined in the backward pass
525     // block, we do not need to duplicate and move it because
526     // it already won't be part of the capture set
527     if (inBlock(input->node(), move_to_this_block))
528       continue;
529     Node* lifted_constant = graph.createClone(input->node(), err);
530     move_to_this_block->prependNode(lifted_constant);
531     GRAPH_DEBUG(
532         "Lifting constant ",
533         input->debugName(),
534         " from GradOf's block and adding ",
535         lifted_constant->output()->debugName(),
536         " to the backprop block");
537     node->replaceInputWith(input, lifted_constant->output());
538   }
539   for (Block* sub : node->blocks()) {
540     liftConstants(sub, move_to_this_block);
541   }
542 }
543 
liftConstants(Block * block,Block * move_to_this_block)544 static void liftConstants(Block* block, Block* move_to_this_block) {
545   for (Node* node : block->nodes()) {
546     liftConstants(node, move_to_this_block);
547   }
548   liftConstants(block->return_node(), move_to_this_block);
549 }
550 
551 // we need to fold aten::_size_if_not_equal at the differentiation time
552 // while we know the shapes of aten::_size_if_not_equal's arguments
553 // Otherwise, they will become inputs to a reverse Graph, and we will
554 // lose this information and we don't profile Scalars, or Lists yet.
555 static void foldSizeIfNotEqual(Block* node);
556 
foldSizeIfNotEqual(Node * node)557 static void foldSizeIfNotEqual(Node* node) {
558   for (Value* input : node->inputs()) {
559     if (input->node()->kind() != aten::_size_if_not_equal) {
560       continue;
561     }
562 
563     auto ptt_input =
564         input->node()->input(0)->node()->input()->type()->expect<TensorType>();
565     auto ptt_output =
566         input->node()->input(1)->node()->input()->type()->expect<TensorType>();
567 
568     auto input_size = ptt_input->sizes().concrete_sizes();
569     auto output_size = ptt_output->sizes().concrete_sizes();
570 
571     if (!input_size || !output_size) {
572       continue;
573     }
574     // insert in front of _grad_sum_to_size
575     WithInsertPoint guard(node);
576     IValue ival{};
577     Value* size = nullptr;
578     if (input_size != output_size) {
579       size = node->owningGraph()->insertConstant(*input_size);
580     } else {
581       size = node->owningGraph()->insertConstant(IValue());
582     }
583     node->replaceInputWith(input, size);
584   }
585 
586   for (auto ib : node->blocks()) {
587     foldSizeIfNotEqual(ib);
588   }
589 }
590 
591 // we need to fold aten::_size_if_not_equal at the differentiation time
592 // while we know the shapes of aten::_size_if_not_equal's arguments
593 // Otherwise, they will become inputs to a reverse Graph, and we will
594 // lose this information and we don't profile Scalars, or Lists yet.
foldSizeIfNotEqual(Block * reverse_block)595 static void foldSizeIfNotEqual(Block* reverse_block) {
596   for (auto n : reverse_block->nodes()) {
597     foldSizeIfNotEqual(n);
598   }
599   foldSizeIfNotEqual(reverse_block->return_node());
600 }
601 
deduplicateSizeCaptures(Gradient & grad_desc,ReverseDetails & rev_info)602 static void deduplicateSizeCaptures(
603     Gradient& grad_desc,
604     ReverseDetails& rev_info) {
605   Block* primal_block = grad_desc.f->block();
606   const auto usedOnlyInReverse = [primal_block](Value* v) {
607     const auto& uses = v->uses();
608     return std::all_of(uses.begin(), uses.end(), [primal_block](const Use& u) {
609       return u.user->owningBlock() != primal_block;
610     });
611   };
612   auto captures = getReverseCaptures(grad_desc);
613   value_set capture_set(captures.begin(), captures.end());
614   for (Value* capture : captures) {
615     Node* node = capture->node();
616     if (!node->matches("aten::size(Tensor self) -> int[]")) {
617       continue;
618     }
619     if (usedOnlyInReverse(capture) && capture_set.count(node->input())) {
620       WithInsertPoint insert_guard{*rev_info.reverse_block->nodes().begin()};
621       auto* size =
622           node->input()->owningGraph()->insert(aten::size, {node->input()});
623       GRAPH_DEBUG(
624           "deduplicateSizeCaptures: Replacing ",
625           capture->debugName(),
626           " with ",
627           size->debugName());
628       capture->replaceAllUsesWith(size);
629       node->destroy();
630     }
631   }
632 }
633 
eliminateDeadCode(ReverseDetails & rev_info)634 static void eliminateDeadCode(ReverseDetails& rev_info) {
635   // addReverseInline has to call gradientForNode if *any* of the inputs
636   // require grad, but it will emit vjps for *all* inputs. Use DCE to remove
637   // unnecessary nodes. Additionally, requires_grad() on intermediates is an
638   // overapproximation of the real state, so we might have emitted some
639   // gradients, only to realize that they were unnecessary once we reach a
640   // point that doesn't require grad.
641   // Of course, we need to filter out corresponding entries of grad_map, because
642   // we don't want to accidentally access freed pointers later.
643   std::function<void(const std::unordered_set<const Value*>&)> cb =
644       [&](const std::unordered_set<const Value*>& live_values) {
645         std::vector<Value*> to_erase;
646         for (auto& entry : rev_info.grad_map) {
647           if (!live_values.count(entry.second)) {
648             to_erase.push_back(entry.first);
649           }
650         }
651         for (Value* v : to_erase) {
652           GRAPH_DEBUG(
653               "Erasing unused value ", v->debugName(), " from grad_map");
654           rev_info.grad_map.erase(v);
655         }
656       };
657   EliminateDeadCode(rev_info.reverse_block, std::move(cb));
658 }
659 
Optimize(Gradient & grad_desc,ReverseDetails & rev_info)660 static void Optimize(Gradient& grad_desc, ReverseDetails& rev_info) {
661   // TODO: we are sometimes emitting expressions like
662   // _grad_sum_to_size(_grad_sum_so_size(x, s1), s2), which are equivalent to
663   // _grad_sum_to_size(x, s2), and could save us some
664   // captures, but I'm not 100% sure how to optimize this at this stage, since
665   // we don't know which GradOf blocks will be stitched together to form the
666   // derivative. I guess a smart analysis could implement this, but I didn't
667   // have time before the 1.0 release, so I put this only as a peephole
668   // optimization.
669   liftConstants(rev_info.reverse_block, rev_info.reverse_block);
670   // TODO: see if this pass can be replaced with peephole pass
671   foldSizeIfNotEqual(rev_info.reverse_block);
672   // We generally add a lot of aten::size calls (for derivatives of broadcasting
673   // operators), and they often end up duplicated, and would get captured
674   // multiple times. Make sure we deduplicate them before lifting.
675   EliminateCommonSubexpression(grad_desc.f);
676   deduplicateSizeCaptures(grad_desc, rev_info);
677   eliminateDeadCode(rev_info);
678 }
679 
680 // Takes a grad_desc.f returned from `addReverseInline` and splits off the
681 // reverse_block into its own graph, storing it in df.
682 // All intermediates needed in the second stage are added to
683 // outputs of f, and taken as inputs in df. For a more
684 // detailed description see Note [Gradient graphs] in autodiff.h.
685 // This function also initializes the fields in grad_desc that were undefined
686 // after `addReverseInline` (and extends `df_input_vjps` with vjps for captured
687 // temporaries).
lambdaLiftReverse(Gradient & grad_desc,ReverseDetails & rev_info)688 static void lambdaLiftReverse(Gradient& grad_desc, ReverseDetails& rev_info) {
689   auto& graph = *grad_desc.f;
690   auto reverse_block = rev_info.reverse_block;
691 
692   // --------------------------------------------------------------------------
693   // 1. Find values of f that need to be captured.
694   // --------------------------------------------------------------------------
695   // First, we need to find all values that are produced in f,
696   // and used in df. They will need to be added as inputs of the df
697   // and some of them may also need to be appended as outputs of f if
698   // they are not already an input or an output of f
699   // Invariant: topo sorted
700   value_list reverse_captures = getReverseCaptures(grad_desc);
701 
702   // --------------------------------------------------------------------------
703   // 2. Prepare input/outputs lists for f and df
704   // --------------------------------------------------------------------------
705   // It's simple to construct primal_inputs/reverse_outputs,
706   // but primal_outputs/reverse_inputs are much more subtle.
707   // Here's a summary of how they are supposed to look like:
708   //
709   // Primal outputs:
710   //   [original outputs], [temporaries]
711   //
712   // Reverse inputs:
713   //   [output vjps (aka grad_outputs)], [temporary vjps]
714   //   [captured primal values, in topological order],
715 
716   // -- Construct primal_outputs, df_input_captures, f_real_outputs ----
717   grad_desc.f_real_outputs = graph.outputs().size();
718 
719   std::unordered_map<Value*, size_t> orig_primal_outputs_idx;
720   std::unordered_map<Value*, size_t> orig_primal_inputs_idx;
721   // NOTE: we use emplace to avoid replacing an existing index if an output is
722   // repeated
723   for (size_t i = 0, num_outputs = graph.outputs().size(); i < num_outputs; ++i)
724     orig_primal_outputs_idx.emplace(graph.outputs()[i], i);
725   for (size_t i = 0, num_inputs = graph.inputs().size(); i < num_inputs; ++i)
726     orig_primal_inputs_idx[graph.inputs()[i]] = i;
727 
728   // NB: reverse_captures are already deduplicated, and in topo order
729   for (Value* capture_val : reverse_captures) {
730     // If it's already an output we don't have to add anything,
731     // but register the fact that it needs to be captured.
732     if (orig_primal_outputs_idx.count(capture_val) > 0) {
733       grad_desc.df_input_captured_outputs.push_back(
734           orig_primal_outputs_idx[capture_val]);
735       // If it's an input, we could add it as an output but in fact it's
736       // more efficient to use a special kind of capture.
737     } else if (orig_primal_inputs_idx.count(capture_val) > 0) {
738       grad_desc.df_input_captured_inputs.push_back(
739           orig_primal_inputs_idx.at(capture_val));
740       // Otherwise it's just a regular intermediate value that we need to add as
741       // an output
742     } else {
743       // we need to create a new temporary output for this capture because it
744       // wasn't available.
745 
746       auto out_index = graph.registerOutput(capture_val);
747       GRAPH_DEBUG(
748           "Capturing a temporary ",
749           capture_val->debugName(),
750           " as ",
751           graph.outputs()[out_index]->debugName(),
752           " for forward graph");
753       grad_desc.df_input_captured_outputs.emplace_back(
754           graph.outputs().size() - 1);
755     }
756   }
757 
758   // -- Add VJPs for temporaries, adjust df_input_vjps -------------------------
759   // NB [possible optimization]: use the newly added vjp input as soon as the
760   // first vjp for that value is generated, to reduce the lifespan of this input
761   // (currently we add it to the final vjp after all adds).
762   for (size_t i = grad_desc.f_real_outputs; i < graph.outputs().size(); ++i) {
763     Value* tmp = graph.outputs().at(i);
764     // Add VJP inputs only for intermediates that actually required grad.
765     // Note that we check the contents of the grad_map instead of
766     // tmp->requires_grad(), because it's actually a more faithful source.
767     // tmp->requires_grad() is really an overapproximation (i.e. it can have
768     // false positives), while the gradients we will emit for this value can get
769     // DCE-d in the optimization pass (because it has no influence on the real
770     // f's outputs that we differentiate).
771     if (rev_info.grad_map.count(tmp) == 0)
772       continue;
773 
774     Value* tmp_vjp_in = reverse_block->addInput()->setType(tmp->type());
775     Value* tmp_vjp_prev = rev_info.grad_map.at(tmp);
776     // This is quite weird because we can't first make a sum and then replace
777     // all uses of tmp_vjp_prev (that would replace its use in the sum too!), so
778     // we create an incorrect sum that doesn't use prev vjp, replace uses, and
779     // fix the sum.
780     Value* new_vjp = createAutogradAdd(tmp_vjp_in, tmp_vjp_in);
781     if (tmp_vjp_prev->node()->kind() == prim::Param) {
782       // can't move a node after a block param node
783       new_vjp->node()->moveBefore(
784           *tmp_vjp_prev->node()->owningBlock()->nodes().begin());
785     } else {
786       new_vjp->node()->moveAfter(tmp_vjp_prev->node());
787     }
788 
789     tmp_vjp_prev->replaceAllUsesWith(new_vjp);
790     new_vjp->node()->replaceInput(1, tmp_vjp_prev);
791     GRAPH_DEBUG("grad_map[", tmp->debugName(), "] = ", *new_vjp->node());
792     grad_desc.df_input_vjps.emplace_back(i);
793   }
794 
795   // add the captures as formal arguments to the reverse_block
796   // afterward inputs: [output vjps][temporary vjps][captures]
797   // construct a map from captured 'value' to the index in the input list
798   // used to extract this block into its own function
799   std::unordered_map<Value*, size_t> capture_to_formal_index;
800   const auto& add_capture = [&](Value* captured) {
801     capture_to_formal_index[captured] = reverse_block->inputs().size();
802     auto new_input = reverse_block->addInput()->copyMetadata(captured);
803     GRAPH_DEBUG(
804         "Capturing ",
805         captured->debugName(),
806         " as ",
807         new_input->debugName(),
808         " for an embedded backward block");
809   };
810   for (auto& offset : grad_desc.df_input_captured_inputs)
811     add_capture(graph.inputs()[offset]);
812   for (auto& offset : grad_desc.df_input_captured_outputs)
813     add_capture(graph.outputs()[offset]);
814 
815   grad_desc.df = std::make_shared<Graph>();
816   grad_desc.df->block()->cloneFrom(reverse_block, [&](Value* v) {
817     return grad_desc.df->inputs()[capture_to_formal_index.at(v)];
818   });
819 
820   GRAPH_DUMP(" forward graph: ", &graph);
821   GRAPH_DEBUG(" backward graph: ", *(reverse_block->owningNode()));
822   // reverse_node was just to hold onto reverse_block in a debuggable way
823   // we can remove it now.
824   reverse_block->owningNode()->destroy();
825 }
826 
packReturnValuesIntoTuple(const std::shared_ptr<Graph> & graph)827 static void packReturnValuesIntoTuple(const std::shared_ptr<Graph>& graph) {
828   auto returnNode = graph->block()->return_node();
829   WithInsertPoint wip(returnNode);
830   auto tuple = graph->insertNode(graph->createTuple(returnNode->inputs()));
831   returnNode->removeAllInputs();
832   returnNode->addInput(tuple->output());
833 }
834 
differentiate(std::shared_ptr<Graph> & graph)835 Gradient differentiate(std::shared_ptr<Graph>& graph) {
836   Gradient grad_desc;
837   // Take ownership of the graph
838   TORCH_CHECK(
839       graph.use_count() == 1,
840       "differentiate will mutate and destroy the graph, so it requires "
841       "graph.use_count() == 1, but found %d",
842       graph.use_count());
843   std::swap(graph, grad_desc.f);
844   // XXX: Take care when handling outputs - they can be duplicated!
845 
846   GRAPH_DUMP("grad_desc.f: ", grad_desc.f);
847   WithInsertPoint guard(grad_desc.f->block());
848   // Fills in df_input_vjps and df_output_vjps
849   auto rev_info = addReverseInline(grad_desc);
850   Optimize(grad_desc, rev_info);
851   // Clean up old nodes which has been replaced by forward graphs in torchscript
852   EliminateDeadCode(grad_desc.f->block());
853 
854   // Fills in f, df, f_real_outputs, df_input_captures,
855   // modifies df_input_vjps (new vjps are added for temporaries)
856   lambdaLiftReverse(grad_desc, rev_info);
857   packReturnValuesIntoTuple(grad_desc.df);
858 
859   // we have created a differentiable forward graph
860   // which will be run with tensors that have their gradients detached,
861   // so profiled types will have outdated requires_grad=True, update the
862   // requires_grad property
863   UpdateDifferentiableGraphRequiresGrad(grad_desc.f, false);
864   return grad_desc;
865 }
866 } // namespace torch::jit
867