xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend_init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend_init.h>
2 
3 #include <pybind11/iostream.h>
4 #include <torch/csrc/jit/backends/backend_detail.h>
5 #include <torch/csrc/jit/backends/backend_resolver.h>
6 #include <torch/csrc/jit/python/module_python.h>
7 #include <torch/csrc/jit/python/pybind_utils.h>
8 #include <torch/csrc/utils/pybind.h>
9 
10 namespace torch {
11 namespace jit {
12 
13 // Get all types that are shared in the module hierarchy rooted at \p mod.
getSharedModuleTypes(Module & mod)14 std::unordered_set<TypePtr> getSharedModuleTypes(Module& mod) {
15   // Maintain a set of all TypePtrs.
16   std::unordered_set<TypePtr> types;
17   // Maintain another set of TypePtrs that have been encountered more than once.
18   std::unordered_set<TypePtr> duplicate_types;
19 
20   // Iterate over all modules in the hierarchy, including the root.
21   for (auto module : mod.modules()) {
22     auto module_type = module.type();
23     if (types.count(module_type) > 0) {
24       duplicate_types.insert(module_type);
25     }
26 
27     types.insert(module_type);
28   }
29 
30   return duplicate_types;
31 }
32 
33 // Selectively lower \p mod to a backend. \p to_backend
34 // is called to lower modules. \p modules_to_lower contains
35 // qualified names of submodules of \p mod that should be lowered.
toBackendSelectiveImpl(Module & mod,const py::function & to_backend,const std::vector<std::string> & modules_to_lower,const std::unordered_set<TypePtr> & duplicate_types)36 void toBackendSelectiveImpl(
37     Module& mod,
38     const py::function& to_backend,
39     const std::vector<std::string>& modules_to_lower,
40     const std::unordered_set<TypePtr>& duplicate_types) {
41   // This map will be used later to remap types in ancestor module graphs for
42   // all lowered submodules.
43   std::unordered_map<TypePtr, TypePtr> type_remap;
44 
45   // For each module that should be lowered:
46   for (const auto& module_to_lower : modules_to_lower) {
47     // Use QualifiedName to parse the qualified module names.
48     c10::QualifiedName qual_module_name(module_to_lower);
49     auto& atoms = qual_module_name.atoms();
50 
51     // Search through the module hierarchy using the atoms of
52     // qual_module_name until current points to the module to
53     // be lowered and parent points to its parent.
54     Module current = mod;
55     Module parent;
56 
57     for (size_t i = 0, e = atoms.size(); i < e; ++i) {
58       IValue submodule = current.attr(atoms[i]);
59       if (submodule.isModule()) {
60         if (i == e - 1) {
61           parent = current;
62         }
63         current = submodule.toModule();
64       } else {
65         std::stringstream err;
66         err << "Attribute named " << atoms[i] << " is not a Module";
67         throw std::runtime_error(err.str());
68       }
69     }
70 
71     // Check that the parent type is not shared and therefore can be edited.
72     if (duplicate_types.count(parent.type()) > 0) {
73       throw py::cast_error(c10::str(
74           "Selective lowering is only supported for module hierarchies with unique types for selected modules; ",
75           parent.type()->repr_str(),
76           " is shared"));
77     }
78 
79     // Call to_backend on the module that needs to be lowered. It needs to be
80     // wrapped before doing so because _to_jit_backend accepts wrapped modules.
81     // The result needs to be unwrapped in order to access its type below.
82     auto lowered_submodule =
83         py::cast<Module>(to_backend(py::module::import("torch.jit._recursive")
84                                         .attr("wrap_cpp_module")(current))
85                              .attr("_c"));
86 
87     // Adjust the parent's type so that the type of the submodule matches
88     // the type of lowered_submodule.
89     auto parent_type = parent.type();
90 
91     parent_type->unsafeChangeAttributeType(
92         atoms.back(), lowered_submodule.type());
93     parent.setattr(atoms.back(), lowered_submodule._ivalue());
94 
95     // Record the type mapping from old type -> lowered type.
96     type_remap[current.type()] = lowered_submodule.type();
97   }
98 
99   // Having lowered all of the modules that needed to be lowered, remap types in
100   // all graphs in the hierarchy so that the graphs all use the new lowered
101   // type.
102   auto type_remap_fn = [&type_remap](TypePtr in) {
103     auto it = type_remap.find(in);
104     if (it == type_remap.end())
105       return in;
106     return it->second;
107   };
108 
109   // modules() iterates over all modules in the hierarchy including the root.
110   for (auto module : mod.modules()) {
111     auto module_type = module.type();
112     for (auto& fn : module_type->methods()) {
113       auto method = module.get_method(fn->name());
114       auto graph = method.graph();
115       graph->remapTypes(type_remap_fn);
116       auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
117       fn->setSchema(new_schema);
118     }
119   }
120 }
121 
codegen_func(const std::string & backend_name,const Module & orig_module,const py::dict & method_compile_spec)122 Module codegen_func(
123     const std::string& backend_name,
124     const Module& orig_module,
125     const py::dict& method_compile_spec) {
126   // Represents of a Type of Dict[str, Any].
127   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
128   return detail::codegen_backend_module(
129       backend_name,
130       orig_module,
131       toIValue(method_compile_spec, any_dict_ty).toGenericDict(),
132       any_dict_ty);
133 }
134 
initJitBackendBindings(PyObject * module)135 void initJitBackendBindings(PyObject* module) {
136   // Bind a function for lowering to each JIT backend. The name of the backend
137   // must be the first argument. For example, to lower a Module to
138   // "example_backend", declared as
139   //
140   //  static auto cls = torch::jit::backend<ExampleBackend>("example_backend");
141   //
142   // this function must be called like
143   //
144   //  torch._C._jit_to_backend("example_backend", module, spec)
145   auto m = py::handle(module).cast<py::module>();
146   m.def(
147       "_jit_to_backend",
148       [=](const std::string& backend_name,
149           py::handle orig_module,
150           const py::dict& method_compile_spec) {
151         py::scoped_ostream_redirect cerr(
152             std::cerr, py::module_::import("sys").attr("stderr"));
153         py::scoped_ostream_redirect cout(
154             std::cout, py::module_::import("sys").attr("stdout"));
155         return py::module::import("torch.jit._recursive")
156             .attr("wrap_cpp_module")(codegen_func(
157                 backend_name,
158                 py::cast<Module>(orig_module.attr("_c")),
159                 method_compile_spec));
160       });
161 
162   m.def(
163       "_jit_to_backend_selective",
164       [=](py::handle orig_module,
165           const py::function& to_backend,
166           const std::vector<std::string>& modules_to_lower) {
167         py::scoped_ostream_redirect cerr(
168             std::cerr, py::module_::import("sys").attr("stderr"));
169         py::scoped_ostream_redirect cout(
170             std::cout, py::module_::import("sys").attr("stdout"));
171         if (auto original_module =
172                 as_module(py::cast<py::object>(orig_module))) {
173           // Clone the Module to avoid editing types that are shared with
174           // Modules in other instances outside this hierarchy.
175           Module& mod = original_module.value();
176           auto cloned_mod = mod.clone();
177           // Get all shared module types. Type sharing is only a problem if the
178           // parent modules of the ones to lower are in this set.
179           auto shared_types = getSharedModuleTypes(cloned_mod);
180           toBackendSelectiveImpl(
181               cloned_mod, to_backend, modules_to_lower, shared_types);
182           // Wrap the result in a RecursiveScriptModule because that's what
183           // the caller passed in.
184           return py::module::import("torch.jit._recursive")
185               .attr("wrap_cpp_module")(cloned_mod);
186         }
187 
188         throw py::cast_error(c10::str(
189             "Object ", py::str(orig_module), " is not a ScriptModule"));
190       });
191 }
192 } // namespace jit
193 } // namespace torch
194