1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 namespace torch::jit { 7 8 // Directly after tracing, we have an ill-formed graph with blocks inserted. 9 // Example: 10 // 11 // graph(%self : ClassType<Module>, 12 // %input.1 : Float(3, 4)): 13 // %1 : ClassType<Module> = prim::GetAttr[name="relu1"](%self) 14 // %2 : ClassType<Module> = prim::GetAttr[name="relu2"](%self) 15 // %3 : ClassType<Module> = prim::GetAttr[name="rrr"](%2) 16 // = prim::TracedModuleForward[scope="__module.relu1"]() 17 // block0(): 18 // %input : Float(3, 4) = aten::relu(%input.1), 19 // -> () 20 // = prim::TracedModuleForward[scope="__module.relu2"](), 21 // block0(): 22 // = prim::TracedModuleForward[scope="__module.relu2.rrr"](), 23 // block0(): 24 // %6 : Float(3, 4) = aten::relu(%input), 25 // -> () 26 // -> () 27 // return (%6) 28 // 29 // In this pass, we: 30 // 1) Lift Value defs to as high of a scope as needed to ensure that 31 // they dominate all their uses. For example, `input` in the above 32 // graph needs to be lifted to the top-level block so that its use 33 // in the second `relu` operator is dominated. 34 // 2) Lambda lift the blocks. This ensures that all values used within 35 // each scope have their defs captured. 36 // 3) Convert the scope blocks into methods on their respective Modules, 37 // and convert TracedModuleForward nodes to CallMethod nodes into those 38 // methods. 39 // 40 // Then, we'll have a well-formed graph with proper method calls. 41 TORCH_API void FixupTraceScopeBlocks( 42 std::shared_ptr<Graph>& graph, 43 Module* self); 44 45 } // namespace torch::jit 46