xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fixup_trace_scope_blocks.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 namespace torch::jit {
7 
8 // Directly after tracing, we have an ill-formed graph with blocks inserted.
9 // Example:
10 //
11 // graph(%self : ClassType<Module>,
12 //       %input.1 : Float(3, 4)):
13 //   %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self)
14 //   %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self)
15 //   %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2)
16 //    = prim::TracedModuleForward[scope="__module.relu1"]()
17 //     block0():
18 //       %input : Float(3, 4) = aten::relu(%input.1),
19 //       -> ()
20 //    = prim::TracedModuleForward[scope="__module.relu2"](),
21 //     block0():
22 //        = prim::TracedModuleForward[scope="__module.relu2.rrr"](),
23 //         block0():
24 //           %6 : Float(3, 4) = aten::relu(%input),
25 //           -> ()
26 //       -> ()
27 //   return (%6)
28 //
29 // In this pass, we:
30 //   1) Lift Value defs to as high of a scope as needed to ensure that
31 //      they dominate all their uses. For example, `input` in the above
32 //      graph needs to be lifted to the top-level block so that its use
33 //      in the second `relu` operator is dominated.
34 //   2) Lambda lift the blocks. This ensures that all values used within
35 //      each scope have their defs captured.
36 //   3) Convert the scope blocks into methods on their respective Modules,
37 //      and convert TracedModuleForward nodes to CallMethod nodes into those
38 //      methods.
39 //
40 //  Then, we'll have a well-formed graph with proper method calls.
41 TORCH_API void FixupTraceScopeBlocks(
42     std::shared_ptr<Graph>& graph,
43     Module* self);
44 
45 } // namespace torch::jit
46