xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/api/module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/symbol.h>
2 #include <ATen/record_function.h>
3 #include <c10/util/Exception.h>
4 #include <c10/util/StringUtil.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/function_impl.h>
8 #include <torch/csrc/jit/api/module.h>
9 #include <torch/csrc/jit/frontend/error_report.h>
10 #include <torch/csrc/jit/frontend/ir_emitter.h>
11 #include <torch/csrc/jit/frontend/schema_matching.h>
12 #include <torch/csrc/jit/jit_log.h>
13 #include <torch/csrc/jit/passes/dead_code_elimination.h>
14 #include <torch/csrc/jit/passes/freeze_module.h>
15 #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
16 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
17 #include <torch/csrc/jit/passes/frozen_linear_transpose.h>
18 #include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
19 #include <torch/csrc/jit/passes/inliner.h>
20 #include <torch/csrc/jit/runtime/operator.h>
21 
22 #include <iostream>
23 
24 namespace torch::jit {
25 
26 namespace {
27 
getInputDebugName(const Node & n,const int idx)28 std::string getInputDebugName(const Node& n, const int idx) {
29   return n.inputs().at(idx)->debugName();
30 }
31 
assert_ignored_methods_not_called(torch::jit::Function & fn,const std::unordered_set<std::string> & ignored_methods)32 void assert_ignored_methods_not_called(
33     torch::jit::Function& fn,
34     const std::unordered_set<std::string>& ignored_methods) {
35   if (ignored_methods.empty()) {
36     return;
37   }
38   const bool recurse = true;
39   std::vector<Node*> all_nodes = findAllNodes(
40       *toGraphFunction(fn).graph(), c10::prim::CallMethod, recurse);
41 
42   // Extract method names from these nodes.
43   std::unordered_set<std::string> encountered_ignored_methods;
44 
45   for (Node* n : all_nodes) {
46     if (ignored_methods.count(n->s(attr::name)) > 0 &&
47         getInputDebugName(*n, 0) == "self") {
48       encountered_ignored_methods.insert(
49           getInputDebugName(*n, 0) + "." + n->s(attr::name));
50     }
51   }
52   if (encountered_ignored_methods.empty()) {
53     return;
54   }
55 
56   const std::string encountered_ignored_methods_str =
57       c10::Join(", ", encountered_ignored_methods);
58 
59   TORCH_CHECK(
60       false,
61       "Preserved method '",
62       fn.name(),
63       "' references ignored method(s) '",
64       encountered_ignored_methods_str,
65       "'. This is not permitted.");
66 }
67 
assert_ignored_attributes_not_referenced(torch::jit::Function & fn,const std::unordered_set<std::string> & ignored_attributes)68 void assert_ignored_attributes_not_referenced(
69     torch::jit::Function& fn,
70     const std::unordered_set<std::string>& ignored_attributes) {
71   if (ignored_attributes.empty()) {
72     return;
73   }
74 
75   const bool recurse = true;
76   std::vector<Node*> all_nodes =
77       findAllNodes(*toGraphFunction(fn).graph(), c10::prim::GetAttr, recurse);
78 
79   // Extract attribute names from these nodes.
80   std::unordered_set<std::string> encountered_ignored_attributes;
81 
82   for (Node* n : all_nodes) {
83     if (ignored_attributes.count(n->s(attr::name)) > 0 &&
84         getInputDebugName(*n, 0) == "self") {
85       encountered_ignored_attributes.insert(
86           getInputDebugName(*n, 0) + "." + n->s(attr::name));
87     }
88   }
89   if (encountered_ignored_attributes.empty()) {
90     return;
91   }
92 
93   const std::string encountered_ignored_attributes_str =
94       c10::Join(", ", encountered_ignored_attributes);
95 
96   TORCH_CHECK(
97       false,
98       "Preserved method '",
99       fn.name(),
100       "' references ignored attribute(s) '",
101       encountered_ignored_attributes_str,
102       "'. This is not permitted.");
103 }
104 
105 } // namespace
106 
create_module_object(c10::QualifiedName class_name,std::shared_ptr<CompilationUnit> cu,bool shouldMangle=false)107 static ObjectPtr create_module_object(
108     c10::QualifiedName class_name,
109     std::shared_ptr<CompilationUnit> cu,
110     bool shouldMangle = false) {
111   // If the name is unqualified, prepend a `__torch__`, similar to what Python
112   // does with `__main__` for top-level code.
113   if (class_name.prefix().empty()) {
114     class_name = c10::QualifiedName("__torch__", class_name.name());
115   }
116   if (shouldMangle && cu->get_class(class_name) != nullptr) {
117     class_name = cu->mangle(class_name);
118   }
119   auto cls = ClassType::create(std::move(class_name), cu, /*is_module=*/true);
120   cu->register_type(cls);
121   return c10::ivalue::Object::create(
122       c10::StrongTypePtr(std::move(cu), std::move(cls)), 0);
123 }
124 
Module(c10::QualifiedName class_name)125 Module::Module(c10::QualifiedName class_name)
126     : Object(create_module_object(
127           std::move(class_name),
128           std::make_shared<CompilationUnit>())) {}
129 
Module(std::shared_ptr<CompilationUnit> cu,const c10::ClassTypePtr & type)130 Module::Module(
131     std::shared_ptr<CompilationUnit> cu,
132     const c10::ClassTypePtr& type)
133     : Object(c10::ivalue::Object::create(
134           c10::StrongTypePtr(std::move(cu), type),
135           type->numAttributes())) {}
136 
Module(c10::QualifiedName class_name,std::shared_ptr<CompilationUnit> cu,bool shouldMangle)137 Module::Module(
138     c10::QualifiedName class_name,
139     std::shared_ptr<CompilationUnit> cu,
140     bool shouldMangle)
141     : Object(create_module_object(
142           std::move(class_name),
143           std::move(cu),
144           shouldMangle)) {}
145 
146 // first class mode runs models as first class objects,
147 // and does not force inlining everywhere. This is experimental
148 // as we bring up the system since it will degrade performance
149 // and may introduce bugs. test_jit.py provides context managers
150 // that enable it for specific tests.
151 thread_local bool inline_everything = false;
getInlineEverythingMode()152 bool& getInlineEverythingMode() {
153   return inline_everything;
154 }
155 
to(at::Device device,at::ScalarType dtype,bool non_blocking)156 void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
157   to_impl(device, dtype, non_blocking);
158 }
159 
to(at::ScalarType dtype,bool non_blocking)160 void Module::to(at::ScalarType dtype, bool non_blocking) {
161   to_impl(/*device=*/std::nullopt, dtype, non_blocking);
162 }
163 
to(at::Device device,bool non_blocking)164 void Module::to(at::Device device, bool non_blocking) {
165   to_impl(device, /*dtype=*/std::nullopt, non_blocking);
166 }
167 
module_state_to(const autograd::Variable & variable,const std::optional<at::Device> & device,const std::optional<at::ScalarType> & dtype,bool non_blocking)168 static void module_state_to(
169     const autograd::Variable& variable,
170     const std::optional<at::Device>& device,
171     const std::optional<at::ScalarType>& dtype,
172     bool non_blocking) {
173   // Need to access the `at::Tensor` as a `Variable` here.
174   // Use the data's original device or dtype if not supplied here.
175   auto new_data = variable.to(
176       device.value_or(variable.device()),
177       dtype.value_or(variable.scalar_type()),
178       non_blocking);
179   variable.set_data(new_data);
180 }
181 
to_impl(const std::optional<at::Device> & device,const std::optional<at::ScalarType> & dtype,bool non_blocking)182 void Module::to_impl(
183     const std::optional<at::Device>& device,
184     const std::optional<at::ScalarType>& dtype,
185     bool non_blocking) {
186   for (at::Tensor e : parameters()) {
187     module_state_to(e, device, dtype, non_blocking);
188   }
189   for (at::Tensor e : buffers()) {
190     module_state_to(e, device, dtype, non_blocking);
191   }
192 }
193 
Method(ModulePtr owner,Function * function)194 Method::Method(ModulePtr owner, Function* function)
195     : owner_(std::move(owner)), function_(function) {}
196 
owner() const197 Module Method::owner() const {
198   return Module(owner_);
199 }
raw_owner() const200 ObjectPtr Method::raw_owner() const {
201   return owner_;
202 }
run(Stack & stack)203 void Method::run(Stack& stack) {
204   stack.insert(stack.begin(), owner()._ivalue()); // self
205   RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
206   function_->run(stack);
207 }
208 
operator ()(std::vector<IValue> stack,const Kwargs & kwargs) const209 IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs)
210     const {
211   stack.insert(stack.begin(), owner()._ivalue()); // self
212   RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
213   return (*function_)(std::move(stack), kwargs);
214 }
215 
run_async(std::vector<IValue> stack,const Kwargs & kwargs,TaskLauncher taskLauncher)216 c10::intrusive_ptr<c10::ivalue::Future> Method::run_async(
217     std::vector<IValue> stack,
218     const Kwargs& kwargs,
219     TaskLauncher taskLauncher) {
220   stack.insert(stack.begin(), owner()._ivalue());
221   RECORD_TORCHSCRIPT_FUNCTION(name(), stack);
222 
223   function_->getSchema().checkAndNormalizeInputs(stack, kwargs);
224   return function_->runAsync(stack, std::move(taskLauncher));
225 }
226 
setArgumentNames(std::vector<std::string> & argumentNamesOut) const227 void Method::setArgumentNames(
228     std::vector<std::string>& argumentNamesOut) const {
229   TORCH_INTERNAL_ASSERT(function_);
230   auto& arguments = function_->getSchema().arguments();
231   argumentNamesOut.reserve(arguments.size());
232   for (auto& argument : arguments) {
233     if (argument.name() == "self") {
234       continue;
235     }
236     argumentNamesOut.push_back(argument.name());
237   }
238 }
239 
operator ()(std::vector<IValue> inputs)240 IValue Module::operator()(std::vector<IValue> inputs) {
241   const auto& pre_forward_hooks = type()->getForwardPreHooks();
242   const auto& forward_hooks = type()->getForwardHooks();
243 
244   // call forward pre_hooks
245   for (const auto& pre_hook : pre_forward_hooks) {
246     auto tuple_input = c10::ivalue::Tuple::create(inputs);
247     IValue result = Method(_ivalue(), pre_hook)({tuple_input});
248     if (!result.isNone()) {
249       if (result.isTuple()) {
250         inputs = result.toTupleRef().elements().vec();
251       } else {
252         inputs = {result};
253       }
254     }
255   }
256 
257   // call forward
258   auto outputs = forward(inputs);
259 
260   // call forward hooks
261   for (const auto& hook : forward_hooks) {
262     auto tuple_input = c10::ivalue::Tuple::create(inputs);
263     auto hook_result = Method(_ivalue(), hook)({tuple_input, outputs});
264     if (!hook_result.isNone()) {
265       outputs = hook_result;
266     }
267   }
268   return outputs;
269 }
270 
clone_method(const Module & orig,const Function & method,const std::unordered_map<TypePtr,TypePtr> & type_remap)271 void Module::clone_method(
272     const Module& orig,
273     const Function& method,
274     const std::unordered_map<TypePtr, TypePtr>& type_remap) {
275   // type remapping - when we copy method implementations from one module
276   // singleton to another, we need to update the types of the self arguments
277   // to match the new module.
278   // XXX - this only handles modules that occur as variables, not modules
279   // that appear in aggregate types. Currently this works fine because
280   // we restrict how modules can be used during the lowering step. Eventually,
281   // we will need to decide what it means for us to 'copy' a module.
282   // For instance, we can copy just the state (parameters, attributes),
283   // but share the code. Or we can copy the code. If we choose to copy the
284   // code, what should we do about aggregate types that contain a module?
285   auto type_remap_fn = [&](TypePtr in) {
286     auto it = type_remap.find(in);
287     if (it == type_remap.end())
288       return in;
289     return it->second;
290   };
291   auto graph = toGraphFunction(method).graph()->copy();
292   graph->remapTypes(type_remap_fn);
293   auto schema = method.getSchema().cloneWithRemappedTypes(type_remap_fn);
294   const auto this_method_name = getNameForMethod(method.name());
295   auto copied =
296       _ivalue()->compilation_unit()->create_function(this_method_name, graph);
297   type()->addMethod(copied);
298   copied->setSchema(std::move(schema));
299 }
300 
clone_method(const Module & orig,const std::string & name)301 void Module::clone_method(const Module& orig, const std::string& name) {
302   std::unordered_map<TypePtr, TypePtr> type_remap;
303   std::vector<std::pair<Module, Module>> to_scan = {{orig, *this}};
304   while (!to_scan.empty()) {
305     auto entry = to_scan.back();
306     to_scan.pop_back();
307     type_remap[entry.first._ivalue()->type()] = entry.second._ivalue()->type();
308     for (const NameModule& s : entry.first.named_children()) {
309       to_scan.emplace_back(
310           s.value, Module(entry.second.attr(s.name).toObject()));
311     }
312   }
313   return clone_method(orig, orig.get_method(name).function(), type_remap);
314 }
315 
copy() const316 Module Module::copy() const {
317   return Module(_ivalue()->copy());
318 }
319 
deepcopy(std::optional<at::Device> device) const320 Module Module::deepcopy(std::optional<at::Device> device) const {
321   return Module(_ivalue()->deepcopy(device));
322 }
323 
clone(bool inplace) const324 Module Module::clone(bool inplace) const {
325   std::unordered_map<TypePtr, TypePtr> type_remap;
326   IValue::HashIdentityIValueMap memo;
327   const std::unordered_set<std::string> ignored_methods;
328   const std::unordered_set<std::string> ignored_attributes;
329   return clone_impl(
330       type_remap, inplace, memo, ignored_methods, ignored_attributes);
331 }
332 
clone(bool inplace,const std::unordered_set<std::string> & ignored_methods,const std::unordered_set<std::string> & ignored_attributes) const333 Module Module::clone(
334     bool inplace,
335     const std::unordered_set<std::string>& ignored_methods,
336     const std::unordered_set<std::string>& ignored_attributes) const {
337   std::unordered_map<TypePtr, TypePtr> type_remap;
338   IValue::HashIdentityIValueMap memo;
339   return clone_impl(
340       type_remap, inplace, memo, ignored_methods, ignored_attributes);
341 }
342 
clone_impl(std::unordered_map<TypePtr,TypePtr> & type_remap,bool inplace,IValue::HashIdentityIValueMap memo,const std::unordered_set<std::string> & ignored_methods,const std::unordered_set<std::string> & ignored_attributes) const343 Module Module::clone_impl(
344     std::unordered_map<TypePtr, TypePtr>& type_remap,
345     bool inplace,
346     IValue::HashIdentityIValueMap memo,
347     const std::unordered_set<std::string>& ignored_methods,
348     const std::unordered_set<std::string>& ignored_attributes) const {
349   // Create a new _ivalue in the same compilation unit.
350   // Since now we have shared ClassType, we need to preserve the shared
351   // ClassType during cloning, so we first need to check if the type
352   // is already cloned, if so, we'll create a new module with the cloned
353   // ClassType, if not, we'll create a new module and a new ClassType.
354   bool type_already_cloned = type_remap.find(type()) != type_remap.end();
355   Module r;
356   if (type_already_cloned) {
357     // if we cloned the class type before, we'll reuse it
358     Module new_module(
359         _ivalue()->compilation_unit(), type_remap[type()]->cast<ClassType>());
360     r = new_module;
361   } else {
362     Module new_module(*type()->name(), _ivalue()->compilation_unit(), true);
363     r = new_module;
364     type_remap[type()] = r.type();
365   }
366 
367   // Copy slots. If a slot is a module - recursively clone it.
368   size_t N = type()->numAttributes();
369   for (const auto i : c10::irange(N)) {
370     IValue s = _ivalue()->getSlot(i);
371     std::string attr_name = type()->getAttributeName(i);
372 
373     // If this attribute is in the list of ignored attributes, skip it
374     // (i.e. do not clone it).
375     if (ignored_attributes.count(attr_name) != 0) {
376       continue;
377     }
378 
379     TypePtr attr_type = type()->getAttribute(i);
380     if (attr_type->is_module()) {
381       const Module& orig = Module(s.toObject());
382       const std::unordered_set<std::string> empty_set;
383       Module cloned =
384           orig.clone_impl(type_remap, inplace, memo, empty_set, empty_set);
385       type_remap[orig.type()] = cloned.type();
386       // NOTE: why do we need to manually setattr on object instead of using
387       // register_module here? because the attr can be a module interface
388       // type and hold a Module object still. register_module will not let us
389       // correctly set up the type for this attr, so we had to do this manually.
390       // In the case it's an interface type, the type will be shared by the new
391       // cloned instance in the same compilation unit bc it only contains a list
392       // of functionSchema
393       r.type()->addOrCheckAttribute(
394           attr_name, attr_type->cast<ClassType>() ? cloned.type() : attr_type);
395       r._ivalue()->setAttr(attr_name, cloned._ivalue());
396     } else {
397       // this adds new slot and creates a new attribute for the underlying type
398       // if the type is not already cloned, otherwise it will only add a new
399       // slot and typecheck
400       r.register_attribute(
401           type()->getAttributeName(i),
402           attr_type,
403           // we'll deepcopy the IValue in non inplace option
404           inplace ? s : s.deepcopy(memo),
405           type()->is_parameter(i),
406           type()->is_buffer(i));
407     }
408   }
409 
410   // only clone the methods if the ClassType is not cloned before
411   if (!type_already_cloned) {
412     // clone constants
413     for (size_t i = 0; i < type()->numConstants(); ++i) {
414       r.type()->addConstant(type()->getConstantName(i), type()->getConstant(i));
415     }
416     // clone methods, remapping the types to the cloned ones.
417     for (auto& fn : type()->methods()) {
418       // If this method is not in the list of ignored methods, clone it.
419       if (ignored_methods.count(fn->name()) == 0) {
420         assert_ignored_methods_not_called(*fn, ignored_methods);
421         assert_ignored_attributes_not_referenced(*fn, ignored_attributes);
422         r.clone_method(*this, *fn, type_remap);
423       }
424     }
425 
426     // Execute __setstate__(__getstate__()) to initialize custom class members.
427     if (auto setstate_method = r.find_method("__setstate__")) {
428       auto getstate_method = r.find_method("__getstate__");
429       TORCH_INTERNAL_ASSERT(getstate_method, "expect __getstate__");
430       auto state = (*getstate_method)(Stack{});
431       (*setstate_method)(Stack{state});
432     }
433   }
434   return r;
435 }
436 
train(bool on)437 void Module::train(bool on) {
438   for (Module m : modules()) {
439     if (auto slot = m._ivalue()->type()->findAttributeSlot("training")) {
440       m._ivalue()->setSlot(*slot, on);
441     } else {
442       // FIXME[T110620981]: This assert was broken (never asserted), and once
443       // fixed it triggers test failures.  Fix me!
444       /* TORCH_INTERNAL_ASSERT(false, "'training' attribute not found"); */
445     }
446   }
447 }
448 
create_class(const c10::QualifiedName & name,Stack stack) const449 IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
450   // Look up the class
451   const auto classType =
452       _ivalue()->compilation_unit()->get_class(c10::QualifiedName(name));
453   if (!classType) {
454     AT_ERROR(
455         "Could not find class with name: '",
456         name.qualifiedName(),
457         "' in module.");
458   }
459 
460   // Create a bare object with correct number of slots
461   const size_t numAttrs = classType->numAttributes();
462   auto obj = c10::ivalue::Object::create(
463       c10::StrongTypePtr(_ivalue()->compilation_unit(), classType), numAttrs);
464 
465   // Invoke the `__init__()` of the class with the arguments provided.
466   Stack stackWithSelf = {obj};
467   for (auto& arg : stack) {
468     stackWithSelf.push_back(std::move(arg));
469   }
470   // Note: following Python, `__init__()` modifies its first parameter in-place
471   // and returns nothing.
472   classType->getMethod("__init__").operator()(std::move(stackWithSelf));
473 
474   return obj;
475 }
476 
freeze(const Module & module,const std::optional<std::vector<std::string>> & preserved_attrs,bool optimize_numerics)477 Module freeze(
478     const Module& module,
479     const std::optional<std::vector<std::string>>& preserved_attrs,
480     bool optimize_numerics) {
481   TORCH_CHECK(
482       !module.hasattr("training") || !module.is_training(),
483       "Freezing is currently only implemented for modules in eval mode. Please call .eval() before freezing");
484 
485   Module out_mod = freeze_module(
486       module, preserved_attrs.value_or(std::vector<std::string>({})));
487   auto graph = out_mod.get_method("forward").graph();
488   OptimizeFrozenGraph(graph, optimize_numerics);
489   return out_mod;
490 }
491 
492 namespace {
optimize_for_inference(std::shared_ptr<Graph> graph)493 void optimize_for_inference(std::shared_ptr<Graph> graph) {
494   FuseFrozenConvAddRelu(graph);
495   ConvertFrozenOpsToMKLDNN(graph);
496   FrozenLinearTranspose(graph);
497 }
498 } // namespace
499 
optimize_for_inference(Module & module,const std::vector<std::string> & other_methods)500 Module optimize_for_inference(
501     Module& module,
502     const std::vector<std::string>& other_methods) {
503   // if not frozen yet
504   Module frozen_mod;
505   if (module._ivalue()->type()->hasAttribute("training")) {
506     frozen_mod = freeze(module, {}, true);
507   } else {
508     frozen_mod = module;
509   }
510   if (auto method = frozen_mod.find_method("forward")) {
511     optimize_for_inference(frozen_mod.get_method("forward").graph());
512   }
513   for (const auto& method : other_methods) {
514     optimize_for_inference(frozen_mod.get_method(method).graph());
515   }
516   return frozen_mod;
517 }
518 
buffers(bool recurse) const519 buffer_list Module::buffers(bool recurse) const {
520   return buffer_list(*this, recurse, /*return_module=*/false);
521 }
named_buffers(bool recurse) const522 named_buffer_list Module::named_buffers(bool recurse) const {
523   return named_buffer_list(*this, recurse, /*return_module=*/false);
524 }
525 
children() const526 module_list Module::children() const {
527   return module_list(*this, /*recurse=*/false, /*return_module=*/false);
528 }
named_children() const529 named_module_list Module::named_children() const {
530   return named_module_list(*this, /*recurse=*/false, /*return_module=*/false);
531 }
modules() const532 module_list Module::modules() const {
533   return module_list(*this, /*recurse=*/true, /*return_module=*/true);
534 }
named_modules() const535 named_module_list Module::named_modules() const {
536   return named_module_list(*this, /*recurse=*/true, /*return_module=*/true);
537 }
538 
parameters(bool recurse) const539 parameter_list Module::parameters(bool recurse) const {
540   return parameter_list(*this, recurse, /*return_module=*/false);
541 }
named_parameters(bool recurse) const542 named_parameter_list Module::named_parameters(bool recurse) const {
543   return named_parameter_list(*this, recurse, /*return_module=*/false);
544 }
545 
attributes(bool recurse) const546 attribute_list Module::attributes(bool recurse) const {
547   return attribute_list(*this, recurse, /*return_module=*/false);
548 }
named_attributes(bool recurse) const549 named_attribute_list Module::named_attributes(bool recurse) const {
550   return named_attribute_list(*this, recurse, /*return_module=*/false);
551 }
552 
apply(const std::function<void (Module &)> & fn)553 void Module::apply(const std::function<void(Module&)>& fn) {
554   for (Module s : modules()) {
555     fn(s);
556   }
557 }
558 
dump_to_str(bool print_method_bodies,bool print_attr_values,bool print_param_values) const559 std::string Module::dump_to_str(
560     bool print_method_bodies,
561     bool print_attr_values,
562     bool print_param_values) const {
563   std::stringstream ss;
564   std::stringstream parameters_ss;
565   std::stringstream attributes_ss;
566   std::stringstream methods_ss;
567   std::stringstream submodules_ss;
568 
569   for (const NameTensor& p : named_parameters(/*recurse=*/false)) {
570     parameters_ss << p.name << " = ";
571     if (print_param_values) {
572       parameters_ss << p.value << '\n';
573     } else {
574       parameters_ss << "..." << '\n';
575     }
576   }
577 
578   for (const NameValue& p : named_attributes(/*recurse=*/false)) {
579     attributes_ss << p.name << " = ";
580     if (!p.value.isTensor() || print_attr_values) {
581       attributes_ss << p.value << '\n';
582     } else {
583       attributes_ss << "..." << '\n';
584     }
585   }
586 
587   for (const Method& method : get_methods()) {
588     methods_ss << "  method " << method.name() << " {" << '\n';
589     if (print_method_bodies) {
590       methods_ss << torch::jit::jit_log_prefix(
591                         "    ", method.graph()->toString())
592                  << '\n';
593     }
594     methods_ss << "  }" << '\n';
595   }
596 
597   ss << "module " << type()->name()->qualifiedName() << " {" << '\n';
598   ss << "  parameters {" << '\n';
599   ss << torch::jit::jit_log_prefix("    ", parameters_ss.str());
600   ss << "  }" << '\n';
601   ss << "  attributes {" << '\n';
602   ss << torch::jit::jit_log_prefix("    ", attributes_ss.str());
603   ss << "  }" << '\n';
604   ss << "  methods {" << '\n';
605   ss << torch::jit::jit_log_prefix("  ", methods_ss.str());
606   ss << "  }" << '\n';
607   ss << "  submodules {" << '\n';
608   for (const NameModule& s : named_children()) {
609     // We do 4 spaces here, because one level of indentation comes from
610     // 'submodules' scope and the other one goes from a specific submodule we're
611     // printing.
612     ss << torch::jit::jit_log_prefix(
613         "    ",
614         s.value.dump_to_str(
615             print_method_bodies, print_attr_values, print_param_values));
616   }
617   ss << "  }" << '\n';
618   ss << "}" << '\n';
619 
620   return ss.str();
621 }
622 
dump(bool print_method_bodies=true,bool print_attr_values=true,bool print_param_values=true) const623 void Module::dump(
624     bool print_method_bodies = true,
625     bool print_attr_values = true,
626     bool print_param_values = true) const {
627   std::cout << dump_to_str(
628                    print_method_bodies, print_attr_values, print_param_values)
629             << '\n';
630 }
631 
632 } // namespace torch::jit
633 
634 namespace c10 {
635 
toModule() const636 torch::jit::Module IValue::toModule() const {
637   return torch::jit::Module(toObject());
638 }
isModule() const639 bool IValue::isModule() const {
640   return isObject() && toObjectRef().type()->is_module();
641 }
642 
643 } // namespace c10
644