xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/module.h>
2 
3 #include <torch/ordered_dict.h>
4 
5 #include <torch/csrc/autograd/generated/VariableType.h>
6 
7 #include <c10/util/Exception.h>
8 
9 #include <algorithm>
10 #include <functional>
11 #include <map>
12 #include <ostream>
13 #include <string>
14 #include <typeinfo>
15 
16 namespace torch {
17 namespace nn {
18 namespace {
19 /// Joins names hierarchically: "name_prefix.name" if `name_prefix` is
20 /// non-empty, else just "name".
join_name(const std::string & name_prefix,const std::string & name)21 std::string join_name(const std::string& name_prefix, const std::string& name) {
22   size_t total_size = name.size();
23   if (!name_prefix.empty()) {
24     total_size += name_prefix.size() + 1;
25   }
26   std::string full_name;
27   full_name.reserve(total_size);
28   if (!name_prefix.empty()) {
29     full_name += name_prefix;
30     full_name.push_back('.');
31   }
32   full_name += name;
33   return full_name;
34 }
35 } // namespace
36 
Module()37 Module::Module()
38     : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {}
39 
Module(std::string name)40 Module::Module(std::string name) : Module() {
41   name_ = std::move(name);
42 }
43 
name() const44 const std::string& Module::name() const noexcept {
45   // If the name optional is empty at this point, we grab the name of the
46   // dynamic type via RTTI. Note that we cannot do this in the constructor,
47   // because in the constructor of a base class `this` always refers to the base
48   // type. Inheritance effectively does not work in constructors. Also this note
49   // from http://en.cppreference.com/w/cpp/language/typeid:
50   // If typeid is used on an object under construction or destruction (in a
51   // destructor or in a constructor, including constructor's initializer list
52   // or default member initializers), then the std::type_info object referred
53   // to by this typeid represents the class that is being constructed or
54   // destroyed even if it is not the most-derived class.
55   if (!name_.has_value()) {
56     name_ = c10::demangle(typeid(*this).name());
57 #if defined(_WIN32)
58     // Windows adds "struct" or "class" as a prefix.
59     if (name_->find("struct ") == 0) {
60       name_->erase(name_->begin(), name_->begin() + 7);
61     } else if (name_->find("class ") == 0) {
62       name_->erase(name_->begin(), name_->begin() + 6);
63     }
64 #endif // defined(_WIN32)
65   }
66   return *name_;
67 }
68 
clone(const std::optional<Device> & device) const69 std::shared_ptr<Module> Module::clone(
70     const std::optional<Device>& device) const {
71   AT_ERROR(
72       "clone() has not been implemented for ",
73       name(),
74       ". Subclass torch::nn::Cloneable<",
75       name(),
76       "> instead of torch::nn::Module to inherit the ability to clone.");
77 }
78 
apply(const ModuleApplyFunction & function)79 void Module::apply(const ModuleApplyFunction& function) {
80   function(*this);
81   apply_to_submodules(
82       [&function](const std::string&, const std::shared_ptr<Module>& module) {
83         function(*module);
84       });
85 }
86 
apply(const ConstModuleApplyFunction & function) const87 void Module::apply(const ConstModuleApplyFunction& function) const {
88   function(*this);
89   apply_to_submodules(
90       [&function](const std::string&, const std::shared_ptr<Module>& module) {
91         function(*module);
92       });
93 }
94 
apply(const NamedModuleApplyFunction & function,const std::string & name_prefix)95 void Module::apply(
96     const NamedModuleApplyFunction& function,
97     const std::string& name_prefix) {
98   function(/*name=*/name_prefix, *this);
99   apply_to_submodules(
100       [&function](
101           const std::string& name, const std::shared_ptr<Module>& module) {
102         function(name, *module);
103       },
104       name_prefix);
105 }
106 
apply(const ConstNamedModuleApplyFunction & function,const std::string & name_prefix) const107 void Module::apply(
108     const ConstNamedModuleApplyFunction& function,
109     const std::string& name_prefix) const {
110   function(/*name=*/name_prefix, *this);
111   apply_to_submodules(
112       [&function](
113           const std::string& name, const std::shared_ptr<Module>& module) {
114         function(name, *module);
115       },
116       name_prefix);
117 }
118 
apply(const ModulePointerApplyFunction & function) const119 void Module::apply(const ModulePointerApplyFunction& function) const {
120   function(shared_from_this_checked());
121   apply_to_submodules(
122       [&function](const std::string&, const std::shared_ptr<Module>& module) {
123         function(module);
124       });
125 }
126 
apply(const NamedModulePointerApplyFunction & function,const std::string & name_prefix) const127 void Module::apply(
128     const NamedModulePointerApplyFunction& function,
129     const std::string& name_prefix) const {
130   function(
131       /*name=*/name_prefix, shared_from_this_checked());
132   apply_to_submodules(function, name_prefix);
133 }
134 
parameters(bool recurse) const135 std::vector<Tensor> Module::parameters(bool recurse) const {
136   return named_parameters(recurse).values();
137 }
138 
named_parameters(bool recurse) const139 OrderedDict<std::string, Tensor> Module::named_parameters(bool recurse) const {
140   OrderedDict<std::string, Tensor> result;
141   if (!recurse) {
142     for (const auto& parameter : parameters_) {
143       if (parameter.value().defined()) {
144         result.insert(parameter.key(), parameter.value());
145       }
146     }
147   } else {
148     apply([&result](const std::string& name, const Module& module) {
149       for (const auto& parameter : module.named_parameters(/*recurse=*/false)) {
150         TORCH_INTERNAL_ASSERT(parameter.value().defined());
151         result.insert(join_name(name, parameter.key()), parameter.value());
152       }
153     });
154   }
155   return result;
156 }
157 
buffers(bool recurse) const158 std::vector<Tensor> Module::buffers(bool recurse) const {
159   return named_buffers(recurse).values();
160 }
161 
named_buffers(bool recurse) const162 OrderedDict<std::string, Tensor> Module::named_buffers(bool recurse) const {
163   OrderedDict<std::string, Tensor> result;
164   if (!recurse) {
165     for (const auto& buffer : buffers_) {
166       if (buffer.value().defined()) {
167         result.insert(buffer.key(), buffer.value());
168       }
169     }
170   } else {
171     apply([&result](const std::string& name, const Module& module) {
172       for (const auto& buffer : module.named_buffers(/*recurse=*/false)) {
173         TORCH_INTERNAL_ASSERT(buffer.value().defined());
174         result.insert(join_name(name, buffer.key()), buffer.value());
175       }
176     });
177   }
178   return result;
179 }
180 
modules(bool include_self) const181 std::vector<std::shared_ptr<Module>> Module::modules(bool include_self) const {
182   std::vector<std::shared_ptr<Module>> result;
183   if (include_self) {
184     apply([&result](const std::shared_ptr<Module>& module) {
185       result.push_back(module);
186     });
187   } else {
188     apply_to_submodules(
189         [&result](const std::string&, const std::shared_ptr<Module>& module) {
190           result.push_back(module);
191         });
192   }
193   return result;
194 }
195 
named_modules(const std::string & name_prefix,bool include_self) const196 OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
197     const std::string& name_prefix,
198     bool include_self) const {
199   OrderedDict<std::string, std::shared_ptr<Module>> result;
200   if (include_self) {
201     apply(
202         [&result](
203             const std::string& key, const std::shared_ptr<Module>& module) {
204           result.insert(key, module);
205         },
206         name_prefix);
207   } else {
208     apply_to_submodules(
209         [&result](
210             const std::string& key, const std::shared_ptr<Module>& module) {
211           result.insert(key, module);
212         },
213         name_prefix);
214   }
215   return result;
216 }
217 
children() const218 std::vector<std::shared_ptr<Module>> Module::children() const {
219   return children_.values();
220 }
221 
named_children() const222 OrderedDict<std::string, std::shared_ptr<Module>> Module::named_children()
223     const {
224   return children_;
225 }
226 
train(bool on)227 void Module::train(bool on) {
228   for (auto& child : children_) {
229     child.value()->train(on);
230   }
231   is_training_ = on;
232 }
233 
eval()234 void Module::eval() {
235   train(/*on=*/false);
236 }
237 
to(torch::Device device,torch::Dtype dtype,bool non_blocking)238 void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) {
239   to_impl(device, dtype, non_blocking);
240 }
241 
to(torch::Dtype dtype,bool non_blocking)242 void Module::to(torch::Dtype dtype, bool non_blocking) {
243   to_impl(dtype, non_blocking);
244 }
245 
to(torch::Device device,bool non_blocking)246 void Module::to(torch::Device device, bool non_blocking) {
247   to_impl(device, non_blocking);
248 }
249 
is_training() const250 bool Module::is_training() const noexcept {
251   return is_training_;
252 }
253 
zero_grad(bool set_to_none)254 void Module::zero_grad(bool set_to_none) {
255   for (auto& child : children_) {
256     child.value()->zero_grad(set_to_none);
257   }
258   for (auto& parameter : named_parameters(/*recurse=*/false)) {
259     auto& grad = parameter->mutable_grad();
260     if (grad.defined()) {
261       grad = grad.detach();
262 
263       if (set_to_none)
264         grad.reset();
265       else
266         grad.zero_();
267     }
268   }
269 }
270 
save(serialize::OutputArchive & archive) const271 void Module::save(serialize::OutputArchive& archive) const {
272   for (const auto& parameter : named_parameters(/*recurse=*/false)) {
273     archive.write(parameter.key(), parameter.value());
274   }
275   for (const auto& buffer : named_buffers(/*recurse=*/false)) {
276     archive.write(buffer.key(), buffer.value(), /*is_buffer=*/true);
277   }
278   for (const auto& child : children_) {
279     if (child.value()->is_serializable()) {
280       serialize::OutputArchive child_archive(archive.compilation_unit());
281       child.value()->save(child_archive);
282       archive.write(child.key(), child_archive);
283     }
284   }
285 }
286 
load(serialize::InputArchive & archive)287 void Module::load(serialize::InputArchive& archive) {
288   for (auto& parameter : named_parameters(/*recurse=*/false)) {
289     archive.read(parameter.key(), parameter.value());
290   }
291   for (auto& buffer : named_buffers(/*recurse=*/false)) {
292     archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true);
293   }
294   for (const auto& child : children_) {
295     if (child.value()->is_serializable()) {
296       serialize::InputArchive child_archive;
297       archive.read(child.key(), child_archive);
298       child.value()->load(child_archive);
299     }
300   }
301 }
302 
is_serializable() const303 bool Module::is_serializable() const {
304   return true;
305 }
306 
register_parameter(std::string name,Tensor tensor,bool requires_grad)307 Tensor& Module::register_parameter(
308     std::string name,
309     Tensor tensor,
310     bool requires_grad) {
311   TORCH_CHECK(!name.empty(), "Parameter name must not be empty");
312   TORCH_CHECK(
313       name.find('.') == std::string::npos,
314       "Parameter name must not contain a dot (got '",
315       name,
316       "')");
317   if (!tensor.defined()) {
318     if (requires_grad) {
319       TORCH_WARN(
320           "An undefined tensor cannot require grad. ",
321           "Ignoring the `requires_grad=true` function parameter.");
322     }
323   } else {
324     tensor.set_requires_grad(requires_grad);
325   }
326   return parameters_.insert(std::move(name), std::move(tensor));
327 }
328 
register_buffer(std::string name,Tensor tensor)329 Tensor& Module::register_buffer(std::string name, Tensor tensor) {
330   TORCH_CHECK(!name.empty(), "Buffer name must not be empty");
331   TORCH_CHECK(
332       name.find('.') == std::string::npos,
333       "Buffer name must not contain a dot (got '",
334       name,
335       "')");
336   return buffers_.insert(std::move(name), std::move(tensor));
337 }
338 
unregister_module(const std::string & name)339 void Module::unregister_module(const std::string& name) {
340   TORCH_CHECK(
341       children_.contains(name),
342       "No Module with name `",
343       name,
344       "` is registered");
345   children_.erase(name);
346 }
347 
pretty_print(std::ostream & stream) const348 void Module::pretty_print(std::ostream& stream) const {
349   stream << name();
350 }
351 
pretty_print_recursive(std::ostream & stream,const std::string & indentation) const352 void Module::pretty_print_recursive(
353     std::ostream& stream,
354     const std::string& indentation) const {
355   pretty_print(stream);
356   if (!children_.is_empty()) {
357     stream << "(\n";
358     const std::string next_indentation = indentation + "  ";
359     for (const auto& child : children_) {
360       stream << next_indentation << "(" << child.key() << "): ";
361       child.value()->pretty_print_recursive(stream, next_indentation);
362       stream << '\n';
363     }
364     stream << indentation << ")";
365   }
366 }
367 
clone_(Module & other,const std::optional<Device> & device)368 void Module::clone_(Module& other, const std::optional<Device>& device) {}
369 
apply_to_submodules(const NamedModulePointerApplyFunction & function,const std::string & name_prefix) const370 void Module::apply_to_submodules(
371     const NamedModulePointerApplyFunction& function,
372     const std::string& name_prefix) const {
373   for (const auto& child : children_) {
374     auto qualified_name = join_name(name_prefix, child.key());
375     function(qualified_name, child.value());
376     child.value()->apply_to_submodules(function, qualified_name);
377   }
378 }
379 
shared_from_this_checked() const380 std::shared_ptr<Module> Module::shared_from_this_checked() const {
381   std::shared_ptr<const Module> ptr;
382   try {
383     ptr = shared_from_this();
384   } catch (const std::bad_weak_ptr&) {
385     AT_ERROR(
386         "It looks like you attempted to retrieve your top-level module "
387         "as a shared_ptr, but it is not stored in a shared_ptr. "
388         "Use std::make_shared<",
389         name(),
390         "> instead of creating your module on "
391         "the stack, or alternatively do not try to access your top-level "
392         "module at all by passing /*include_self=*/false "
393         "to modules() or named_modules()");
394   }
395   return std::const_pointer_cast<Module>(ptr);
396 }
397 
operator <<(std::ostream & stream,const nn::Module & module)398 std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
399   module.pretty_print_recursive(stream, "");
400   return stream;
401 }
402 
operator <<(serialize::OutputArchive & archive,const std::shared_ptr<nn::Module> & module)403 serialize::OutputArchive& operator<<(
404     serialize::OutputArchive& archive,
405     const std::shared_ptr<nn::Module>& module) {
406   TORCH_CHECK(module != nullptr, "Cannot serialize empty module");
407   module->save(archive);
408   return archive;
409 }
410 
operator >>(serialize::InputArchive & archive,const std::shared_ptr<nn::Module> & module)411 serialize::InputArchive& operator>>(
412     serialize::InputArchive& archive,
413     const std::shared_ptr<nn::Module>& module) {
414   TORCH_CHECK(module != nullptr, "Cannot deserialize empty module");
415   module->load(archive);
416   return archive;
417 }
418 } // namespace nn
419 } // namespace torch
420