1 #pragma once 2 3 #include <torch/nn/module.h> 4 #include <torch/types.h> 5 #include <torch/utils.h> 6 7 #include <c10/core/TensorOptions.h> 8 #include <c10/util/Exception.h> 9 10 #include <memory> 11 #include <utility> 12 13 namespace torch { 14 namespace nn { 15 /// The `clone()` method in the base `Module` class does not have knowledge of 16 /// the concrete runtime type of its subclasses. Therefore, `clone()` must 17 /// either be called from within the subclass, or from a base class that has 18 /// knowledge of the concrete type. `Cloneable` uses the CRTP to gain 19 /// knowledge of the subclass' static type and provide an implementation of the 20 /// `clone()` method. We do not want to use this pattern in the base class, 21 /// because then storing a module would always require templatizing it. 22 template <typename Derived> 23 // NOLINTNEXTLINE(bugprone-exception-escape) 24 class Cloneable : public Module { 25 public: 26 using Module::Module; 27 28 /// `reset()` must perform initialization of all members with reference 29 /// semantics, most importantly parameters, buffers and submodules. 30 virtual void reset() = 0; 31 32 /// Performs a recursive "deep copy" of the `Module`, such that all parameters 33 /// and submodules in the cloned module are different from those in the 34 /// original module. 35 std::shared_ptr<Module> clone( 36 const std::optional<Device>& device = std::nullopt) const override { 37 NoGradGuard no_grad; 38 39 const auto& self = static_cast<const Derived&>(*this); 40 auto copy = std::make_shared<Derived>(self); 41 copy->parameters_.clear(); 42 copy->buffers_.clear(); 43 copy->children_.clear(); 44 copy->reset(); 45 TORCH_CHECK( 46 copy->parameters_.size() == parameters_.size(), 47 "The cloned module does not have the same number of " 48 "parameters as the original module after calling reset(). " 49 "Are you sure you called register_parameter() inside reset() " 50 "and not the constructor?"); 51 for (const auto& parameter : named_parameters(/*recurse=*/false)) { 52 auto& tensor = *parameter; 53 auto data = device && tensor.device() != *device 54 ? tensor.to(*device) 55 : autograd::Variable(tensor).clone(); 56 copy->parameters_[parameter.key()].set_data(data); 57 } 58 TORCH_CHECK( 59 copy->buffers_.size() == buffers_.size(), 60 "The cloned module does not have the same number of " 61 "buffers as the original module after calling reset(). " 62 "Are you sure you called register_buffer() inside reset() " 63 "and not the constructor?"); 64 for (const auto& buffer : named_buffers(/*recurse=*/false)) { 65 auto& tensor = *buffer; 66 auto data = device && tensor.device() != *device 67 ? tensor.to(*device) 68 : autograd::Variable(tensor).clone(); 69 copy->buffers_[buffer.key()].set_data(data); 70 } 71 TORCH_CHECK( 72 copy->children_.size() == children_.size(), 73 "The cloned module does not have the same number of " 74 "child modules as the original module after calling reset(). " 75 "Are you sure you called register_module() inside reset() " 76 "and not the constructor?"); 77 for (const auto& child : children_) { 78 copy->children_[child.key()]->clone_(*child.value(), device); 79 } 80 return copy; 81 } 82 83 private: clone_(Module & other,const std::optional<Device> & device)84 void clone_(Module& other, const std::optional<Device>& device) final { 85 // Here we are *pretty* certain that `other's` type is `Derived` (because it 86 // was registered under the same name as `this`), but you never know what 87 // crazy things `reset()` does, so `dynamic_cast` just to be safe. 88 auto clone = std::dynamic_pointer_cast<Derived>(other.clone(device)); 89 TORCH_CHECK( 90 clone != nullptr, 91 "Attempted to clone submodule, but it is of a " 92 "different type than the submodule it was to be cloned into"); 93 static_cast<Derived&>(*this) = *clone; 94 } 95 }; 96 97 } // namespace nn 98 } // namespace torch 99