xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/cloneable.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/module.h>
4 #include <torch/types.h>
5 #include <torch/utils.h>
6 
7 #include <c10/core/TensorOptions.h>
8 #include <c10/util/Exception.h>
9 
10 #include <memory>
11 #include <utility>
12 
13 namespace torch {
14 namespace nn {
15 /// The `clone()` method in the base `Module` class does not have knowledge of
16 /// the concrete runtime type of its subclasses. Therefore, `clone()` must
17 /// either be called from within the subclass, or from a base class that has
18 /// knowledge of the concrete type. `Cloneable` uses the CRTP to gain
19 /// knowledge of the subclass' static type and provide an implementation of the
20 /// `clone()` method. We do not want to use this pattern in the base class,
21 /// because then storing a module would always require templatizing it.
22 template <typename Derived>
23 // NOLINTNEXTLINE(bugprone-exception-escape)
24 class Cloneable : public Module {
25  public:
26   using Module::Module;
27 
28   /// `reset()` must perform initialization of all members with reference
29   /// semantics, most importantly parameters, buffers and submodules.
30   virtual void reset() = 0;
31 
32   /// Performs a recursive "deep copy" of the `Module`, such that all parameters
33   /// and submodules in the cloned module are different from those in the
34   /// original module.
35   std::shared_ptr<Module> clone(
36       const std::optional<Device>& device = std::nullopt) const override {
37     NoGradGuard no_grad;
38 
39     const auto& self = static_cast<const Derived&>(*this);
40     auto copy = std::make_shared<Derived>(self);
41     copy->parameters_.clear();
42     copy->buffers_.clear();
43     copy->children_.clear();
44     copy->reset();
45     TORCH_CHECK(
46         copy->parameters_.size() == parameters_.size(),
47         "The cloned module does not have the same number of "
48         "parameters as the original module after calling reset(). "
49         "Are you sure you called register_parameter() inside reset() "
50         "and not the constructor?");
51     for (const auto& parameter : named_parameters(/*recurse=*/false)) {
52       auto& tensor = *parameter;
53       auto data = device && tensor.device() != *device
54           ? tensor.to(*device)
55           : autograd::Variable(tensor).clone();
56       copy->parameters_[parameter.key()].set_data(data);
57     }
58     TORCH_CHECK(
59         copy->buffers_.size() == buffers_.size(),
60         "The cloned module does not have the same number of "
61         "buffers as the original module after calling reset(). "
62         "Are you sure you called register_buffer() inside reset() "
63         "and not the constructor?");
64     for (const auto& buffer : named_buffers(/*recurse=*/false)) {
65       auto& tensor = *buffer;
66       auto data = device && tensor.device() != *device
67           ? tensor.to(*device)
68           : autograd::Variable(tensor).clone();
69       copy->buffers_[buffer.key()].set_data(data);
70     }
71     TORCH_CHECK(
72         copy->children_.size() == children_.size(),
73         "The cloned module does not have the same number of "
74         "child modules as the original module after calling reset(). "
75         "Are you sure you called register_module() inside reset() "
76         "and not the constructor?");
77     for (const auto& child : children_) {
78       copy->children_[child.key()]->clone_(*child.value(), device);
79     }
80     return copy;
81   }
82 
83  private:
clone_(Module & other,const std::optional<Device> & device)84   void clone_(Module& other, const std::optional<Device>& device) final {
85     // Here we are *pretty* certain that `other's` type is `Derived` (because it
86     // was registered under the same name as `this`), but you never know what
87     // crazy things `reset()` does, so `dynamic_cast` just to be safe.
88     auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device));
89     TORCH_CHECK(
90         clone != nullptr,
91         "Attempted to clone submodule, but it is of a "
92         "different type than the submodule it was to be cloned into");
93     static_cast<Derived&>(*this) = *clone;
94   }
95 };
96 
97 } // namespace nn
98 } // namespace torch
99