xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/backend_detail.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/backends/backend_detail.h>
2 
3 #include <ATen/code_template.h>
4 #include <ATen/core/jit_type.h>
5 #include <torch/csrc/jit/backends/backend.h>
6 #include <torch/csrc/jit/backends/backend_debug_handler.h>
7 #include <torch/csrc/jit/backends/backend_debug_info.h>
8 #include <torch/csrc/jit/backends/backend_resolver.h>
9 
10 #include <memory>
11 #include <stack>
12 #include <unordered_map>
13 
14 namespace torch {
15 namespace jit {
16 namespace detail {
17 namespace {
18 
19 /*
20  * This is the API via which backend's preprocess function will obtain debug
21  * handles corresponding to the nodes of the graph for the lowered methods of
22  * the module.
23  * Implementation: Given graph
24  * For each node of the graph, request debug handle via debug_info_recorder.
25  * debug_info_recorder returns the next debug handle and record node with
26  * corresponding debug info, such as source range and inlined callstack.
27  *
28  * Backend code for lowering module, preprocess, calls
29  * generate_debug_handles(graph)) which will return debug handles corresponding
30  * to the Node* of the said graph.
31  *
32  * In to_backend, after lowering, stopRecording is called on
33  * BackendModuleDebugInfoRecorder: It will extract debug map. This map gets
34  * stored as part of the lowered module.
35  * During serialization, specifically for bytecode serialization, check is made
36  * to see if the model being serialized has any lowered modules. If so
37  * corresponding debug map is extracted and serialized.
38  */
39 
generate_debug_handles(BackendDebugInfoRecorder & debug_info_recorder,const std::shared_ptr<Graph> & graph)40 NodeToDebugHandle generate_debug_handles(
41     BackendDebugInfoRecorder& debug_info_recorder,
42     const std::shared_ptr<Graph>& graph) {
43   NodeToDebugHandle node_to_debug_handles;
44 
45   std::stack<Block*> blocks_to_visit;
46   // TODO: Look into using DepthFirstGraphNodeIterator
47   // At the moment it takes non-const graph but maybe we can make it
48   // general such that it can work with both.
49   blocks_to_visit.push(graph->block());
50   while (!blocks_to_visit.empty()) {
51     Block* b = blocks_to_visit.top();
52     blocks_to_visit.pop();
53     for (Node* n : b->nodes()) {
54       DebugHandleType debug_handle = debug_info_recorder.getNextDebugHandle(n);
55       node_to_debug_handles.emplace(n, debug_handle);
56       for (Block* subblock : n->blocks()) {
57         blocks_to_visit.push(subblock);
58       }
59     }
60   }
61   return node_to_debug_handles;
62 }
63 
64 std::unordered_map<std::string, BackendPreprocessFunction>&
backendPreprocessFunctions()65 backendPreprocessFunctions() {
66   static std::unordered_map<std::string, BackendPreprocessFunction>
67       preprocess_functions;
68   return preprocess_functions;
69 }
70 } // namespace
71 
hasBackendPreprocessFunction(const std::string & name)72 bool hasBackendPreprocessFunction(const std::string& name) {
73   return backendPreprocessFunctions().count(name);
74 }
75 
registerBackendPreprocessFunction(const std::string & name,const BackendPreprocessFunction & preprocess)76 void registerBackendPreprocessFunction(
77     const std::string& name,
78     const BackendPreprocessFunction& preprocess) {
79   TORCH_CHECK(
80       !detail::hasBackendPreprocessFunction(name),
81       "Preprocessing function for backend ",
82       name,
83       " is already registered. Ensure that registration is only called once.");
84   detail::backendPreprocessFunctions()[name] = preprocess;
85 }
86 
getBackendPreprocessFunction(const std::string & name)87 BackendPreprocessFunction getBackendPreprocessFunction(
88     const std::string& name) {
89   TORCH_CHECK(
90       hasBackendPreprocessFunction(name),
91       "Preprocessing function for backend ",
92       name,
93       " is not registered.");
94   return backendPreprocessFunctions()[name];
95 }
96 
codegen_backend_module(const std::string & backend_name,const Module & orig_module,const c10::Dict<IValue,IValue> & method_compile_spec,const c10::DictTypePtr & any_dict_ty)97 Module codegen_backend_module(
98     const std::string& backend_name,
99     const Module& orig_module,
100     const c10::Dict<IValue, IValue>& method_compile_spec,
101     const c10::DictTypePtr& any_dict_ty) {
102   const c10::QualifiedName qual_backend_name(
103       {"__torch__", "torch", "classes", kBackendsNamespace, backend_name});
104   // TODO: Validate method_compile_spec.
105 
106   // Clone orig_module to make sure backend transformation is
107   // functional.
108   auto cloned_module = orig_module.clone();
109   auto module_name = orig_module.type()->name()->qualifiedName();
110 
111   // Generate LoweredModule.
112   Module loweredModule(
113       "torch.jit.LoweredModule." + backend_name + "." + module_name,
114       std::make_shared<CompilationUnit>(),
115       /*shouldMangle=*/true);
116 
117   // Generate WrapperModule.
118   Module wrapper(
119       "torch.jit.LoweredWrapper." + backend_name + "." + module_name,
120       std::make_shared<CompilationUnit>(),
121       /*shouldMangle=*/true);
122 
123   // 1. Initialized debug info recorder.
124   // 2. Later call debug_info_recorder.stopRecording() to gather
125   //    recorded debug info and save it in __backend_debug_info.
126   BackendDebugInfoRecorder debug_info_recorder;
127 
128   // Generate attributes.
129   // This is the preprocessed module.
130   // For backwards compatibility, for backends that implement preprocessing in
131   // the backend interface rather than as a separate function, we just pass
132   // the cloned original Module.
133 
134   BackendDebugHandleGenerator debug_handle_generator =
135       [&](const std::shared_ptr<Graph>& g) {
136         return generate_debug_handles(debug_info_recorder, g);
137       };
138   loweredModule.register_attribute(
139       "__processed_module",
140       AnyType::get(),
141       detail::getBackendPreprocessFunction(backend_name)(
142           cloned_module, method_compile_spec, debug_handle_generator),
143       /*is_param=*/false);
144 
145   // This is for the method_compile_spec passed in to to_<backend> or
146   // loaded from an exported model.
147   loweredModule.register_attribute(
148       "__method_compile_spec",
149       any_dict_ty,
150       method_compile_spec,
151       /*is_param=*/false);
152 
153   // This is a pointer to a backend instance that is used to access
154   // compile and execute functions.
155   auto cls = getCustomClass(qual_backend_name.qualifiedName());
156   TORCH_INTERNAL_ASSERT(cls);
157   c10::intrusive_ptr<torch::CustomClassHolder> backend;
158   loweredModule.register_attribute(
159       "__backend", cls, IValue::make_capsule(backend));
160 
161   // This is the list of opaque backend handles returned by
162   // backend.compile.
163   loweredModule.register_attribute(
164       "__handles",
165       any_dict_ty,
166       c10::impl::GenericDict(
167           any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
168       /*is_param=*/false);
169 
170   // Methods.
171 
172   // This is a helper function for creating a new instance of the
173   // backend class.
174   static const auto create_backend_ct = at::jit::CodeTemplate(R"(
175             def __create_backend(self):
176                 self.__backend = $name()
177             )");
178   at::jit::TemplateEnv create_backend_te;
179   create_backend_te.s("name", qual_backend_name.qualifiedName());
180   loweredModule.define(
181       create_backend_ct.format(create_backend_te), loweredModuleResolver());
182 
183   // Helper function to expose backend.is_available() to Module generation code.
184   // Assumes self.__backend exists (i.e. __create_backend() has already been
185   // invoked).
186   loweredModule.define(
187       R"(
188             def __is_available(self):
189                 return self.__backend.is_available()
190             )",
191       loweredModuleResolver());
192 
193   // backend_debug_info_class is an instance of BackendDebugInfo that
194   // stores debug information.
195   // The purpose of this class is to make the debug information available
196   // at model saving time for serializing it outside of the lowered module,
197   // while still tying it to the module's lifetime (so it gets destroyed along
198   // with it).
199   // Whereas this information is not serialized as part of the lowered
200   // module, we still need to provide a valid instance of the
201   // BackendDebugInfo class when the lowered module is deserialized.
202   // Since the deserialized modules does not need this information,
203   // we create a "dummy" instance with no extra code dependencies (to avoid
204   // overhead) when the backend is created in __setstate__.
205   c10::intrusive_ptr<torch::CustomClassHolder> backend_debug_info_class;
206   const c10::QualifiedName backend_debug_info_class_name(
207       {"__torch__",
208        "torch",
209        "classes",
210        kBackendUtilsNamespace,
211        kBackendDebugInfoClass});
212   auto debug_info_cls =
213       getCustomClass(backend_debug_info_class_name.qualifiedName());
214   TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available.");
215   loweredModule.register_attribute(
216       "__backend_debug_info",
217       OptionalType::create(debug_info_cls),
218       IValue::make_capsule(backend_debug_info_class));
219   static const auto create_backend_debug_info_ct = at::jit::CodeTemplate(R"(
220             def __create_backend_debug_info(self):
221                 self.__backend_debug_info = $backend_debug_info()
222             )");
223   at::jit::TemplateEnv create_backend_debug_info_te;
224   create_backend_debug_info_te.s(
225       "backend_debug_info", backend_debug_info_class_name.qualifiedName());
226   loweredModule.define(
227       create_backend_debug_info_ct.format(create_backend_debug_info_te),
228       loweredModuleResolver());
229 
230   // getstate and setstate are for serialization/deserialization of
231   // the LoweredModule.
232   // setstate is in charge of initializing self.__backend by invoking
233   // __create_backend().
234   loweredModule.define(
235       R"(
236             def __getstate__(self):
237                 # The third parameter indicates whether __setstate__ must create
238                 # the backend instance. It's hardcoded to True since the only
239                 # case it can be false is when __setstate__ is called from
240                 # outside the module (at module creation time), because
241                 # __create_backed has been called already (also directly).
242                 return self.__method_compile_spec, self.__processed_module, True
243             )",
244       loweredModuleResolver());
245 
246   loweredModule.define(
247       R"(
248             def __setstate__(self, state):
249                 self.__method_compile_spec = state[0]
250                 self.__processed_module = state[1]
251                 # state[2] indicates whether to create the backend instance.
252                 if state[2]:
253                     self.__create_backend()
254                     self.__create_backend_debug_info()
255                 if self.__backend.is_available() :
256                     self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
257                 else:
258                     raise Exception("Backend is not available.")
259             )",
260       loweredModuleResolver());
261 
262   // This loop generates one method on the LoweredModule for every key
263   // in method_compile_spec.
264   std::vector<std::string> wrapper_methods;
265   for (auto& e : method_compile_spec) {
266     std::string method_name = e.key().toStringRef();
267     static const auto method_ct = at::jit::CodeTemplate(R"(
268             def $method(self${,def_inputs}):
269                 typed_inputs: List[Any] = [${fwd_inputs,}]
270                 if self.__backend.is_available() :
271                   $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
272                   ${refine,}
273                   return $ret
274                 else:
275                   raise Exception("Backend is not available.")
276             )");
277     static const auto wrapper_method_ct = at::jit::CodeTemplate(R"(
278             def $method(self${,def_inputs}):
279                 return self.__loweredModule__.$method(${fwd_inputs})
280             )");
281 
282     at::jit::TemplateEnv method_te, wrapper_method_te;
283     method_te.s("method", method_name);
284     wrapper_method_te.s("method", method_name);
285     auto method = orig_module.get_method(method_name);
286     auto& function = method.function();
287     auto& schema = function.getSchema();
288 
289     // Generate the inputs for the function signature (def_inputs) and
290     // for passing to backend.execute (fwd_inputs).
291     std::vector<std::string> def_inputs, fwd_inputs;
292     for (const auto& arg : schema.arguments()) {
293       auto name = arg.name();
294 
295       // Skip self since that is only and always present in the
296       // signature.
297       if (name == "self") {
298         continue;
299       }
300 
301       auto default_value = arg.default_value();
302 
303       if (arg.kwarg_only()) {
304         // If this is a kwarg, it needs to be emitted as keyword=value
305         // in the definition and keyword=keyword in the call to
306         // backend_execute.
307         TORCH_INTERNAL_ASSERT(default_value.has_value());
308         std::stringstream def_ss, fwd_ss;
309         // Annotate type of the arg
310         def_ss << name << ": " << arg.type()->annotation_str(nullptr) << "=";
311         fwd_ss << name << "=" << name;
312         default_value->repr(
313             def_ss, [](std::ostream&, const IValue&) -> bool { return false; });
314         def_inputs.emplace_back(def_ss.str());
315         fwd_inputs.emplace_back(fwd_ss.str());
316       } else {
317         // If this is not a kwarg, it should be emitted as is in the
318         // signature and the call to backend_execute.
319         std::stringstream def_ss;
320         // Annotate type of the arg
321         def_ss << name << ": " << arg.type()->annotation_str(nullptr);
322         def_inputs.emplace_back(def_ss.str());
323         fwd_inputs.emplace_back(name);
324       }
325     }
326 
327     // Generate a comma-delimited list of identifiers to unpack
328     // outputs, as well as a list of isinstance checks to make sure
329     // the backend returned the types it was supposed to.
330     std::stringstream out_ss, type_check_ss;
331     std::vector<std::string> type_checks;
332     TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
333     auto out_ty = schema.returns().at(0).type();
334 
335     out_ss << "_0";
336     type_check_ss << "assert isinstance(_0, ";
337 
338     auto out_tuple_ty = out_ty->cast<TupleType>();
339 
340     if (out_tuple_ty) {
341       auto tuple_elements = out_tuple_ty->elements();
342       type_check_ss << tuple_elements[0]->annotation_str() << ")";
343       type_checks.emplace_back(type_check_ss.str());
344       for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
345         type_check_ss.str(std::string());
346         type_check_ss.clear();
347         out_ss << ", _" << i;
348         type_check_ss << "assert isinstance(_" << i << ", "
349                       << tuple_elements[i]->annotation_str() << ")";
350         type_checks.emplace_back(type_check_ss.str());
351       }
352     } else {
353       type_check_ss << out_ty->annotation_str() << ")";
354       type_checks.emplace_back(type_check_ss.str());
355     }
356 
357     method_te.v("def_inputs", def_inputs);
358     method_te.v("fwd_inputs", fwd_inputs);
359     method_te.v("refine", type_checks);
360     method_te.s("unpack", out_ss.str());
361 
362     wrapper_method_te.v("def_inputs", def_inputs);
363     wrapper_method_te.v("fwd_inputs", fwd_inputs);
364     wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te));
365 
366     // If the output type is a single element tuple then add an extra comma
367     // to ensure the final output maintains this type.
368     if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
369       out_ss << ",";
370     }
371 
372     method_te.s("ret", out_ss.str());
373 
374     loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
375   }
376 
377   // If backend is available, call __setstate__ to ensure that the returned
378   // Module is ready to run.
379   // Otherwise throw a warning indicating that the resulting Module is not
380   // ready for execution until is loaded to a device with the backend.
381   loweredModule.run_method("__create_backend");
382   if (loweredModule.run_method("__is_available").toBool()) {
383     auto state = at::ivalue::Tuple::create(
384         method_compile_spec,
385         loweredModule.attr("__processed_module"),
386         /*create_backend*/ false);
387     loweredModule.run_method("__setstate__", state);
388   } else {
389     TORCH_WARN(
390         "Backend [",
391         backend_name,
392         "] is not available. Execution of this Module is still possible by "
393         "saving and loading on a device where the backend is available.");
394   }
395 
396   // stop debug info recording and get debug_info_map
397   auto debug_info_map = debug_info_recorder.stopRecording();
398   loweredModule.run_method("__create_backend_debug_info");
399   auto backend_debug_info = loweredModule.attr("__backend_debug_info")
400                                 .toCustomClass<PyTorchBackendDebugInfo>();
401   backend_debug_info->setDebugInfoMap(std::move(debug_info_map));
402 
403   // Wrap lowered module to obfuscate custom serialization logic
404   wrapper.register_module("__loweredModule__", loweredModule);
405   for (auto& method : wrapper_methods) {
406     wrapper.define(method);
407   }
408 
409   return wrapper;
410 }
411 } // namespace detail
412 } // namespace jit
413 } // namespace torch
414