xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/serialization/export_bytecode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/serialization/export_bytecode.h>
2 #include <utility>
3 
4 #include <torch/csrc/jit/operator_upgraders/version_map.h>
5 #include <torch/csrc/jit/runtime/instruction.h>
6 #include <torch/csrc/jit/serialization/export.h>
7 
8 #include <c10/util/Exception.h>
9 #include <torch/csrc/jit/api/function_impl.h>
10 #include <torch/csrc/jit/api/method.h>
11 #include <torch/csrc/jit/backends/backend_debug_handler.h>
12 #include <torch/csrc/jit/backends/backend_debug_info.h>
13 #include <torch/csrc/jit/frontend/source_range.h>
14 #include <torch/csrc/jit/ir/attributes.h>
15 #include <torch/csrc/jit/ir/ir.h>
16 #include <torch/csrc/jit/ir/type_hashing.h>
17 #include <torch/csrc/jit/mobile/function.h>
18 #include <torch/csrc/jit/mobile/interpreter.h>
19 #include <torch/csrc/jit/mobile/method.h>
20 #include <torch/csrc/jit/mobile/module.h>
21 #include <torch/csrc/jit/passes/inliner.h>
22 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
23 #include <torch/csrc/jit/serialization/import_export_constants.h>
24 #include <torch/csrc/jit/serialization/import_export_functions.h>
25 #include <torch/csrc/jit/serialization/import_export_helpers.h>
26 #include <torch/csrc/jit/serialization/pickle.h>
27 #include <torch/csrc/jit/serialization/python_print.h>
28 #include <torch/csrc/jit/serialization/source_range_serialization.h>
29 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
30 
31 #include <caffe2/serialize/inline_container.h>
32 
33 namespace torch::jit {
34 
gatherGetSetStates(const ObjectPtr & obj)35 static std::vector<Method> gatherGetSetStates(const ObjectPtr& obj) {
36   std::vector<Method> methods;
37   // Use DFS on IValue's to traverse dependencies of module._ivalue and
38   // add all setstate/getstates to initial stack.
39   std::vector<ObjectPtr> ivalue_stack;
40   ivalue_stack.emplace_back(obj);
41   while (!ivalue_stack.empty()) {
42     ObjectPtr cur = ivalue_stack.back();
43     ivalue_stack.pop_back();
44     auto type = cur->type();
45     Function* setstate = type->findMethod("__setstate__");
46     Function* getstate = type->findMethod("__getstate__");
47     if (getstate && setstate) {
48       if (setstate->isGraphFunction()) {
49         methods.emplace_back(cur, setstate);
50       }
51       if (getstate->isGraphFunction()) {
52         methods.emplace_back(cur, getstate);
53       }
54     } else {
55       for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
56         IValue field = cur->getSlot(i);
57         if (field.isObject()) {
58           ivalue_stack.emplace_back(field.toObject());
59         }
60       }
61     }
62   }
63   return methods;
64 }
65 
findAllDependentFunctions(const Module & module,Graph & graph)66 static std::vector<Method> findAllDependentFunctions(
67     const Module& module,
68     Graph& graph) {
69   std::vector<Method> methods;
70   std::unordered_set<c10::string_view> called_method_names;
71   auto nodes = findAllNodes(graph, c10::prim::CallMethod, true);
72   for (Node* node : nodes) {
73     if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) {
74       const FunctionSchema* schema = iface->getMethod(node->s(attr::name));
75       called_method_names.insert(schema->name());
76     }
77   }
78 
79   for (const auto& submodule : module.modules()) {
80     for (const auto& m : submodule.get_methods()) {
81       if (called_method_names.find(m.function().qualname().name()) !=
82           called_method_names.end()) {
83         methods.emplace_back(m);
84       }
85     }
86   }
87   return methods;
88 }
89 
90 // NOTE: order of functions returned will be:
91 // 1. functions originated from the methods passed in will be first
92 // 2. All the dependent functions will come afterwards.
93 // This order is meaningful because currently mobile Module looks up
94 // methods with linear search.
inlineFunctions(const std::vector<Method> & initial_methods,bool incl_dependent_functions)95 static std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
96     const std::vector<Method>& initial_methods,
97     bool incl_dependent_functions) {
98   std::set<std::pair<std::string, Function*>> visited;
99   std::deque<Method> stack;
100   std::copy(
101       initial_methods.begin(),
102       initial_methods.end(),
103       std::back_inserter(stack));
104   std::vector<std::unique_ptr<GraphFunction>> inlined_functions;
105   while (!stack.empty()) {
106     Method cur = stack.front();
107     stack.pop_front();
108     auto tup = std::make_pair(
109         cur.owner()._ivalue()->type()->name()->qualifiedName(),
110         &cur.function());
111     if (visited.find(tup) != visited.end()) {
112       continue;
113     }
114     visited.insert(tup);
115     const auto& f = toGraphFunction(cur.function());
116     auto graph = f.graph()->copyUnique();
117     Inline(*graph);
118     c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name());
119 
120     if (incl_dependent_functions) {
121       std::vector<Method> dependent_methods =
122           findAllDependentFunctions(cur.owner(), *graph);
123       std::copy(
124           dependent_methods.begin(),
125           dependent_methods.end(),
126           std::back_inserter(stack));
127     }
128     auto inlined_func = std::make_unique<GraphFunction>(
129         qn, std::move(graph), f.function_creator());
130     inlined_func->setSchema(f.getSchema());
131     inlined_functions.emplace_back(std::move(inlined_func));
132   }
133   return inlined_functions;
134 }
135 
compileGraphToMobileCode(const std::string & name,const std::shared_ptr<Graph> & graph,const CompilationOptions & compilation_options,BackendDebugInfoRecorder & debug_info_recorder)136 mobile::Code compileGraphToMobileCode(
137     const std::string& name,
138     const std::shared_ptr<Graph>& graph,
139     const CompilationOptions& compilation_options,
140     BackendDebugInfoRecorder& debug_info_recorder) {
141   MobileCode code(
142       graph,
143       name,
144       compilation_options.enable_default_value_for_unspecified_arg,
145       compilation_options.enable_default_args_before_out_args,
146       compilation_options.enable_emit_promoted_ops);
147 
148   mobile::Code mobile_code;
149 
150   // operator names
151   std::vector<std::string> method_names;
152   std::vector<int64_t> op_debug_handles;
153   int next_new_op_index = 0;
154 
155   auto op_to_specified_args = code.op_to_num_specified_args();
156 
157   for (size_t i = 0; i < code.instructions().size(); ++i) {
158     Instruction ins = code.instructions()[i];
159 
160     if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
161       // Found a new op (assumes new operators ordered by ascending ins.X)
162       auto node = code.instructions_source()[i];
163       const c10::OperatorName& opname = node->schema().operator_name();
164       auto unique_name = c10::toString(opname);
165       // For operator with vararg, adding default arguments would be confusing
166       // and is not allowed. For an operator with num_args = -1, it means the
167       // number of arguments is not available for this operator, we don't do any
168       // backward compatibility adaptation at runtime.
169       std::optional<int> num_args = std::nullopt;
170       auto it = op_to_specified_args.find(unique_name);
171       if (it != op_to_specified_args.end()) {
172         num_args = it->second;
173       }
174       mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1));
175       mobile_code.op_names_.emplace_back(opname);
176       auto func = mobile::makeOperatorFunction(opname, num_args);
177       TORCH_INTERNAL_ASSERT(
178           func.has_value(),
179           "Operator with name: ",
180           toString(opname),
181           " not found");
182       mobile_code.operators_.emplace_back(*func);
183       next_new_op_index++;
184     }
185     // CALL nodes at this point represent built-in (i.e. non-Graph)
186     // functions that were not inlined. Here we convert the CALL
187     // instructions for these functions into INTERFACE_CALL instructions
188     // s.t. at runtime, we will look up the Function* on the Type of the
189     // 0th argument in the stack and call that directly.
190     if (ins.op == CALL) {
191       auto node = code.instructions_source()[i];
192       if (node->kind() == prim::CallMethod) {
193         // NB: replacing instruction
194         auto method_name_idx =
195             code.constant_table().size() + method_names.size();
196         method_names.emplace_back(node->s(attr::name));
197         ins = Instruction{
198             INTERFACE_CALL,
199             static_cast<int32_t>(method_name_idx),
200             static_cast<uint16_t>(node->inputs().size())};
201       } else {
202         TORCH_INTERNAL_ASSERT(
203             false, "Unsupported node kind on CALL opcode for mobile");
204       }
205     } else if (ins.op == RET) {
206       auto node = code.instructions_source()[i];
207       for (const auto& input : node->inputs()) {
208         const auto& input_type = input->type();
209         if (input_type->kind() == TypeKind::ListType ||
210             input_type->kind() == TypeKind::DictType) {
211           for (const TypePtr& element_type : input_type->containedTypes()) {
212             TORCH_CHECK(
213                 element_type->kind() != TypeKind::ClassType,
214                 "Returning a list or dictionary with pytorch class type ",
215                 "is not supported in mobile module "
216                 "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
217                 "Workaround: instead of using pytorch class as their element type, ",
218                 "use a combination of list, dictionary, and single types.");
219           }
220         }
221       }
222     } else {
223       TORCH_CHECK(
224           isOpSupportedInMobile(ins.op),
225           toString(ins.op),
226           " is not supported in mobile module.");
227     }
228     auto node = code.instructions_source()[i];
229     int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
230     // Note 1-to-1 correspondence between instructions and debug handles
231     mobile_code.instructions_.emplace_back(ins);
232     mobile_code.debug_handles_.emplace_back(debug_handle);
233   }
234 
235   // copy constants
236   mobile_code.constants_ = code.constant_table();
237 
238   // Make a copy of the constants and append the method names
239   // that we emitted for the converted INTERFACE_CALL nodes above.
240   for (auto& method_name : method_names) {
241     mobile_code.constants_.emplace_back(method_name);
242   }
243 
244   mobile_code.types_ = code.type_table();
245   mobile_code.register_size_ = code.register_size();
246   return mobile_code;
247 }
248 
convertJitFunctionToMobileFunction(const GraphFunction & function,const CompilationOptions & options)249 std::unique_ptr<mobile::Function> convertJitFunctionToMobileFunction(
250     const GraphFunction& function,
251     const CompilationOptions& options) {
252   BackendDebugInfoRecorder debug_handle;
253   auto mobileCode = compileGraphToMobileCode(
254       function.name(), function.graph(), options, debug_handle);
255   const auto& schema = function.getSchema();
256   return std::make_unique<mobile::Function>(
257       function.qualname(), std::move(mobileCode), schema);
258 }
259 
convertMobileFunctionToCodeTable(const mobile::Function & func,const CompilationOptions & compilation_options)260 IValue convertMobileFunctionToCodeTable(
261     const mobile::Function& func,
262     const CompilationOptions& compilation_options) {
263   auto code = func.get_code();
264   std::vector<IValue> instructions;
265   instructions.reserve(code.instructions_.size());
266   for (Instruction ins : code.instructions_) {
267     instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
268   }
269 
270   std::vector<IValue> operators;
271   operators.reserve(code.op_names_.size());
272   for (unsigned i = 0; i < code.op_names_.size(); ++i) {
273     const auto& opname = code.op_names_[i];
274     const int size = code.operator_input_sizes_[i];
275     if (compilation_options.enable_default_value_for_unspecified_arg) {
276       operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
277     } else {
278       operators.emplace_back(
279           to_tuple({opname.name, opname.overload_name, size}));
280     }
281   }
282 
283   std::vector<IValue> types;
284   for (const TypePtr& t : code.types_) {
285     std::string type_str = t->annotation_str();
286     types.emplace_back(type_str);
287   }
288 
289   auto register_size = static_cast<int>(code.register_size_);
290   auto codeTable = Table(
291       {{"instructions", to_tuple(instructions)},
292        {"operators", to_tuple(operators)},
293        {"constants", to_tuple(code.constants_)},
294        {"types", to_tuple(types)},
295        {"register_size", register_size}});
296 
297   return codeTable;
298 }
299 
checkSchema(const c10::FunctionSchema & schema)300 static void checkSchema(const c10::FunctionSchema& schema) {
301   TORCH_CHECK(
302       schema.overload_name().empty(), // @TODO: is this check correct?
303       "Overloads are not supported in mobile modules.");
304   TORCH_CHECK(
305       !schema.is_vararg(), "Python *args are not supported in mobile modules.");
306   TORCH_CHECK(
307       !schema.is_varret(),
308       "A variable number of return values is not supported in mobile modules.");
309 }
310 
isLoweredModule(const Module & m)311 static bool isLoweredModule(const Module& m) {
312   c10::QualifiedName type_name;
313   if (m.type()->name()) {
314     type_name = m.type()->name().value();
315   }
316   bool isLoweredModule = false;
317   for (const auto& atom : type_name.atoms()) {
318     if (atom == "LoweredModule") {
319       isLoweredModule = true;
320       break;
321     }
322   }
323   return isLoweredModule;
324 }
325 
326 // Check if the global static map of backend debug info
327 // contains debug info for this module and any of its children.
328 // If so combine all the maps together and return one.
getBackendDebugInfoMap(const Module & m,BackendDebugInfoMapType & debug_map)329 static void getBackendDebugInfoMap(
330     const Module& m,
331     BackendDebugInfoMapType& debug_map) {
332   if (isLoweredModule(m)) {
333     auto backend_debug_info =
334         m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
335     const auto& map = backend_debug_info->getDebugInfoMap();
336     if (map) {
337       debug_map.insert(map.value().begin(), map.value().end());
338     }
339   }
340   for (const auto& c : m.children()) {
341     getBackendDebugInfoMap(c, debug_map);
342   }
343 }
344 
get_min_operator_version_from_version_map(const mobile::Module & module)345 static uint64_t get_min_operator_version_from_version_map(
346     const mobile::Module& module) {
347   uint64_t min_version = caffe2::serialize::kMinSupportedFileFormatVersion;
348   for (const auto& func : module.compilation_unit().methods()) {
349     for (const auto& op_name : func->get_code().op_names_) {
350       auto schema_name = op_name.overload_name.empty()
351           ? op_name.name
352           : op_name.name + "." + op_name.overload_name;
353       auto version_entry = get_operator_version_map().find(schema_name);
354       if (version_entry != get_operator_version_map().end()) {
355         const auto& entry = version_entry->second;
356         min_version = std::max(
357             min_version, uint64_t(entry[entry.size() - 1].bumped_at_version));
358       }
359     }
360   }
361   return min_version;
362 }
363 
jitModuleToMobile(const Module & module,const CompilationOptions & options)364 mobile::Module jitModuleToMobile(
365     const Module& module,
366     const CompilationOptions& options) {
367   std::shared_ptr<mobile::CompilationUnit> mcu =
368       std::make_shared<mobile::CompilationUnit>();
369   BackendDebugInfoRecorder debug_info_recorder;
370 
371   std::vector<Method> methods_to_export = module.get_methods();
372   std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue());
373   std::copy(
374       getsetstates.begin(),
375       getsetstates.end(),
376       std::back_inserter(methods_to_export));
377 
378   for (const auto& func :
379        inlineFunctions(methods_to_export, options.incl_interface_call)) {
380     auto mobile_code = compileGraphToMobileCode(
381         func->name(), func->graph(), options, debug_info_recorder);
382     const auto& schema = func->getSchema();
383     checkSchema(schema);
384     auto mobile_func = std::make_unique<mobile::Function>(
385         func->qualname(), std::move(mobile_code), schema);
386     mcu->register_function(std::move(mobile_func));
387   }
388 
389   mobile::Module m(module._ivalue(), mcu);
390   m.setHasDebugHandles(true);
391   BackendDebugInfoMapType backend_debug_info_map;
392   getBackendDebugInfoMap(module, backend_debug_info_map);
393   auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
394   debug_handle_cs_ptr_map.insert(
395       backend_debug_info_map.begin(), backend_debug_info_map.end());
396   m.setDebugTable(MobileDebugTable(
397       debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
398   m.set_min_operator_version(
399       static_cast<int64_t>(get_min_operator_version_from_version_map(m)));
400   m.set_bytecode_version(options.model_version);
401   return m;
402 }
403 
404 } // namespace torch::jit
405