1 #pragma once
2
3 #include <torch/nn/modules/container/any_module_holder.h>
4 #include <torch/nn/modules/container/any_value.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/ordered_dict.h>
7 #include <torch/serialize/archive.h>
8 #include <torch/types.h>
9
10 #include <ATen/ATen.h>
11
12 #include <functional>
13 #include <iosfwd>
14 #include <map>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18
19 namespace torch {
20 namespace nn {
21
22 /// The base class for all modules in PyTorch.
23 ///
24 /// \rst
25 /// .. note::
26 /// The design and implementation of this class is largely based on the Python
27 /// API. You may want to consult the python documentation for
28 /// :py:class:`pytorch:torch.nn.Module` for further clarification on certain
29 /// methods or behavior.
30 /// \endrst
31 ///
32 /// A `Module` is an abstraction over the implementation of some function or
33 /// algorithm, possibly associated with some persistent data. A `Module` may
34 /// contain further `Module`s ("submodules"), each with their own
35 /// implementation, persistent data and further submodules. `Module`s can thus
36 /// be said to form a recursive tree structure. A `Module` is registered as a
37 /// submodule to another `Module` by calling `register_module()`, typically from
38 /// within a parent module's constructor.
39 ///
40 /// A distinction is made between three kinds of persistent data that may be
41 /// associated with a `Module`:
42 ///
43 /// 1. *Parameters*: tensors that record gradients, typically weights updated
44 /// during the backward step (e.g. the `weight` of a `Linear` module),
45 /// 2. *Buffers*: tensors that do not record gradients, typically updated during
46 /// the forward step, such as running statistics (e.g. `mean` and `variance`
47 /// in the `BatchNorm` module),
48 /// 3. Any additional state, not necessarily tensors, required for the
49 /// implementation or configuration of a `Module`.
50 ///
51 /// The first two kinds of state are special in that they may be registered
52 /// with the `Module` system to allow convenient access and batch configuration.
53 /// For example, registered parameters in any `Module` may be iterated over via
54 /// the `parameters()` accessor. Further, changing the data type of a `Module`'s
55 /// registered parameters can be done conveniently via `Module::to()`, e.g.
56 /// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly,
57 /// registered parameters and buffers are handled specially during a `clone()`
58 /// operation, which performs a deepcopy of a cloneable `Module` hierarchy.
59 ///
60 /// Parameters are registered with a `Module` via `register_parameter`. Buffers
61 /// are registered separately via `register_buffer`. These methods are part of
62 /// the public API of `Module` and are typically invoked from within a
63 /// concrete `Module`s constructor.
64 class TORCH_API Module : public std::enable_shared_from_this<Module> {
65 public:
66 using ModuleApplyFunction = std::function<void(Module&)>;
67 using ConstModuleApplyFunction = std::function<void(const Module&)>;
68 using NamedModuleApplyFunction =
69 std::function<void(const std::string&, Module&)>;
70 using ConstNamedModuleApplyFunction =
71 std::function<void(const std::string&, const Module&)>;
72 using ModulePointerApplyFunction =
73 std::function<void(const std::shared_ptr<Module>&)>;
74 using NamedModulePointerApplyFunction =
75 std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
76
77 /// Tells the base `Module` about the name of the submodule.
78 explicit Module(std::string name);
79
80 /// Constructs the module without immediate knowledge of the submodule's name.
81 /// The name of the submodule is inferred via RTTI (if possible) the first
82 /// time `.name()` is invoked.
83 Module();
84 Module(const Module&) = default;
85 Module& operator=(const Module&) = default;
86 Module(Module&&) noexcept = default;
87 Module& operator=(Module&&) noexcept = default;
88
89 virtual ~Module() = default;
90
91 /// Returns the name of the `Module`.
92 ///
93 /// A `Module` has an associated `name`, which is a string representation of
94 /// the kind of concrete `Module` it represents, such as `"Linear"` for the
95 /// `Linear` module. Under most circumstances, this name is automatically
96 /// inferred via runtime type information (RTTI). In the unusual circumstance
97 /// that you have this feature disabled, you may want to manually name your
98 /// `Module`s by passing the string name to the `Module` base class'
99 /// constructor.
100 const std::string& name() const noexcept;
101
102 /// Performs a recursive deep copy of the module and all its registered
103 /// parameters, buffers and submodules.
104 ///
105 /// Optionally, this method sets the current device
106 /// to the one supplied before cloning. If no device is given, each
107 /// parameter and buffer will be moved to the device of its source.
108 ///
109 /// \rst
110 /// .. attention::
111 /// Attempting to call the `clone()` method inherited from the base `Module`
112 /// class (the one documented here) will fail. To inherit an actual
113 /// implementation of `clone()`, you must subclass `Cloneable`. `Cloneable`
114 /// is templatized on the concrete module type, and can thus properly copy a
115 /// `Module`. This method is provided on the base class' API solely for an
116 /// easier-to-use polymorphic interface.
117 /// \endrst
118 virtual std::shared_ptr<Module> clone(
119 const std::optional<Device>& device = std::nullopt) const;
120
121 /// Applies the `function` to the `Module` and recursively to every submodule.
122 /// The function must accept a `Module&`.
123 ///
124 /// \rst
125 /// .. code-block:: cpp
126 /// MyModule module;
127 /// module->apply([](nn::Module& module) {
128 /// std::cout << module.name() << std::endl;
129 /// });
130 /// \endrst
131 void apply(const ModuleApplyFunction& function);
132
133 /// Applies the `function` to the `Module` and recursively to every submodule.
134 /// The function must accept a `const Module&`.
135 ///
136 /// \rst
137 /// .. code-block:: cpp
138 /// MyModule module;
139 /// module->apply([](const nn::Module& module) {
140 /// std::cout << module.name() << std::endl;
141 /// });
142 /// \endrst
143 void apply(const ConstModuleApplyFunction& function) const;
144
145 /// Applies the `function` to the `Module` and recursively to every submodule.
146 /// The function must accept a `const std::string&` for the key of the module,
147 /// and a `Module&`. The key of the module itself is the empty string. If
148 /// `name_prefix` is given, it is prepended to every key as
149 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
150 ///
151 /// \rst
152 /// .. code-block:: cpp
153 /// MyModule module;
154 /// module->apply([](const std::string& key, nn::Module& module) {
155 /// std::cout << key << ": " << module.name() << std::endl;
156 /// });
157 /// \endrst
158 void apply(
159 const NamedModuleApplyFunction& function,
160 const std::string& name_prefix = std::string());
161
162 /// Applies the `function` to the `Module` and recursively to every submodule.
163 /// The function must accept a `const std::string&` for the key of the module,
164 /// and a `const Module&`. The key of the module itself is the empty string.
165 /// If `name_prefix` is given, it is prepended to every key as
166 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
167 ///
168 /// \rst
169 /// .. code-block:: cpp
170 /// MyModule module;
171 /// module->apply([](const std::string& key, const nn::Module& module) {
172 /// std::cout << key << ": " << module.name() << std::endl;
173 /// });
174 /// \endrst
175 void apply(
176 const ConstNamedModuleApplyFunction& function,
177 const std::string& name_prefix = std::string()) const;
178
179 /// Applies the `function` to the `Module` and recursively to every submodule.
180 /// The function must accept a `const std::shared_ptr<Module>&`.
181 ///
182 /// \rst
183 /// .. code-block:: cpp
184 /// MyModule module;
185 /// module->apply([](const std::shared_ptr<nn::Module>& module) {
186 /// std::cout << module->name() << std::endl;
187 /// });
188 /// \endrst
189 void apply(const ModulePointerApplyFunction& function) const;
190
191 /// Applies the `function` to the `Module` and recursively to every submodule.
192 /// The function must accept a `const std::string&` for the key of the module,
193 /// and a `const std::shared_ptr<Module>&`. The key of the module itself is
194 /// the empty string. If `name_prefix` is given, it is prepended to every key
195 /// as
196 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
197 ///
198 /// \rst
199 /// .. code-block:: cpp
200 /// MyModule module;
201 /// module->apply([](const std::string& key,
202 /// const std::shared_ptr<nn::Module>& module) {
203 /// std::cout << key << ": " << module->name() << std::endl;
204 /// });
205 /// \endrst
206 void apply(
207 const NamedModulePointerApplyFunction& function,
208 const std::string& name_prefix = std::string()) const;
209
210 /// Returns the parameters of this `Module` and if `recurse` is true, also
211 /// recursively of every submodule.
212 std::vector<Tensor> parameters(bool recurse = true) const;
213
214 /// Returns an `OrderedDict` with the parameters of this `Module` along with
215 /// their keys, and if `recurse` is true also recursively of every submodule.
216 OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
217
218 /// Returns the buffers of this `Module` and if `recurse` is true, also
219 /// recursively of every submodule.
220 std::vector<Tensor> buffers(bool recurse = true) const;
221
222 /// Returns an `OrderedDict` with the buffers of this `Module` along with
223 /// their keys, and if `recurse` is true also recursively of every submodule.
224 OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;
225
226 /// Returns the submodules of this `Module` (the entire submodule hierarchy)
227 /// and if `include_self` is true, also inserts a `shared_ptr` to this module
228 /// in the first position.
229 ///
230 /// \rst
231 /// .. warning::
232 /// Only pass `include_self` as `true` if this `Module` is stored in a
233 /// `shared_ptr`! Otherwise an exception will be thrown. You may still call
234 /// this method with `include_self` set to false if your `Module` is not
235 /// stored in a `shared_ptr`.
236 /// \endrst
237 std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;
238
239 /// Returns an `OrderedDict` of the submodules of this `Module` (the entire
240 /// submodule hierarchy) and their keys, and if `include_self` is true, also
241 /// inserts a `shared_ptr` to this module in the first position. If
242 /// `name_prefix` is given, it is prepended to every key as
243 /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
244 ///
245 /// \rst
246 /// .. warning::
247 /// Only pass `include_self` as `true` if this `Module` is stored in a
248 /// `shared_ptr`! Otherwise an exception will be thrown. You may still call
249 /// this method with `include_self` set to false if your `Module` is not
250 /// stored in a `shared_ptr`.
251 /// \endrst
252 OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
253 const std::string& name_prefix = std::string(),
254 bool include_self = true) const;
255
256 /// Returns the direct submodules of this `Module`.
257 std::vector<std::shared_ptr<Module>> children() const;
258
259 /// Returns an `OrderedDict` of the direct submodules of this `Module` and
260 /// their keys.
261 OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;
262
263 /// Enables "training" mode.
264 virtual void train(bool on = true);
265
266 /// Calls train(false) to enable "eval" mode.
267 /// Do not override this method, override `train()` instead.
268 void eval();
269
270 /// True if the module is in training mode.
271 ///
272 /// Every `Module` has a boolean associated with it that determines whether
273 /// the `Module` is currently in *training* mode (set via `.train()`) or in
274 /// *evaluation* (inference) mode (set via `.eval()`). This property is
275 /// exposed via `is_training()`, and may be used by the implementation of a
276 /// concrete module to modify its runtime behavior. See the `BatchNorm` or
277 /// `Dropout` modules for examples of `Module`s that use different code paths
278 /// depending on this property.
279 virtual bool is_training() const noexcept;
280
281 /// Recursively casts all parameters to the given `dtype` and `device`.
282 ///
283 /// If `non_blocking` is true and the source is in pinned memory and
284 /// destination is on the GPU or vice versa, the copy is performed
285 /// asynchronously with respect to the host. Otherwise, the argument has no
286 /// effect.
287 virtual void to(
288 torch::Device device,
289 torch::Dtype dtype,
290 bool non_blocking = false);
291
292 /// Recursively casts all parameters to the given dtype.
293 ///
294 /// If `non_blocking` is true and the source is in pinned memory and
295 /// destination is on the GPU or vice versa, the copy is performed
296 /// asynchronously with respect to the host. Otherwise, the argument has no
297 /// effect.
298 virtual void to(torch::Dtype dtype, bool non_blocking = false);
299
300 /// Recursively moves all parameters to the given device.
301 ///
302 /// If `non_blocking` is true and the source is in pinned memory and
303 /// destination is on the GPU or vice versa, the copy is performed
304 /// asynchronously with respect to the host. Otherwise, the argument has no
305 /// effect.
306 virtual void to(torch::Device device, bool non_blocking = false);
307
308 /// Recursively zeros out the `grad` value of each registered parameter.
309 virtual void zero_grad(bool set_to_none = true);
310
311 /// Attempts to cast this `Module` to the given `ModuleType`.
312 ///
313 /// This method is useful when calling `apply()`.
314 /// \rst
315 /// .. code-block:: cpp
316 ///
317 /// void initialize_weights(nn::Module& module) {
318 /// torch::NoGradGuard no_grad;
319 /// if (auto* linear = module.as<nn::Linear>()) {
320 /// linear->weight.normal_(0.0, 0.02);
321 /// }
322 /// }
323 ///
324 /// MyModule module;
325 /// module->apply(initialize_weights);
326 /// \endrst
327 template <typename ModuleType>
328 typename ModuleType::ContainedType* as() noexcept;
329
330 /// Attempts to cast this `Module` to the given `ModuleType`.
331 ///
332 /// This method is useful when calling `apply()`.
333 /// \rst
334 /// .. code-block:: cpp
335 /// void initialize_weights(nn::Module& module) {
336 /// torch::NoGradGuard no_grad;
337 /// if (auto* linear = module.as<nn::Linear>()) {
338 /// linear->weight.normal_(0.0, 0.02);
339 /// }
340 /// }
341 ///
342 /// MyModule module;
343 /// module->apply(initialize_weights);
344 /// \endrst
345 template <typename ModuleType>
346 const typename ModuleType::ContainedType* as() const noexcept;
347
348 /// Attempts to cast this `Module` to the given `ModuleType`.
349 ///
350 /// This method is useful when calling `apply()`.
351 /// \rst
352 /// .. code-block:: cpp
353 ///
354 /// void initialize_weights(nn::Module& module) {
355 /// torch::NoGradGuard no_grad;
356 /// if (auto* linear = module.as<nn::Linear>()) {
357 /// linear->weight.normal_(0.0, 0.02);
358 /// }
359 /// }
360 ///
361 /// MyModule module;
362 /// module.apply(initialize_weights);
363 /// \endrst
364 template <
365 typename ModuleType,
366 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
367 ModuleType* as() noexcept;
368
369 /// Attempts to cast this `Module` to the given `ModuleType`.
370 ///
371 /// This method is useful when calling `apply()`.
372 /// \rst
373 /// .. code-block:: cpp
374 ///
375 /// void initialize_weights(nn::Module& module) {
376 /// torch::NoGradGuard no_grad;
377 /// if (auto* linear = module.as<nn::Linear>()) {
378 /// linear->weight.normal_(0.0, 0.02);
379 /// }
380 /// }
381 ///
382 /// MyModule module;
383 /// module.apply(initialize_weights);
384 /// \endrst
385 template <
386 typename ModuleType,
387 typename = torch::detail::disable_if_module_holder_t<ModuleType>>
388 const ModuleType* as() const noexcept;
389
390 /// Serializes the `Module` into the given `OutputArchive`.
391 ///
392 /// If the `Module` contains unserializable submodules (e.g.
393 /// `nn::Functional`), those submodules are skipped when serializing.
394 virtual void save(serialize::OutputArchive& archive) const;
395
396 /// Deserializes the `Module` from the given `InputArchive`.
397 ///
398 /// If the `Module` contains unserializable submodules (e.g.
399 /// `nn::Functional`), we don't check the existence of those submodules in the
400 /// `InputArchive` when deserializing.
401 virtual void load(serialize::InputArchive& archive);
402
403 /// Streams a pretty representation of the `Module` into the given `stream`.
404 /// By default, this representation will be the name of the module (taken from
405 /// `name()`), followed by a recursive pretty print of all of the `Module`'s
406 /// submodules.
407 ///
408 /// Override this method to change the pretty print. The input
409 /// `stream` should be returned from the method, to allow easy chaining.
410 virtual void pretty_print(std::ostream& stream) const;
411
412 /// Returns whether the `Module` is serializable.
413 virtual bool is_serializable() const;
414
415 /// Registers a parameter with this `Module`.
416 ///
417 /// A parameter should be any gradient-recording tensor used in the
418 /// implementation of your `Module`. Registering it makes it available to
419 /// methods such as `parameters()`, `clone()` or `to().`
420 ///
421 /// Note that registering an undefined Tensor (e.g.
422 /// `module.register_parameter("param", Tensor())`) is allowed, and is
423 /// equivalent to `module.register_parameter("param", None)` in Python API.
424 ///
425 /// \rst
426 /// .. code-block:: cpp
427 ///
428 /// MyModule::MyModule() {
429 /// weight_ = register_parameter("weight", torch::randn({A, B}));
430 /// }
431 /// \endrst
432 Tensor& register_parameter(
433 std::string name,
434 Tensor tensor,
435 bool requires_grad = true);
436
437 /// Registers a buffer with this `Module`.
438 ///
439 /// A buffer is intended to be state in your module that does not record
440 /// gradients, such as running statistics. Registering it makes it available
441 /// to methods such as `buffers()`, `clone()` or `to().
442 ///
443 /// \rst
444 /// .. code-block:: cpp
445 ///
446 /// MyModule::MyModule() {
447 /// mean_ = register_buffer("mean", torch::empty({num_features_}));
448 /// }
449 /// \endrst
450 Tensor& register_buffer(std::string name, Tensor tensor);
451
452 /// Registers a submodule with this `Module`.
453 ///
454 /// Registering a module makes it available to methods such as `modules()`,
455 /// `clone()` or `to()`.
456 ///
457 /// \rst
458 /// .. code-block:: cpp
459 ///
460 /// MyModule::MyModule() {
461 /// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
462 /// }
463 /// \endrst
464 template <typename ModuleType>
465 std::shared_ptr<ModuleType> register_module(
466 std::string name,
467 std::shared_ptr<ModuleType> module);
468
469 /// Registers a submodule with this `Module`.
470 ///
471 /// This method deals with `ModuleHolder`s.
472 ///
473 /// Registering a module makes it available to methods such as `modules()`,
474 /// `clone()` or `to()`.
475 ///
476 /// \rst
477 /// .. code-block:: cpp
478 ///
479 /// MyModule::MyModule() {
480 /// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
481 /// }
482 /// \endrst
483 template <typename ModuleType>
484 std::shared_ptr<ModuleType> register_module(
485 std::string name,
486 ModuleHolder<ModuleType> module_holder);
487
488 /// Replaces a registered submodule with this `Module`.
489 ///
490 /// This takes care of the registration, if you used submodule members, you
491 /// should
492 // assign the submodule as well, i.e. use as
493 /// module->submodule_ = module->replace_module("linear",
494 /// torch::nn::Linear(3, 4));
495 /// It only works when a module of the name is already registered.
496 ///
497 /// This is useful for replacing a module after initialization, e.g.
498 /// for finetuning.
499 template <typename ModuleType>
500 std::shared_ptr<ModuleType> replace_module(
501 const std::string& name,
502 std::shared_ptr<ModuleType> module);
503
504 /// Replaces a registered submodule with this `Module`.
505 /// This method deals with `ModuleHolder`s.
506 ///
507 /// This takes care of the registration, if you used submodule members, you
508 /// should
509 // assign the submodule as well, i.e. use as
510 /// module->submodule_ = module->replace_module("linear", linear_holder);
511 /// It only works when a module of the name is already registered.
512 ///
513 /// This is useful for replacing a module after initialization, e.g.
514 /// for finetuning.
515 template <typename ModuleType>
516 std::shared_ptr<ModuleType> replace_module(
517 const std::string& name,
518 ModuleHolder<ModuleType> module_holder);
519
520 /// Unregisters a submodule from this `Module`. If there is no such module
521 /// with `name` an exception is thrown.
522 void unregister_module(const std::string& name);
523
524 protected:
525 /// The following three functions allow a module with default arguments in its
526 /// forward method to be used in a Sequential module.
527 /// You should NEVER override these functions manually. Instead, you should
528 /// use the `FORWARD_HAS_DEFAULT_ARGS` macro.
_forward_has_default_args()529 virtual bool _forward_has_default_args() {
530 return false;
531 }
532
_forward_num_required_args()533 virtual unsigned int _forward_num_required_args() {
534 TORCH_CHECK(
535 false,
536 "torch::nn::Module subclass that has default arguments in `forward` method ",
537 "must override `_forward_num_required_args` method. Please use ",
538 "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
539 }
540
_forward_populate_default_args(std::vector<AnyValue> && arguments)541 virtual std::vector<AnyValue> _forward_populate_default_args(
542 std::vector<AnyValue>&& arguments) {
543 TORCH_CHECK(
544 false,
545 "torch::nn::Module subclass that has default arguments in `forward` method ",
546 "must override `_forward_populate_default_args` method. Please use ",
547 "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
548 }
549
550 /// The registered parameters of this `Module`.
551 /// Inorder to access parameters_ in ParameterDict and ParameterList
552 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
553 OrderedDict<std::string, Tensor> parameters_;
554
555 private:
556 // Friend classes.
557
558 template <typename Derived>
559 friend class Cloneable;
560
561 template <typename ModuleType, typename... ArgumentTypes>
562 friend struct AnyModuleHolder;
563
564 /// Pretty prints the given `Module` into the `ostream`.
565 TORCH_API friend std::ostream& operator<<(
566 std::ostream& stream,
567 const nn::Module& module);
568
569 // data parallel using this method to configure gradient edges during the
570 // replicate step.
571 template <typename ModuleType>
572 friend void replicate_grad_edges(
573 const std::shared_ptr<Module>& module,
574 const std::vector<std::shared_ptr<ModuleType>>& replicas,
575 const std::vector<Device>& devices);
576
577 // Private methods.
578
579 /// Used in the implementation of `Cloneable`.
580 virtual void clone_(Module& other, const std::optional<Device>& device);
581
582 /// The implementation of the various `to()` methods.
583 template <typename... Ts>
584 void to_impl(Ts&&... ts);
585
586 /// Implements pretty printing the module hierarchy.
587 void pretty_print_recursive(
588 std::ostream& stream,
589 const std::string& indentation) const;
590
591 /// Applies the `function` to every submodule recursively, starting at this
592 /// `Module`'s children (thus not including the module itself).
593 void apply_to_submodules(
594 const NamedModulePointerApplyFunction& function,
595 const std::string& name_prefix = std::string()) const;
596
597 /// Returns a shared_ptr to `this` in a safe (checked) way.
598 std::shared_ptr<Module> shared_from_this_checked() const;
599
600 /// The registered buffers of this `Module`.
601 OrderedDict<std::string, Tensor> buffers_;
602
603 /// The registered (direct) submodules of this `Module`.
604 OrderedDict<std::string, std::shared_ptr<Module>> children_;
605
606 /// The module's name (e.g. "LSTM").
607 mutable std::optional<std::string> name_;
608
609 /// Whether the module is in training mode.
610 bool is_training_{true};
611 };
612
613 /// Serialize a `Module` pointer into an `OutputArchive`.
614 TORCH_API serialize::OutputArchive& operator<<(
615 serialize::OutputArchive& archive,
616 const std::shared_ptr<nn::Module>& module);
617
618 /// Deserializes a `Module` from an `InputArchive`.
619 TORCH_API serialize::InputArchive& operator>>(
620 serialize::InputArchive& archive,
621 const std::shared_ptr<nn::Module>& module);
622
623 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
624
625 template <typename ModuleType>
as()626 typename ModuleType::ContainedType* Module::as() noexcept {
627 // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
628 // `Linear`, since `LinearImpl` inherits `nn::Module`.
629 return as<typename ModuleType::ContainedType>();
630 }
631
632 template <typename ModuleType>
as()633 const typename ModuleType::ContainedType* Module::as() const noexcept {
634 // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
635 // `Linear`, since `LinearImpl` inherits `nn::Module`.
636 return as<typename ModuleType::ContainedType>();
637 }
638
639 template <typename ModuleType, typename>
as()640 ModuleType* Module::as() noexcept {
641 return dynamic_cast<ModuleType*>(this);
642 }
643
644 template <typename ModuleType, typename>
as()645 const ModuleType* Module::as() const noexcept {
646 return dynamic_cast<const ModuleType*>(this);
647 }
648
649 template <typename ModuleType>
register_module(std::string name,std::shared_ptr<ModuleType> module)650 std::shared_ptr<ModuleType> Module::register_module(
651 std::string name,
652 std::shared_ptr<ModuleType> module) {
653 TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
654 TORCH_CHECK(
655 name.find('.') == std::string::npos,
656 "Submodule name must not contain a dot (got '",
657 name,
658 "')");
659 auto& base_module = children_.insert(std::move(name), std::move(module));
660 return std::dynamic_pointer_cast<ModuleType>(base_module);
661 }
662
663 template <typename ModuleType>
register_module(std::string name,ModuleHolder<ModuleType> module_holder)664 std::shared_ptr<ModuleType> Module::register_module(
665 std::string name,
666 ModuleHolder<ModuleType> module_holder) {
667 return register_module(std::move(name), module_holder.ptr());
668 }
669
670 template <typename ModuleType>
replace_module(const std::string & name,std::shared_ptr<ModuleType> module)671 std::shared_ptr<ModuleType> Module::replace_module(
672 const std::string& name,
673 std::shared_ptr<ModuleType> module) {
674 auto& base_module = (children_[name] = std::move(module));
675 return std::dynamic_pointer_cast<ModuleType>(base_module);
676 }
677
678 template <typename ModuleType>
replace_module(const std::string & name,ModuleHolder<ModuleType> module_holder)679 std::shared_ptr<ModuleType> Module::replace_module(
680 const std::string& name,
681 ModuleHolder<ModuleType> module_holder) {
682 return replace_module(name, module_holder.ptr());
683 }
684
685 template <typename... Ts>
to_impl(Ts &&...ts)686 void Module::to_impl(Ts&&... ts) {
687 // First call `to()` on every child module.
688 for (auto& child : children_) {
689 child.value()->to(ts...);
690 }
691 // Then move every parameter to the new dtype/device.
692 for (auto& parameter : named_parameters(/*recurse=*/false)) {
693 parameter->set_data(autograd::Variable(*parameter).to(ts...));
694 }
695 // Then move every buffer to the new dtype/device.
696 for (auto& buffer : named_buffers(/*recurse=*/false)) {
697 buffer->set_data(autograd::Variable(*buffer).to(ts...));
698 }
699 }
700
701 } // namespace nn
702 } // namespace torch
703