xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/module.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/modules/container/any_module_holder.h>
4 #include <torch/nn/modules/container/any_value.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/ordered_dict.h>
7 #include <torch/serialize/archive.h>
8 #include <torch/types.h>
9 
10 #include <ATen/ATen.h>
11 
12 #include <functional>
13 #include <iosfwd>
14 #include <map>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18 
19 namespace torch {
20 namespace nn {
21 
22 /// The base class for all modules in PyTorch.
23 ///
24 /// \rst
25 /// .. note::
26 ///   The design and implementation of this class is largely based on the Python
27 ///   API. You may want to consult the python documentation for
28 ///   :py:class:`pytorch:torch.nn.Module` for further clarification on certain
29 ///   methods or behavior.
30 /// \endrst
31 ///
32 /// A `Module` is an abstraction over the implementation of some function or
33 /// algorithm, possibly associated with some persistent data. A `Module` may
34 /// contain further `Module`s ("submodules"), each with their own
35 /// implementation, persistent data and further submodules. `Module`s can thus
36 /// be said to form a recursive tree structure. A `Module` is registered as a
37 /// submodule to another `Module` by calling `register_module()`, typically from
38 /// within a parent module's constructor.
39 ///
40 /// A distinction is made between three kinds of persistent data that may be
41 /// associated with a `Module`:
42 ///
43 /// 1. *Parameters*: tensors that record gradients, typically weights updated
44 ///    during the backward step (e.g. the `weight` of a `Linear` module),
45 /// 2. *Buffers*: tensors that do not record gradients, typically updated during
46 ///    the forward step, such as running statistics (e.g. `mean` and `variance`
47 ///    in the `BatchNorm` module),
48 /// 3. Any additional state, not necessarily tensors, required for the
49 ///    implementation or configuration of a `Module`.
50 ///
51 /// The first two kinds of state are special in that they may be registered
52 /// with the `Module` system to allow convenient access and batch configuration.
53 /// For example, registered parameters in any `Module` may be iterated over via
54 /// the `parameters()` accessor. Further, changing the data type of a `Module`'s
55 /// registered parameters can be done conveniently via `Module::to()`, e.g.
56 /// `module->to(torch::kCUDA)` to move all parameters to GPU memory. Lastly,
57 /// registered parameters and buffers are handled specially during a `clone()`
58 /// operation, which performs a deepcopy of a cloneable `Module` hierarchy.
59 ///
60 /// Parameters are registered with a `Module` via `register_parameter`. Buffers
61 /// are registered separately via `register_buffer`. These methods are part of
62 /// the public API of `Module` and are typically invoked from within a
63 /// concrete `Module`s constructor.
64 class TORCH_API Module : public std::enable_shared_from_this<Module> {
65  public:
66   using ModuleApplyFunction = std::function<void(Module&)>;
67   using ConstModuleApplyFunction = std::function<void(const Module&)>;
68   using NamedModuleApplyFunction =
69       std::function<void(const std::string&, Module&)>;
70   using ConstNamedModuleApplyFunction =
71       std::function<void(const std::string&, const Module&)>;
72   using ModulePointerApplyFunction =
73       std::function<void(const std::shared_ptr<Module>&)>;
74   using NamedModulePointerApplyFunction =
75       std::function<void(const std::string&, const std::shared_ptr<Module>&)>;
76 
77   /// Tells the base `Module` about the name of the submodule.
78   explicit Module(std::string name);
79 
80   /// Constructs the module without immediate knowledge of the submodule's name.
81   /// The name of the submodule is inferred via RTTI (if possible) the first
82   /// time `.name()` is invoked.
83   Module();
84   Module(const Module&) = default;
85   Module& operator=(const Module&) = default;
86   Module(Module&&) noexcept = default;
87   Module& operator=(Module&&) noexcept = default;
88 
89   virtual ~Module() = default;
90 
91   /// Returns the name of the `Module`.
92   ///
93   /// A `Module` has an associated `name`, which is a string representation of
94   /// the kind of concrete `Module` it represents, such as `"Linear"` for the
95   /// `Linear` module. Under most circumstances, this name is automatically
96   /// inferred via runtime type information (RTTI). In the unusual circumstance
97   /// that you have this feature disabled, you may want to manually name your
98   /// `Module`s by passing the string name to the `Module` base class'
99   /// constructor.
100   const std::string& name() const noexcept;
101 
102   /// Performs a recursive deep copy of the module and all its registered
103   /// parameters, buffers and submodules.
104   ///
105   /// Optionally, this method sets the current device
106   /// to the one supplied before cloning. If no device is given, each
107   /// parameter and buffer will be moved to the device of its source.
108   ///
109   /// \rst
110   /// .. attention::
111   ///   Attempting to call the `clone()` method inherited from the base `Module`
112   ///   class (the one documented here) will fail. To inherit an actual
113   ///   implementation of `clone()`, you must subclass `Cloneable`. `Cloneable`
114   ///   is templatized on the concrete module type, and can thus properly copy a
115   ///   `Module`. This method is provided on the base class' API solely for an
116   ///   easier-to-use polymorphic interface.
117   /// \endrst
118   virtual std::shared_ptr<Module> clone(
119       const std::optional<Device>& device = std::nullopt) const;
120 
121   /// Applies the `function` to the `Module` and recursively to every submodule.
122   /// The function must accept a `Module&`.
123   ///
124   /// \rst
125   /// .. code-block:: cpp
126   ///   MyModule module;
127   ///   module->apply([](nn::Module& module) {
128   ///     std::cout << module.name() << std::endl;
129   ///   });
130   /// \endrst
131   void apply(const ModuleApplyFunction& function);
132 
133   /// Applies the `function` to the `Module` and recursively to every submodule.
134   /// The function must accept a `const Module&`.
135   ///
136   /// \rst
137   /// .. code-block:: cpp
138   ///   MyModule module;
139   ///   module->apply([](const nn::Module& module) {
140   ///     std::cout << module.name() << std::endl;
141   ///   });
142   /// \endrst
143   void apply(const ConstModuleApplyFunction& function) const;
144 
145   /// Applies the `function` to the `Module` and recursively to every submodule.
146   /// The function must accept a `const std::string&` for the key of the module,
147   /// and a `Module&`. The key of the module itself is the empty string. If
148   /// `name_prefix` is given, it is prepended to every key as
149   /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
150   ///
151   /// \rst
152   /// .. code-block:: cpp
153   ///   MyModule module;
154   ///   module->apply([](const std::string& key, nn::Module& module) {
155   ///     std::cout << key << ": " << module.name() << std::endl;
156   ///   });
157   /// \endrst
158   void apply(
159       const NamedModuleApplyFunction& function,
160       const std::string& name_prefix = std::string());
161 
162   /// Applies the `function` to the `Module` and recursively to every submodule.
163   /// The function must accept a `const std::string&` for the key of the module,
164   /// and a `const Module&`. The key of the module itself is the empty string.
165   /// If `name_prefix` is given, it is prepended to every key as
166   /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
167   ///
168   /// \rst
169   /// .. code-block:: cpp
170   ///   MyModule module;
171   ///   module->apply([](const std::string& key, const nn::Module& module) {
172   ///     std::cout << key << ": " << module.name() << std::endl;
173   ///   });
174   /// \endrst
175   void apply(
176       const ConstNamedModuleApplyFunction& function,
177       const std::string& name_prefix = std::string()) const;
178 
179   /// Applies the `function` to the `Module` and recursively to every submodule.
180   /// The function must accept a `const std::shared_ptr<Module>&`.
181   ///
182   /// \rst
183   /// .. code-block:: cpp
184   ///   MyModule module;
185   ///   module->apply([](const std::shared_ptr<nn::Module>& module) {
186   ///     std::cout << module->name() << std::endl;
187   ///   });
188   /// \endrst
189   void apply(const ModulePointerApplyFunction& function) const;
190 
191   /// Applies the `function` to the `Module` and recursively to every submodule.
192   /// The function must accept a `const std::string&` for the key of the module,
193   /// and a `const std::shared_ptr<Module>&`. The key of the module itself is
194   /// the empty string. If `name_prefix` is given, it is prepended to every key
195   /// as
196   /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
197   ///
198   /// \rst
199   /// .. code-block:: cpp
200   ///   MyModule module;
201   ///   module->apply([](const std::string& key,
202   ///                    const std::shared_ptr<nn::Module>& module) {
203   ///     std::cout << key << ": " << module->name() << std::endl;
204   ///   });
205   /// \endrst
206   void apply(
207       const NamedModulePointerApplyFunction& function,
208       const std::string& name_prefix = std::string()) const;
209 
210   /// Returns the parameters of this `Module` and if `recurse` is true, also
211   /// recursively of every submodule.
212   std::vector<Tensor> parameters(bool recurse = true) const;
213 
214   /// Returns an `OrderedDict` with the parameters of this `Module` along with
215   /// their keys, and if `recurse` is true also recursively of every submodule.
216   OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const;
217 
218   /// Returns the buffers of this `Module` and if `recurse` is true, also
219   /// recursively of every submodule.
220   std::vector<Tensor> buffers(bool recurse = true) const;
221 
222   /// Returns an `OrderedDict` with the buffers of this `Module` along with
223   /// their keys, and if `recurse` is true also recursively of every submodule.
224   OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const;
225 
226   /// Returns the submodules of this `Module` (the entire submodule hierarchy)
227   /// and if `include_self` is true, also inserts a `shared_ptr` to this module
228   /// in the first position.
229   ///
230   /// \rst
231   /// .. warning::
232   ///   Only pass `include_self` as `true` if this `Module` is stored in a
233   ///   `shared_ptr`! Otherwise an exception will be thrown. You may still call
234   ///   this method with `include_self` set to false if your `Module` is not
235   ///   stored in a `shared_ptr`.
236   /// \endrst
237   std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const;
238 
239   /// Returns an `OrderedDict` of the submodules of this `Module` (the entire
240   /// submodule hierarchy) and their keys, and if `include_self` is true, also
241   /// inserts a `shared_ptr` to this module in the first position. If
242   /// `name_prefix` is given, it is prepended to every key as
243   /// `<name_prefix>.<key>` (and just `name_prefix` for the module itself).
244   ///
245   /// \rst
246   /// .. warning::
247   ///   Only pass `include_self` as `true` if this `Module` is stored in a
248   ///   `shared_ptr`! Otherwise an exception will be thrown. You may still call
249   ///   this method with `include_self` set to false if your `Module` is not
250   ///   stored in a `shared_ptr`.
251   /// \endrst
252   OrderedDict<std::string, std::shared_ptr<Module>> named_modules(
253       const std::string& name_prefix = std::string(),
254       bool include_self = true) const;
255 
256   /// Returns the direct submodules of this `Module`.
257   std::vector<std::shared_ptr<Module>> children() const;
258 
259   /// Returns an `OrderedDict` of the direct submodules of this `Module` and
260   /// their keys.
261   OrderedDict<std::string, std::shared_ptr<Module>> named_children() const;
262 
263   /// Enables "training" mode.
264   virtual void train(bool on = true);
265 
266   /// Calls train(false) to enable "eval" mode.
267   /// Do not override this method, override `train()` instead.
268   void eval();
269 
270   /// True if the module is in training mode.
271   ///
272   /// Every `Module` has a boolean associated with it that determines whether
273   /// the `Module` is currently in *training* mode (set via `.train()`) or in
274   /// *evaluation* (inference) mode (set via `.eval()`). This property is
275   /// exposed via `is_training()`, and may be used by the implementation of a
276   /// concrete module to modify its runtime behavior. See the `BatchNorm` or
277   /// `Dropout` modules for examples of `Module`s that use different code paths
278   /// depending on this property.
279   virtual bool is_training() const noexcept;
280 
281   /// Recursively casts all parameters to the given `dtype` and `device`.
282   ///
283   /// If `non_blocking` is true and the source is in pinned memory and
284   /// destination is on the GPU or vice versa, the copy is performed
285   /// asynchronously with respect to the host. Otherwise, the argument has no
286   /// effect.
287   virtual void to(
288       torch::Device device,
289       torch::Dtype dtype,
290       bool non_blocking = false);
291 
292   /// Recursively casts all parameters to the given dtype.
293   ///
294   /// If `non_blocking` is true and the source is in pinned memory and
295   /// destination is on the GPU or vice versa, the copy is performed
296   /// asynchronously with respect to the host. Otherwise, the argument has no
297   /// effect.
298   virtual void to(torch::Dtype dtype, bool non_blocking = false);
299 
300   /// Recursively moves all parameters to the given device.
301   ///
302   /// If `non_blocking` is true and the source is in pinned memory and
303   /// destination is on the GPU or vice versa, the copy is performed
304   /// asynchronously with respect to the host. Otherwise, the argument has no
305   /// effect.
306   virtual void to(torch::Device device, bool non_blocking = false);
307 
308   /// Recursively zeros out the `grad` value of each registered parameter.
309   virtual void zero_grad(bool set_to_none = true);
310 
311   /// Attempts to cast this `Module` to the given `ModuleType`.
312   ///
313   /// This method is useful when calling `apply()`.
314   /// \rst
315   /// .. code-block:: cpp
316   ///
317   ///   void initialize_weights(nn::Module& module) {
318   ///     torch::NoGradGuard no_grad;
319   ///     if (auto* linear = module.as<nn::Linear>()) {
320   ///       linear->weight.normal_(0.0, 0.02);
321   ///     }
322   ///   }
323   ///
324   ///   MyModule module;
325   ///   module->apply(initialize_weights);
326   /// \endrst
327   template <typename ModuleType>
328   typename ModuleType::ContainedType* as() noexcept;
329 
330   /// Attempts to cast this `Module` to the given `ModuleType`.
331   ///
332   /// This method is useful when calling `apply()`.
333   /// \rst
334   /// .. code-block:: cpp
335   ///   void initialize_weights(nn::Module& module) {
336   ///     torch::NoGradGuard no_grad;
337   ///     if (auto* linear = module.as<nn::Linear>()) {
338   ///       linear->weight.normal_(0.0, 0.02);
339   ///     }
340   ///   }
341   ///
342   ///   MyModule module;
343   ///   module->apply(initialize_weights);
344   /// \endrst
345   template <typename ModuleType>
346   const typename ModuleType::ContainedType* as() const noexcept;
347 
348   /// Attempts to cast this `Module` to the given `ModuleType`.
349   ///
350   /// This method is useful when calling `apply()`.
351   /// \rst
352   /// .. code-block:: cpp
353   ///
354   ///   void initialize_weights(nn::Module& module) {
355   ///     torch::NoGradGuard no_grad;
356   ///     if (auto* linear = module.as<nn::Linear>()) {
357   ///       linear->weight.normal_(0.0, 0.02);
358   ///     }
359   ///   }
360   ///
361   ///   MyModule module;
362   ///   module.apply(initialize_weights);
363   /// \endrst
364   template <
365       typename ModuleType,
366       typename = torch::detail::disable_if_module_holder_t<ModuleType>>
367   ModuleType* as() noexcept;
368 
369   /// Attempts to cast this `Module` to the given `ModuleType`.
370   ///
371   /// This method is useful when calling `apply()`.
372   /// \rst
373   /// .. code-block:: cpp
374   ///
375   ///   void initialize_weights(nn::Module& module) {
376   ///     torch::NoGradGuard no_grad;
377   ///     if (auto* linear = module.as<nn::Linear>()) {
378   ///       linear->weight.normal_(0.0, 0.02);
379   ///     }
380   ///   }
381   ///
382   ///   MyModule module;
383   ///   module.apply(initialize_weights);
384   /// \endrst
385   template <
386       typename ModuleType,
387       typename = torch::detail::disable_if_module_holder_t<ModuleType>>
388   const ModuleType* as() const noexcept;
389 
390   /// Serializes the `Module` into the given `OutputArchive`.
391   ///
392   /// If the `Module` contains unserializable submodules (e.g.
393   /// `nn::Functional`), those submodules are skipped when serializing.
394   virtual void save(serialize::OutputArchive& archive) const;
395 
396   /// Deserializes the `Module` from the given `InputArchive`.
397   ///
398   /// If the `Module` contains unserializable submodules (e.g.
399   /// `nn::Functional`), we don't check the existence of those submodules in the
400   /// `InputArchive` when deserializing.
401   virtual void load(serialize::InputArchive& archive);
402 
403   /// Streams a pretty representation of the `Module` into the given `stream`.
404   /// By default, this representation will be the name of the module (taken from
405   /// `name()`), followed by a recursive pretty print of all of the `Module`'s
406   /// submodules.
407   ///
408   /// Override this method to change the pretty print. The input
409   /// `stream` should be returned from the method, to allow easy chaining.
410   virtual void pretty_print(std::ostream& stream) const;
411 
412   /// Returns whether the `Module` is serializable.
413   virtual bool is_serializable() const;
414 
415   /// Registers a parameter with this `Module`.
416   ///
417   /// A parameter should be any gradient-recording tensor used in the
418   /// implementation of your `Module`. Registering it makes it available to
419   /// methods such as `parameters()`, `clone()` or `to().`
420   ///
421   /// Note that registering an undefined Tensor (e.g.
422   /// `module.register_parameter("param", Tensor())`) is allowed, and is
423   /// equivalent to `module.register_parameter("param", None)` in Python API.
424   ///
425   /// \rst
426   /// .. code-block:: cpp
427   ///
428   ///   MyModule::MyModule() {
429   ///     weight_ = register_parameter("weight", torch::randn({A, B}));
430   ///   }
431   /// \endrst
432   Tensor& register_parameter(
433       std::string name,
434       Tensor tensor,
435       bool requires_grad = true);
436 
437   /// Registers a buffer with this `Module`.
438   ///
439   /// A buffer is intended to be state in your module that does not record
440   /// gradients, such as running statistics. Registering it makes it available
441   /// to methods such as `buffers()`, `clone()` or `to().
442   ///
443   /// \rst
444   /// .. code-block:: cpp
445   ///
446   ///   MyModule::MyModule() {
447   ///     mean_ = register_buffer("mean", torch::empty({num_features_}));
448   ///   }
449   /// \endrst
450   Tensor& register_buffer(std::string name, Tensor tensor);
451 
452   /// Registers a submodule with this `Module`.
453   ///
454   /// Registering a module makes it available to methods such as `modules()`,
455   /// `clone()` or `to()`.
456   ///
457   /// \rst
458   /// .. code-block:: cpp
459   ///
460   ///   MyModule::MyModule() {
461   ///     submodule_ = register_module("linear", torch::nn::Linear(3, 4));
462   ///   }
463   /// \endrst
464   template <typename ModuleType>
465   std::shared_ptr<ModuleType> register_module(
466       std::string name,
467       std::shared_ptr<ModuleType> module);
468 
469   /// Registers a submodule with this `Module`.
470   ///
471   /// This method deals with `ModuleHolder`s.
472   ///
473   /// Registering a module makes it available to methods such as `modules()`,
474   /// `clone()` or `to()`.
475   ///
476   /// \rst
477   /// .. code-block:: cpp
478   ///
479   ///   MyModule::MyModule() {
480   ///     submodule_ = register_module("linear", torch::nn::Linear(3, 4));
481   ///   }
482   /// \endrst
483   template <typename ModuleType>
484   std::shared_ptr<ModuleType> register_module(
485       std::string name,
486       ModuleHolder<ModuleType> module_holder);
487 
488   /// Replaces a registered submodule with this `Module`.
489   ///
490   /// This takes care of the registration, if you used submodule members, you
491   /// should
492   //  assign the submodule as well, i.e. use as
493   ///     module->submodule_ = module->replace_module("linear",
494   ///     torch::nn::Linear(3, 4));
495   /// It only works when a module of the name is already registered.
496   ///
497   /// This is useful for replacing a module after initialization, e.g.
498   /// for finetuning.
499   template <typename ModuleType>
500   std::shared_ptr<ModuleType> replace_module(
501       const std::string& name,
502       std::shared_ptr<ModuleType> module);
503 
504   /// Replaces a registered submodule with this `Module`.
505   /// This method deals with `ModuleHolder`s.
506   ///
507   /// This takes care of the registration, if you used submodule members, you
508   /// should
509   //  assign the submodule as well, i.e. use as
510   ///     module->submodule_ = module->replace_module("linear", linear_holder);
511   /// It only works when a module of the name is already registered.
512   ///
513   /// This is useful for replacing a module after initialization, e.g.
514   /// for finetuning.
515   template <typename ModuleType>
516   std::shared_ptr<ModuleType> replace_module(
517       const std::string& name,
518       ModuleHolder<ModuleType> module_holder);
519 
520   /// Unregisters a submodule from this `Module`. If there is no such module
521   /// with `name` an exception is thrown.
522   void unregister_module(const std::string& name);
523 
524  protected:
525   /// The following three functions allow a module with default arguments in its
526   /// forward method to be used in a Sequential module.
527   /// You should NEVER override these functions manually. Instead, you should
528   /// use the `FORWARD_HAS_DEFAULT_ARGS` macro.
_forward_has_default_args()529   virtual bool _forward_has_default_args() {
530     return false;
531   }
532 
_forward_num_required_args()533   virtual unsigned int _forward_num_required_args() {
534     TORCH_CHECK(
535         false,
536         "torch::nn::Module subclass that has default arguments in `forward` method ",
537         "must override `_forward_num_required_args` method. Please use ",
538         "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
539   }
540 
_forward_populate_default_args(std::vector<AnyValue> && arguments)541   virtual std::vector<AnyValue> _forward_populate_default_args(
542       std::vector<AnyValue>&& arguments) {
543     TORCH_CHECK(
544         false,
545         "torch::nn::Module subclass that has default arguments in `forward` method ",
546         "must override `_forward_populate_default_args` method. Please use ",
547         "`FORWARD_HAS_DEFAULT_ARGS` macro to do so.");
548   }
549 
550   /// The registered parameters of this `Module`.
551   /// Inorder to access parameters_ in ParameterDict and ParameterList
552   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
553   OrderedDict<std::string, Tensor> parameters_;
554 
555  private:
556   // Friend classes.
557 
558   template <typename Derived>
559   friend class Cloneable;
560 
561   template <typename ModuleType, typename... ArgumentTypes>
562   friend struct AnyModuleHolder;
563 
564   /// Pretty prints the given `Module` into the `ostream`.
565   TORCH_API friend std::ostream& operator<<(
566       std::ostream& stream,
567       const nn::Module& module);
568 
569   // data parallel using this method to configure gradient edges during the
570   // replicate step.
571   template <typename ModuleType>
572   friend void replicate_grad_edges(
573       const std::shared_ptr<Module>& module,
574       const std::vector<std::shared_ptr<ModuleType>>& replicas,
575       const std::vector<Device>& devices);
576 
577   // Private methods.
578 
579   /// Used in the implementation of `Cloneable`.
580   virtual void clone_(Module& other, const std::optional<Device>& device);
581 
582   /// The implementation of the various `to()` methods.
583   template <typename... Ts>
584   void to_impl(Ts&&... ts);
585 
586   /// Implements pretty printing the module hierarchy.
587   void pretty_print_recursive(
588       std::ostream& stream,
589       const std::string& indentation) const;
590 
591   /// Applies the `function` to every submodule recursively, starting at this
592   /// `Module`'s children (thus not including the module itself).
593   void apply_to_submodules(
594       const NamedModulePointerApplyFunction& function,
595       const std::string& name_prefix = std::string()) const;
596 
597   /// Returns a shared_ptr to `this` in a safe (checked) way.
598   std::shared_ptr<Module> shared_from_this_checked() const;
599 
600   /// The registered buffers of this `Module`.
601   OrderedDict<std::string, Tensor> buffers_;
602 
603   /// The registered (direct) submodules of this `Module`.
604   OrderedDict<std::string, std::shared_ptr<Module>> children_;
605 
606   /// The module's name (e.g. "LSTM").
607   mutable std::optional<std::string> name_;
608 
609   /// Whether the module is in training mode.
610   bool is_training_{true};
611 };
612 
613 /// Serialize a `Module` pointer into an `OutputArchive`.
614 TORCH_API serialize::OutputArchive& operator<<(
615     serialize::OutputArchive& archive,
616     const std::shared_ptr<nn::Module>& module);
617 
618 /// Deserializes a `Module` from an `InputArchive`.
619 TORCH_API serialize::InputArchive& operator>>(
620     serialize::InputArchive& archive,
621     const std::shared_ptr<nn::Module>& module);
622 
623 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
624 
625 template <typename ModuleType>
as()626 typename ModuleType::ContainedType* Module::as() noexcept {
627   // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
628   // `Linear`, since `LinearImpl` inherits `nn::Module`.
629   return as<typename ModuleType::ContainedType>();
630 }
631 
632 template <typename ModuleType>
as()633 const typename ModuleType::ContainedType* Module::as() const noexcept {
634   // Use the contained type of the `ModuleHolder`, e.g. `LinearImpl` for
635   // `Linear`, since `LinearImpl` inherits `nn::Module`.
636   return as<typename ModuleType::ContainedType>();
637 }
638 
639 template <typename ModuleType, typename>
as()640 ModuleType* Module::as() noexcept {
641   return dynamic_cast<ModuleType*>(this);
642 }
643 
644 template <typename ModuleType, typename>
as()645 const ModuleType* Module::as() const noexcept {
646   return dynamic_cast<const ModuleType*>(this);
647 }
648 
649 template <typename ModuleType>
register_module(std::string name,std::shared_ptr<ModuleType> module)650 std::shared_ptr<ModuleType> Module::register_module(
651     std::string name,
652     std::shared_ptr<ModuleType> module) {
653   TORCH_CHECK(!name.empty(), "Submodule name must not be empty");
654   TORCH_CHECK(
655       name.find('.') == std::string::npos,
656       "Submodule name must not contain a dot (got '",
657       name,
658       "')");
659   auto& base_module = children_.insert(std::move(name), std::move(module));
660   return std::dynamic_pointer_cast<ModuleType>(base_module);
661 }
662 
663 template <typename ModuleType>
register_module(std::string name,ModuleHolder<ModuleType> module_holder)664 std::shared_ptr<ModuleType> Module::register_module(
665     std::string name,
666     ModuleHolder<ModuleType> module_holder) {
667   return register_module(std::move(name), module_holder.ptr());
668 }
669 
670 template <typename ModuleType>
replace_module(const std::string & name,std::shared_ptr<ModuleType> module)671 std::shared_ptr<ModuleType> Module::replace_module(
672     const std::string& name,
673     std::shared_ptr<ModuleType> module) {
674   auto& base_module = (children_[name] = std::move(module));
675   return std::dynamic_pointer_cast<ModuleType>(base_module);
676 }
677 
678 template <typename ModuleType>
replace_module(const std::string & name,ModuleHolder<ModuleType> module_holder)679 std::shared_ptr<ModuleType> Module::replace_module(
680     const std::string& name,
681     ModuleHolder<ModuleType> module_holder) {
682   return replace_module(name, module_holder.ptr());
683 }
684 
685 template <typename... Ts>
to_impl(Ts &&...ts)686 void Module::to_impl(Ts&&... ts) {
687   // First call `to()` on every child module.
688   for (auto& child : children_) {
689     child.value()->to(ts...);
690   }
691   // Then move every parameter to the new dtype/device.
692   for (auto& parameter : named_parameters(/*recurse=*/false)) {
693     parameter->set_data(autograd::Variable(*parameter).to(ts...));
694   }
695   // Then move every buffer to the new dtype/device.
696   for (auto& buffer : named_buffers(/*recurse=*/false)) {
697     buffer->set_data(autograd::Variable(*buffer).to(ts...));
698   }
699 }
700 
701 } // namespace nn
702 } // namespace torch
703