1 #include <torch/csrc/jit/backends/backend_detail.h>
2
3 #include <ATen/code_template.h>
4 #include <ATen/core/jit_type.h>
5 #include <torch/csrc/jit/backends/backend.h>
6 #include <torch/csrc/jit/backends/backend_debug_handler.h>
7 #include <torch/csrc/jit/backends/backend_debug_info.h>
8 #include <torch/csrc/jit/backends/backend_resolver.h>
9
10 #include <memory>
11 #include <stack>
12 #include <unordered_map>
13
14 namespace torch {
15 namespace jit {
16 namespace detail {
17 namespace {
18
19 /*
20 * This is the API via which backend's preprocess function will obtain debug
21 * handles corresponding to the nodes of the graph for the lowered methods of
22 * the module.
23 * Implementation: Given graph
24 * For each node of the graph, request debug handle via debug_info_recorder.
25 * debug_info_recorder returns the next debug handle and record node with
26 * corresponding debug info, such as source range and inlined callstack.
27 *
28 * Backend code for lowering module, preprocess, calls
29 * generate_debug_handles(graph)) which will return debug handles corresponding
30 * to the Node* of the said graph.
31 *
32 * In to_backend, after lowering, stopRecording is called on
33 * BackendModuleDebugInfoRecorder: It will extract debug map. This map gets
34 * stored as part of the lowered module.
35 * During serialization, specifically for bytecode serialization, check is made
36 * to see if the model being serialized has any lowered modules. If so
37 * corresponding debug map is extracted and serialized.
38 */
39
generate_debug_handles(BackendDebugInfoRecorder & debug_info_recorder,const std::shared_ptr<Graph> & graph)40 NodeToDebugHandle generate_debug_handles(
41 BackendDebugInfoRecorder& debug_info_recorder,
42 const std::shared_ptr<Graph>& graph) {
43 NodeToDebugHandle node_to_debug_handles;
44
45 std::stack<Block*> blocks_to_visit;
46 // TODO: Look into using DepthFirstGraphNodeIterator
47 // At the moment it takes non-const graph but maybe we can make it
48 // general such that it can work with both.
49 blocks_to_visit.push(graph->block());
50 while (!blocks_to_visit.empty()) {
51 Block* b = blocks_to_visit.top();
52 blocks_to_visit.pop();
53 for (Node* n : b->nodes()) {
54 DebugHandleType debug_handle = debug_info_recorder.getNextDebugHandle(n);
55 node_to_debug_handles.emplace(n, debug_handle);
56 for (Block* subblock : n->blocks()) {
57 blocks_to_visit.push(subblock);
58 }
59 }
60 }
61 return node_to_debug_handles;
62 }
63
64 std::unordered_map<std::string, BackendPreprocessFunction>&
backendPreprocessFunctions()65 backendPreprocessFunctions() {
66 static std::unordered_map<std::string, BackendPreprocessFunction>
67 preprocess_functions;
68 return preprocess_functions;
69 }
70 } // namespace
71
hasBackendPreprocessFunction(const std::string & name)72 bool hasBackendPreprocessFunction(const std::string& name) {
73 return backendPreprocessFunctions().count(name);
74 }
75
registerBackendPreprocessFunction(const std::string & name,const BackendPreprocessFunction & preprocess)76 void registerBackendPreprocessFunction(
77 const std::string& name,
78 const BackendPreprocessFunction& preprocess) {
79 TORCH_CHECK(
80 !detail::hasBackendPreprocessFunction(name),
81 "Preprocessing function for backend ",
82 name,
83 " is already registered. Ensure that registration is only called once.");
84 detail::backendPreprocessFunctions()[name] = preprocess;
85 }
86
getBackendPreprocessFunction(const std::string & name)87 BackendPreprocessFunction getBackendPreprocessFunction(
88 const std::string& name) {
89 TORCH_CHECK(
90 hasBackendPreprocessFunction(name),
91 "Preprocessing function for backend ",
92 name,
93 " is not registered.");
94 return backendPreprocessFunctions()[name];
95 }
96
codegen_backend_module(const std::string & backend_name,const Module & orig_module,const c10::Dict<IValue,IValue> & method_compile_spec,const c10::DictTypePtr & any_dict_ty)97 Module codegen_backend_module(
98 const std::string& backend_name,
99 const Module& orig_module,
100 const c10::Dict<IValue, IValue>& method_compile_spec,
101 const c10::DictTypePtr& any_dict_ty) {
102 const c10::QualifiedName qual_backend_name(
103 {"__torch__", "torch", "classes", kBackendsNamespace, backend_name});
104 // TODO: Validate method_compile_spec.
105
106 // Clone orig_module to make sure backend transformation is
107 // functional.
108 auto cloned_module = orig_module.clone();
109 auto module_name = orig_module.type()->name()->qualifiedName();
110
111 // Generate LoweredModule.
112 Module loweredModule(
113 "torch.jit.LoweredModule." + backend_name + "." + module_name,
114 std::make_shared<CompilationUnit>(),
115 /*shouldMangle=*/true);
116
117 // Generate WrapperModule.
118 Module wrapper(
119 "torch.jit.LoweredWrapper." + backend_name + "." + module_name,
120 std::make_shared<CompilationUnit>(),
121 /*shouldMangle=*/true);
122
123 // 1. Initialized debug info recorder.
124 // 2. Later call debug_info_recorder.stopRecording() to gather
125 // recorded debug info and save it in __backend_debug_info.
126 BackendDebugInfoRecorder debug_info_recorder;
127
128 // Generate attributes.
129 // This is the preprocessed module.
130 // For backwards compatibility, for backends that implement preprocessing in
131 // the backend interface rather than as a separate function, we just pass
132 // the cloned original Module.
133
134 BackendDebugHandleGenerator debug_handle_generator =
135 [&](const std::shared_ptr<Graph>& g) {
136 return generate_debug_handles(debug_info_recorder, g);
137 };
138 loweredModule.register_attribute(
139 "__processed_module",
140 AnyType::get(),
141 detail::getBackendPreprocessFunction(backend_name)(
142 cloned_module, method_compile_spec, debug_handle_generator),
143 /*is_param=*/false);
144
145 // This is for the method_compile_spec passed in to to_<backend> or
146 // loaded from an exported model.
147 loweredModule.register_attribute(
148 "__method_compile_spec",
149 any_dict_ty,
150 method_compile_spec,
151 /*is_param=*/false);
152
153 // This is a pointer to a backend instance that is used to access
154 // compile and execute functions.
155 auto cls = getCustomClass(qual_backend_name.qualifiedName());
156 TORCH_INTERNAL_ASSERT(cls);
157 c10::intrusive_ptr<torch::CustomClassHolder> backend;
158 loweredModule.register_attribute(
159 "__backend", cls, IValue::make_capsule(backend));
160
161 // This is the list of opaque backend handles returned by
162 // backend.compile.
163 loweredModule.register_attribute(
164 "__handles",
165 any_dict_ty,
166 c10::impl::GenericDict(
167 any_dict_ty->getKeyType(), any_dict_ty->getValueType()),
168 /*is_param=*/false);
169
170 // Methods.
171
172 // This is a helper function for creating a new instance of the
173 // backend class.
174 static const auto create_backend_ct = at::jit::CodeTemplate(R"(
175 def __create_backend(self):
176 self.__backend = $name()
177 )");
178 at::jit::TemplateEnv create_backend_te;
179 create_backend_te.s("name", qual_backend_name.qualifiedName());
180 loweredModule.define(
181 create_backend_ct.format(create_backend_te), loweredModuleResolver());
182
183 // Helper function to expose backend.is_available() to Module generation code.
184 // Assumes self.__backend exists (i.e. __create_backend() has already been
185 // invoked).
186 loweredModule.define(
187 R"(
188 def __is_available(self):
189 return self.__backend.is_available()
190 )",
191 loweredModuleResolver());
192
193 // backend_debug_info_class is an instance of BackendDebugInfo that
194 // stores debug information.
195 // The purpose of this class is to make the debug information available
196 // at model saving time for serializing it outside of the lowered module,
197 // while still tying it to the module's lifetime (so it gets destroyed along
198 // with it).
199 // Whereas this information is not serialized as part of the lowered
200 // module, we still need to provide a valid instance of the
201 // BackendDebugInfo class when the lowered module is deserialized.
202 // Since the deserialized modules does not need this information,
203 // we create a "dummy" instance with no extra code dependencies (to avoid
204 // overhead) when the backend is created in __setstate__.
205 c10::intrusive_ptr<torch::CustomClassHolder> backend_debug_info_class;
206 const c10::QualifiedName backend_debug_info_class_name(
207 {"__torch__",
208 "torch",
209 "classes",
210 kBackendUtilsNamespace,
211 kBackendDebugInfoClass});
212 auto debug_info_cls =
213 getCustomClass(backend_debug_info_class_name.qualifiedName());
214 TORCH_CHECK(debug_info_cls, "BackendDebugInfo class must be available.");
215 loweredModule.register_attribute(
216 "__backend_debug_info",
217 OptionalType::create(debug_info_cls),
218 IValue::make_capsule(backend_debug_info_class));
219 static const auto create_backend_debug_info_ct = at::jit::CodeTemplate(R"(
220 def __create_backend_debug_info(self):
221 self.__backend_debug_info = $backend_debug_info()
222 )");
223 at::jit::TemplateEnv create_backend_debug_info_te;
224 create_backend_debug_info_te.s(
225 "backend_debug_info", backend_debug_info_class_name.qualifiedName());
226 loweredModule.define(
227 create_backend_debug_info_ct.format(create_backend_debug_info_te),
228 loweredModuleResolver());
229
230 // getstate and setstate are for serialization/deserialization of
231 // the LoweredModule.
232 // setstate is in charge of initializing self.__backend by invoking
233 // __create_backend().
234 loweredModule.define(
235 R"(
236 def __getstate__(self):
237 # The third parameter indicates whether __setstate__ must create
238 # the backend instance. It's hardcoded to True since the only
239 # case it can be false is when __setstate__ is called from
240 # outside the module (at module creation time), because
241 # __create_backed has been called already (also directly).
242 return self.__method_compile_spec, self.__processed_module, True
243 )",
244 loweredModuleResolver());
245
246 loweredModule.define(
247 R"(
248 def __setstate__(self, state):
249 self.__method_compile_spec = state[0]
250 self.__processed_module = state[1]
251 # state[2] indicates whether to create the backend instance.
252 if state[2]:
253 self.__create_backend()
254 self.__create_backend_debug_info()
255 if self.__backend.is_available() :
256 self.__handles = self.__backend.compile(self.__processed_module, self.__method_compile_spec)
257 else:
258 raise Exception("Backend is not available.")
259 )",
260 loweredModuleResolver());
261
262 // This loop generates one method on the LoweredModule for every key
263 // in method_compile_spec.
264 std::vector<std::string> wrapper_methods;
265 for (auto& e : method_compile_spec) {
266 std::string method_name = e.key().toStringRef();
267 static const auto method_ct = at::jit::CodeTemplate(R"(
268 def $method(self${,def_inputs}):
269 typed_inputs: List[Any] = [${fwd_inputs,}]
270 if self.__backend.is_available() :
271 $unpack, = self.__backend.execute(self.__handles["$method"], typed_inputs)
272 ${refine,}
273 return $ret
274 else:
275 raise Exception("Backend is not available.")
276 )");
277 static const auto wrapper_method_ct = at::jit::CodeTemplate(R"(
278 def $method(self${,def_inputs}):
279 return self.__loweredModule__.$method(${fwd_inputs})
280 )");
281
282 at::jit::TemplateEnv method_te, wrapper_method_te;
283 method_te.s("method", method_name);
284 wrapper_method_te.s("method", method_name);
285 auto method = orig_module.get_method(method_name);
286 auto& function = method.function();
287 auto& schema = function.getSchema();
288
289 // Generate the inputs for the function signature (def_inputs) and
290 // for passing to backend.execute (fwd_inputs).
291 std::vector<std::string> def_inputs, fwd_inputs;
292 for (const auto& arg : schema.arguments()) {
293 auto name = arg.name();
294
295 // Skip self since that is only and always present in the
296 // signature.
297 if (name == "self") {
298 continue;
299 }
300
301 auto default_value = arg.default_value();
302
303 if (arg.kwarg_only()) {
304 // If this is a kwarg, it needs to be emitted as keyword=value
305 // in the definition and keyword=keyword in the call to
306 // backend_execute.
307 TORCH_INTERNAL_ASSERT(default_value.has_value());
308 std::stringstream def_ss, fwd_ss;
309 // Annotate type of the arg
310 def_ss << name << ": " << arg.type()->annotation_str(nullptr) << "=";
311 fwd_ss << name << "=" << name;
312 default_value->repr(
313 def_ss, [](std::ostream&, const IValue&) -> bool { return false; });
314 def_inputs.emplace_back(def_ss.str());
315 fwd_inputs.emplace_back(fwd_ss.str());
316 } else {
317 // If this is not a kwarg, it should be emitted as is in the
318 // signature and the call to backend_execute.
319 std::stringstream def_ss;
320 // Annotate type of the arg
321 def_ss << name << ": " << arg.type()->annotation_str(nullptr);
322 def_inputs.emplace_back(def_ss.str());
323 fwd_inputs.emplace_back(name);
324 }
325 }
326
327 // Generate a comma-delimited list of identifiers to unpack
328 // outputs, as well as a list of isinstance checks to make sure
329 // the backend returned the types it was supposed to.
330 std::stringstream out_ss, type_check_ss;
331 std::vector<std::string> type_checks;
332 TORCH_INTERNAL_ASSERT(schema.returns().size() == 1);
333 auto out_ty = schema.returns().at(0).type();
334
335 out_ss << "_0";
336 type_check_ss << "assert isinstance(_0, ";
337
338 auto out_tuple_ty = out_ty->cast<TupleType>();
339
340 if (out_tuple_ty) {
341 auto tuple_elements = out_tuple_ty->elements();
342 type_check_ss << tuple_elements[0]->annotation_str() << ")";
343 type_checks.emplace_back(type_check_ss.str());
344 for (unsigned i = 1, e = tuple_elements.size(); i < e; ++i) {
345 type_check_ss.str(std::string());
346 type_check_ss.clear();
347 out_ss << ", _" << i;
348 type_check_ss << "assert isinstance(_" << i << ", "
349 << tuple_elements[i]->annotation_str() << ")";
350 type_checks.emplace_back(type_check_ss.str());
351 }
352 } else {
353 type_check_ss << out_ty->annotation_str() << ")";
354 type_checks.emplace_back(type_check_ss.str());
355 }
356
357 method_te.v("def_inputs", def_inputs);
358 method_te.v("fwd_inputs", fwd_inputs);
359 method_te.v("refine", type_checks);
360 method_te.s("unpack", out_ss.str());
361
362 wrapper_method_te.v("def_inputs", def_inputs);
363 wrapper_method_te.v("fwd_inputs", fwd_inputs);
364 wrapper_methods.push_back(wrapper_method_ct.format(wrapper_method_te));
365
366 // If the output type is a single element tuple then add an extra comma
367 // to ensure the final output maintains this type.
368 if (out_tuple_ty && out_tuple_ty->elements().size() == 1) {
369 out_ss << ",";
370 }
371
372 method_te.s("ret", out_ss.str());
373
374 loweredModule.define(method_ct.format(method_te), loweredModuleResolver());
375 }
376
377 // If backend is available, call __setstate__ to ensure that the returned
378 // Module is ready to run.
379 // Otherwise throw a warning indicating that the resulting Module is not
380 // ready for execution until is loaded to a device with the backend.
381 loweredModule.run_method("__create_backend");
382 if (loweredModule.run_method("__is_available").toBool()) {
383 auto state = at::ivalue::Tuple::create(
384 method_compile_spec,
385 loweredModule.attr("__processed_module"),
386 /*create_backend*/ false);
387 loweredModule.run_method("__setstate__", state);
388 } else {
389 TORCH_WARN(
390 "Backend [",
391 backend_name,
392 "] is not available. Execution of this Module is still possible by "
393 "saving and loading on a device where the backend is available.");
394 }
395
396 // stop debug info recording and get debug_info_map
397 auto debug_info_map = debug_info_recorder.stopRecording();
398 loweredModule.run_method("__create_backend_debug_info");
399 auto backend_debug_info = loweredModule.attr("__backend_debug_info")
400 .toCustomClass<PyTorchBackendDebugInfo>();
401 backend_debug_info->setDebugInfoMap(std::move(debug_info_map));
402
403 // Wrap lowered module to obfuscate custom serialization logic
404 wrapper.register_module("__loweredModule__", loweredModule);
405 for (auto& method : wrapper_methods) {
406 wrapper.define(method);
407 }
408
409 return wrapper;
410 }
411 } // namespace detail
412 } // namespace jit
413 } // namespace torch
414