1 #include <torch/nn/module.h>
2
3 #include <torch/ordered_dict.h>
4
5 #include <torch/csrc/autograd/generated/VariableType.h>
6
7 #include <c10/util/Exception.h>
8
9 #include <algorithm>
10 #include <functional>
11 #include <map>
12 #include <ostream>
13 #include <string>
14 #include <typeinfo>
15
16 namespace torch {
17 namespace nn {
18 namespace {
19 /// Joins names hierarchically: "name_prefix.name" if `name_prefix` is
20 /// non-empty, else just "name".
join_name(const std::string & name_prefix,const std::string & name)21 std::string join_name(const std::string& name_prefix, const std::string& name) {
22 size_t total_size = name.size();
23 if (!name_prefix.empty()) {
24 total_size += name_prefix.size() + 1;
25 }
26 std::string full_name;
27 full_name.reserve(total_size);
28 if (!name_prefix.empty()) {
29 full_name += name_prefix;
30 full_name.push_back('.');
31 }
32 full_name += name;
33 return full_name;
34 }
35 } // namespace
36
Module()37 Module::Module()
38 : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {}
39
Module(std::string name)40 Module::Module(std::string name) : Module() {
41 name_ = std::move(name);
42 }
43
name() const44 const std::string& Module::name() const noexcept {
45 // If the name optional is empty at this point, we grab the name of the
46 // dynamic type via RTTI. Note that we cannot do this in the constructor,
47 // because in the constructor of a base class `this` always refers to the base
48 // type. Inheritance effectively does not work in constructors. Also this note
49 // from http://en.cppreference.com/w/cpp/language/typeid:
50 // If typeid is used on an object under construction or destruction (in a
51 // destructor or in a constructor, including constructor's initializer list
52 // or default member initializers), then the std::type_info object referred
53 // to by this typeid represents the class that is being constructed or
54 // destroyed even if it is not the most-derived class.
55 if (!name_.has_value()) {
56 name_ = c10::demangle(typeid(*this).name());
57 #if defined(_WIN32)
58 // Windows adds "struct" or "class" as a prefix.
59 if (name_->find("struct ") == 0) {
60 name_->erase(name_->begin(), name_->begin() + 7);
61 } else if (name_->find("class ") == 0) {
62 name_->erase(name_->begin(), name_->begin() + 6);
63 }
64 #endif // defined(_WIN32)
65 }
66 return *name_;
67 }
68
clone(const std::optional<Device> & device) const69 std::shared_ptr<Module> Module::clone(
70 const std::optional<Device>& device) const {
71 AT_ERROR(
72 "clone() has not been implemented for ",
73 name(),
74 ". Subclass torch::nn::Cloneable<",
75 name(),
76 "> instead of torch::nn::Module to inherit the ability to clone.");
77 }
78
apply(const ModuleApplyFunction & function)79 void Module::apply(const ModuleApplyFunction& function) {
80 function(*this);
81 apply_to_submodules(
82 [&function](const std::string&, const std::shared_ptr<Module>& module) {
83 function(*module);
84 });
85 }
86
apply(const ConstModuleApplyFunction & function) const87 void Module::apply(const ConstModuleApplyFunction& function) const {
88 function(*this);
89 apply_to_submodules(
90 [&function](const std::string&, const std::shared_ptr<Module>& module) {
91 function(*module);
92 });
93 }
94
apply(const NamedModuleApplyFunction & function,const std::string & name_prefix)95 void Module::apply(
96 const NamedModuleApplyFunction& function,
97 const std::string& name_prefix) {
98 function(/*name=*/name_prefix, *this);
99 apply_to_submodules(
100 [&function](
101 const std::string& name, const std::shared_ptr<Module>& module) {
102 function(name, *module);
103 },
104 name_prefix);
105 }
106
apply(const ConstNamedModuleApplyFunction & function,const std::string & name_prefix) const107 void Module::apply(
108 const ConstNamedModuleApplyFunction& function,
109 const std::string& name_prefix) const {
110 function(/*name=*/name_prefix, *this);
111 apply_to_submodules(
112 [&function](
113 const std::string& name, const std::shared_ptr<Module>& module) {
114 function(name, *module);
115 },
116 name_prefix);
117 }
118
apply(const ModulePointerApplyFunction & function) const119 void Module::apply(const ModulePointerApplyFunction& function) const {
120 function(shared_from_this_checked());
121 apply_to_submodules(
122 [&function](const std::string&, const std::shared_ptr<Module>& module) {
123 function(module);
124 });
125 }
126
apply(const NamedModulePointerApplyFunction & function,const std::string & name_prefix) const127 void Module::apply(
128 const NamedModulePointerApplyFunction& function,
129 const std::string& name_prefix) const {
130 function(
131 /*name=*/name_prefix, shared_from_this_checked());
132 apply_to_submodules(function, name_prefix);
133 }
134
parameters(bool recurse) const135 std::vector<Tensor> Module::parameters(bool recurse) const {
136 return named_parameters(recurse).values();
137 }
138
named_parameters(bool recurse) const139 OrderedDict<std::string, Tensor> Module::named_parameters(bool recurse) const {
140 OrderedDict<std::string, Tensor> result;
141 if (!recurse) {
142 for (const auto& parameter : parameters_) {
143 if (parameter.value().defined()) {
144 result.insert(parameter.key(), parameter.value());
145 }
146 }
147 } else {
148 apply([&result](const std::string& name, const Module& module) {
149 for (const auto& parameter : module.named_parameters(/*recurse=*/false)) {
150 TORCH_INTERNAL_ASSERT(parameter.value().defined());
151 result.insert(join_name(name, parameter.key()), parameter.value());
152 }
153 });
154 }
155 return result;
156 }
157
buffers(bool recurse) const158 std::vector<Tensor> Module::buffers(bool recurse) const {
159 return named_buffers(recurse).values();
160 }
161
named_buffers(bool recurse) const162 OrderedDict<std::string, Tensor> Module::named_buffers(bool recurse) const {
163 OrderedDict<std::string, Tensor> result;
164 if (!recurse) {
165 for (const auto& buffer : buffers_) {
166 if (buffer.value().defined()) {
167 result.insert(buffer.key(), buffer.value());
168 }
169 }
170 } else {
171 apply([&result](const std::string& name, const Module& module) {
172 for (const auto& buffer : module.named_buffers(/*recurse=*/false)) {
173 TORCH_INTERNAL_ASSERT(buffer.value().defined());
174 result.insert(join_name(name, buffer.key()), buffer.value());
175 }
176 });
177 }
178 return result;
179 }
180
modules(bool include_self) const181 std::vector<std::shared_ptr<Module>> Module::modules(bool include_self) const {
182 std::vector<std::shared_ptr<Module>> result;
183 if (include_self) {
184 apply([&result](const std::shared_ptr<Module>& module) {
185 result.push_back(module);
186 });
187 } else {
188 apply_to_submodules(
189 [&result](const std::string&, const std::shared_ptr<Module>& module) {
190 result.push_back(module);
191 });
192 }
193 return result;
194 }
195
named_modules(const std::string & name_prefix,bool include_self) const196 OrderedDict<std::string, std::shared_ptr<Module>> Module::named_modules(
197 const std::string& name_prefix,
198 bool include_self) const {
199 OrderedDict<std::string, std::shared_ptr<Module>> result;
200 if (include_self) {
201 apply(
202 [&result](
203 const std::string& key, const std::shared_ptr<Module>& module) {
204 result.insert(key, module);
205 },
206 name_prefix);
207 } else {
208 apply_to_submodules(
209 [&result](
210 const std::string& key, const std::shared_ptr<Module>& module) {
211 result.insert(key, module);
212 },
213 name_prefix);
214 }
215 return result;
216 }
217
children() const218 std::vector<std::shared_ptr<Module>> Module::children() const {
219 return children_.values();
220 }
221
named_children() const222 OrderedDict<std::string, std::shared_ptr<Module>> Module::named_children()
223 const {
224 return children_;
225 }
226
train(bool on)227 void Module::train(bool on) {
228 for (auto& child : children_) {
229 child.value()->train(on);
230 }
231 is_training_ = on;
232 }
233
eval()234 void Module::eval() {
235 train(/*on=*/false);
236 }
237
to(torch::Device device,torch::Dtype dtype,bool non_blocking)238 void Module::to(torch::Device device, torch::Dtype dtype, bool non_blocking) {
239 to_impl(device, dtype, non_blocking);
240 }
241
to(torch::Dtype dtype,bool non_blocking)242 void Module::to(torch::Dtype dtype, bool non_blocking) {
243 to_impl(dtype, non_blocking);
244 }
245
to(torch::Device device,bool non_blocking)246 void Module::to(torch::Device device, bool non_blocking) {
247 to_impl(device, non_blocking);
248 }
249
is_training() const250 bool Module::is_training() const noexcept {
251 return is_training_;
252 }
253
zero_grad(bool set_to_none)254 void Module::zero_grad(bool set_to_none) {
255 for (auto& child : children_) {
256 child.value()->zero_grad(set_to_none);
257 }
258 for (auto& parameter : named_parameters(/*recurse=*/false)) {
259 auto& grad = parameter->mutable_grad();
260 if (grad.defined()) {
261 grad = grad.detach();
262
263 if (set_to_none)
264 grad.reset();
265 else
266 grad.zero_();
267 }
268 }
269 }
270
save(serialize::OutputArchive & archive) const271 void Module::save(serialize::OutputArchive& archive) const {
272 for (const auto& parameter : named_parameters(/*recurse=*/false)) {
273 archive.write(parameter.key(), parameter.value());
274 }
275 for (const auto& buffer : named_buffers(/*recurse=*/false)) {
276 archive.write(buffer.key(), buffer.value(), /*is_buffer=*/true);
277 }
278 for (const auto& child : children_) {
279 if (child.value()->is_serializable()) {
280 serialize::OutputArchive child_archive(archive.compilation_unit());
281 child.value()->save(child_archive);
282 archive.write(child.key(), child_archive);
283 }
284 }
285 }
286
load(serialize::InputArchive & archive)287 void Module::load(serialize::InputArchive& archive) {
288 for (auto& parameter : named_parameters(/*recurse=*/false)) {
289 archive.read(parameter.key(), parameter.value());
290 }
291 for (auto& buffer : named_buffers(/*recurse=*/false)) {
292 archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true);
293 }
294 for (const auto& child : children_) {
295 if (child.value()->is_serializable()) {
296 serialize::InputArchive child_archive;
297 archive.read(child.key(), child_archive);
298 child.value()->load(child_archive);
299 }
300 }
301 }
302
is_serializable() const303 bool Module::is_serializable() const {
304 return true;
305 }
306
register_parameter(std::string name,Tensor tensor,bool requires_grad)307 Tensor& Module::register_parameter(
308 std::string name,
309 Tensor tensor,
310 bool requires_grad) {
311 TORCH_CHECK(!name.empty(), "Parameter name must not be empty");
312 TORCH_CHECK(
313 name.find('.') == std::string::npos,
314 "Parameter name must not contain a dot (got '",
315 name,
316 "')");
317 if (!tensor.defined()) {
318 if (requires_grad) {
319 TORCH_WARN(
320 "An undefined tensor cannot require grad. ",
321 "Ignoring the `requires_grad=true` function parameter.");
322 }
323 } else {
324 tensor.set_requires_grad(requires_grad);
325 }
326 return parameters_.insert(std::move(name), std::move(tensor));
327 }
328
register_buffer(std::string name,Tensor tensor)329 Tensor& Module::register_buffer(std::string name, Tensor tensor) {
330 TORCH_CHECK(!name.empty(), "Buffer name must not be empty");
331 TORCH_CHECK(
332 name.find('.') == std::string::npos,
333 "Buffer name must not contain a dot (got '",
334 name,
335 "')");
336 return buffers_.insert(std::move(name), std::move(tensor));
337 }
338
unregister_module(const std::string & name)339 void Module::unregister_module(const std::string& name) {
340 TORCH_CHECK(
341 children_.contains(name),
342 "No Module with name `",
343 name,
344 "` is registered");
345 children_.erase(name);
346 }
347
pretty_print(std::ostream & stream) const348 void Module::pretty_print(std::ostream& stream) const {
349 stream << name();
350 }
351
pretty_print_recursive(std::ostream & stream,const std::string & indentation) const352 void Module::pretty_print_recursive(
353 std::ostream& stream,
354 const std::string& indentation) const {
355 pretty_print(stream);
356 if (!children_.is_empty()) {
357 stream << "(\n";
358 const std::string next_indentation = indentation + " ";
359 for (const auto& child : children_) {
360 stream << next_indentation << "(" << child.key() << "): ";
361 child.value()->pretty_print_recursive(stream, next_indentation);
362 stream << '\n';
363 }
364 stream << indentation << ")";
365 }
366 }
367
clone_(Module & other,const std::optional<Device> & device)368 void Module::clone_(Module& other, const std::optional<Device>& device) {}
369
apply_to_submodules(const NamedModulePointerApplyFunction & function,const std::string & name_prefix) const370 void Module::apply_to_submodules(
371 const NamedModulePointerApplyFunction& function,
372 const std::string& name_prefix) const {
373 for (const auto& child : children_) {
374 auto qualified_name = join_name(name_prefix, child.key());
375 function(qualified_name, child.value());
376 child.value()->apply_to_submodules(function, qualified_name);
377 }
378 }
379
shared_from_this_checked() const380 std::shared_ptr<Module> Module::shared_from_this_checked() const {
381 std::shared_ptr<const Module> ptr;
382 try {
383 ptr = shared_from_this();
384 } catch (const std::bad_weak_ptr&) {
385 AT_ERROR(
386 "It looks like you attempted to retrieve your top-level module "
387 "as a shared_ptr, but it is not stored in a shared_ptr. "
388 "Use std::make_shared<",
389 name(),
390 "> instead of creating your module on "
391 "the stack, or alternatively do not try to access your top-level "
392 "module at all by passing /*include_self=*/false "
393 "to modules() or named_modules()");
394 }
395 return std::const_pointer_cast<Module>(ptr);
396 }
397
operator <<(std::ostream & stream,const nn::Module & module)398 std::ostream& operator<<(std::ostream& stream, const nn::Module& module) {
399 module.pretty_print_recursive(stream, "");
400 return stream;
401 }
402
operator <<(serialize::OutputArchive & archive,const std::shared_ptr<nn::Module> & module)403 serialize::OutputArchive& operator<<(
404 serialize::OutputArchive& archive,
405 const std::shared_ptr<nn::Module>& module) {
406 TORCH_CHECK(module != nullptr, "Cannot serialize empty module");
407 module->save(archive);
408 return archive;
409 }
410
operator >>(serialize::InputArchive & archive,const std::shared_ptr<nn::Module> & module)411 serialize::InputArchive& operator>>(
412 serialize::InputArchive& archive,
413 const std::shared_ptr<nn::Module>& module) {
414 TORCH_CHECK(module != nullptr, "Cannot deserialize empty module");
415 module->load(archive);
416 return archive;
417 }
418 } // namespace nn
419 } // namespace torch
420