1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/module.h> 5 #include <torch/ordered_dict.h> 6 #include <vector> 7 8 namespace torch { 9 namespace nn { 10 11 /// An OrderedDict of `Module`s that registers its elements by their `key`s. 12 /// 13 /// \rst 14 /// .. code-block:: cpp 15 /// 16 /// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { 17 /// {"linear", Linear(10, 3).ptr()}, 18 /// {"conv", Conv2d(1, 2, 3).ptr()}, 19 /// {"dropout", Dropout(0.5).ptr()}, 20 /// }; 21 /// torch::nn::ModuleDict dict1(ordereddict); 22 /// 23 /// for (const auto &module : *dict1) { 24 /// module->pretty_print(std::cout); 25 /// } 26 /// 27 /// std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = { 28 /// {"linear", Linear(10, 3).ptr()}, 29 /// {"conv", Conv2d(1, 2, 3).ptr()}, 30 /// {"dropout", Dropout(0.5).ptr()}, 31 /// }; 32 /// torch::nn::ModuleDict dict2(list); 33 /// 34 /// for (const auto &module : *dict2) { 35 /// module->pretty_print(std::cout); 36 /// } 37 /// 38 /// \endrst 39 /// 40 /// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`? 41 /// The value a `ModuleDict` provides over manually calling an ordered map of 42 /// modules is that it allows treating the whole container *as a single module*, 43 /// such that performing a transformation on the `ModuleDict` applies to each of 44 /// the modules it stores (which are each a registered submodule of the 45 /// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict` 46 /// will move each module in the map to CUDA memory. For example: 47 /// 48 /// \rst 49 /// .. code-block:: cpp 50 /// 51 /// torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = { 52 /// {"linear", Linear(10, 3).ptr()}, 53 /// {"conv", Conv2d(1, 2, 3).ptr()}, 54 /// {"dropout", Dropout(0.5).ptr()}, 55 /// }; 56 /// torch::nn::ModuleDict dict(ordereddict); 57 /// 58 /// // Convert all modules to CUDA. 59 /// dict->to(torch::kCUDA); 60 /// 61 /// \endrst 62 /// 63 /// Finally, `ModuleDict` provides a lightweight container API, such as allowing 64 /// iteration over submodules, positional access, adding new modules from a 65 /// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after 66 /// construction via `update`. 67 class ModuleDictImpl : public Cloneable<ModuleDictImpl> { 68 public: 69 using Iterator = 70 torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator; 71 using ConstIterator = 72 torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator; 73 74 ModuleDictImpl() = default; 75 76 /// Constructs the `ModuleDict` from a list of string-Module pairs. ModuleDictImpl(const std::vector<std::pair<std::string,std::shared_ptr<Module>>> & modules)77 explicit ModuleDictImpl( 78 const std::vector<std::pair<std::string, std::shared_ptr<Module>>>& 79 modules) { 80 update(modules); 81 } 82 83 /// Constructs the `ModuleDict` from an `OrderedDict`. ModuleDictImpl(const torch::OrderedDict<std::string,std::shared_ptr<Module>> & modules)84 explicit ModuleDictImpl( 85 const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) { 86 update(modules); 87 } 88 89 /// Return the items in the `ModuleDict`. items()90 std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const { 91 return modules_.pairs(); 92 } 93 94 /// Return the keys in the `ModuleDict`. keys()95 std::vector<std::string> keys() const { 96 return modules_.keys(); 97 } 98 99 /// Return the values in the `ModuleDict`. values()100 std::vector<std::shared_ptr<Module>> values() const { 101 return modules_.values(); 102 } 103 104 /// Return an iterator to the start of `ModuleDict`. begin()105 Iterator begin() { 106 return modules_.begin(); 107 } 108 109 /// Return a const iterator to the start of `ModuleDict`. begin()110 ConstIterator begin() const { 111 return modules_.begin(); 112 } 113 114 /// Return an iterator to the end of `ModuleDict`. end()115 Iterator end() { 116 return modules_.end(); 117 } 118 119 /// Return a const iterator to the end of `ModuleDict`. end()120 ConstIterator end() const { 121 return modules_.end(); 122 } 123 124 /// Return the number of items currently stored in the `ModuleDict`. size()125 size_t size() const noexcept { 126 return modules_.size(); 127 } 128 129 /// Return true if the `ModuleDict` is empty, otherwise return false. empty()130 bool empty() const noexcept { 131 return modules_.is_empty(); 132 } 133 134 /// Check if the centain parameter with the key in the `ModuleDict`. contains(const std::string & key)135 bool contains(const std::string& key) const noexcept { 136 return modules_.contains(key); 137 } 138 139 /// Remove all items from the `ModuleDict`. clear()140 void clear() { 141 // Not remove the registration of modules to make it consistent with python 142 // version. 143 modules_.clear(); 144 } 145 146 /// Special cloning function for `ModuleDict` because it does not use 147 /// `reset()`. 148 std::shared_ptr<Module> clone( 149 const std::optional<Device>& device = std::nullopt) const override { 150 auto clone = std::make_shared<ModuleDictImpl>(); 151 for (const auto& module : modules_) { 152 clone->insert(module.key(), module.value()->clone(device)); 153 } 154 return clone; 155 } 156 157 /// `reset()` is empty for `ModuleDict`, since it does not have parameters of 158 /// its own. reset()159 void reset() override {} 160 161 /// Pretty prints the `ModuleDict` into the given `stream`. pretty_print(std::ostream & stream)162 void pretty_print(std::ostream& stream) const override { 163 stream << "torch::nn::ModuleDict"; 164 } 165 166 /// Attempts to returns the `Module` associated with the given `key`. Throws 167 /// an exception if no such `key` is stored in the `ModuleDict`. Check 168 /// contains(key) before for a non-throwing way of access. 169 std::shared_ptr<Module> operator[](const std::string& key) const { 170 return modules_[key]; 171 } 172 173 /// Attempts to return the module at the given key as the requested type. 174 /// Throws an exception if no such `key` is stored in the `ModuleDict`. 175 /// Check contains(key) before for a non-throwing way of access. 176 template <typename T> at(const std::string & key)177 T& at(const std::string& key) { 178 static_assert( 179 torch::detail::is_module<T>::value, 180 "Can only call ModuleList::at with an nn::Module type"); 181 auto module = modules_[key]->as<T>(); 182 TORCH_CHECK( 183 module, 184 "Unable to cast module[", 185 key, 186 "] to ", 187 c10::demangle(typeid(T).name())); 188 return *module; 189 } 190 191 /// Attempts to return the module at the given key as the requested type. 192 /// Throws an exception if no such `key` is stored in the `ModuleDict`. 193 /// Check contains(key) before for a non-throwing way of access. 194 template <typename T> at(const std::string & key)195 const T& at(const std::string& key) const { 196 static_assert( 197 torch::detail::is_module<T>::value, 198 "Can only call ModuleList::at with an nn::Module type"); 199 const auto module = modules_[key]->as<T>(); 200 TORCH_CHECK( 201 module, 202 "Unable to cast module[", 203 key, 204 "] to ", 205 c10::demangle(typeid(T).name())); 206 return *module; 207 } 208 209 /// Removes and returns the `Module` associated with the given `key`. 210 /// Throws an exception if no such `key` is stored in the `ModuleDict`. 211 /// Check contains(key) before for a non-throwing way of access. pop(const std::string & key)212 std::shared_ptr<Module> pop(const std::string& key) { 213 auto module = modules_[key]; 214 modules_.erase(key); 215 // Not remove the registration of the module to make it consistent with 216 // python version. 217 return module; 218 } 219 220 /// Updated the `ModuleDict` with a vector of key-module pairs. update(const std::vector<std::pair<std::string,std::shared_ptr<Module>>> & modules)221 void update( 222 const std::vector<std::pair<std::string, std::shared_ptr<Module>>>& 223 modules) { 224 for (auto& item : modules) { 225 insert(item.first, item.second); 226 } 227 } 228 229 /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or 230 /// `ModuleDict`. 231 template <typename Container> update(const Container & container)232 void update(const Container& container) { 233 for (auto& item : container) { 234 insert(item.key(), item.value()); 235 } 236 } 237 238 private: 239 /// Private `OrderedDict` holding the key-Module pairs. 240 torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_; 241 242 /// Insert a key-module pair by overwriting existing keys, 243 /// and register or replace the `Module`. insert(const std::string & key,std::shared_ptr<Module> module)244 void insert(const std::string& key, std::shared_ptr<Module> module) { 245 if (contains(key)) { 246 modules_[key] = std::move(module); 247 replace_module(key, modules_[key]); 248 } else { 249 modules_.insert(key, std::move(module)); 250 register_module(key, modules_.back().value()); 251 } 252 } 253 }; 254 255 /// A `ModuleHolder` subclass for `ModuleDictImpl`. 256 /// See the documentation for `ModuleDictImpl` class to learn what methods it 257 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 258 /// module storage semantics. 259 TORCH_MODULE(ModuleDict); 260 261 } // namespace nn 262 } // namespace torch 263