xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/moduledict.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/module.h>
5 #include <torch/ordered_dict.h>
6 #include <vector>
7 
8 namespace torch {
9 namespace nn {
10 
11 /// An OrderedDict of `Module`s that registers its elements by their `key`s.
12 ///
13 /// \rst
14 /// .. code-block:: cpp
15 ///
16 ///   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
17 ///     {"linear", Linear(10, 3).ptr()},
18 ///     {"conv", Conv2d(1, 2, 3).ptr()},
19 ///     {"dropout", Dropout(0.5).ptr()},
20 ///   };
21 ///   torch::nn::ModuleDict dict1(ordereddict);
22 ///
23 ///   for (const auto &module : *dict1) {
24 ///     module->pretty_print(std::cout);
25 ///   }
26 ///
27 ///   std::vector<std::pair<std::string, std::shared_ptr<Module>>> list = {
28 ///     {"linear", Linear(10, 3).ptr()},
29 ///     {"conv", Conv2d(1, 2, 3).ptr()},
30 ///     {"dropout", Dropout(0.5).ptr()},
31 ///   };
32 ///   torch::nn::ModuleDict dict2(list);
33 ///
34 ///   for (const auto &module : *dict2) {
35 ///     module->pretty_print(std::cout);
36 ///   }
37 ///
38 /// \endrst
39 ///
40 /// Why should you use `ModuleDict` instead of a simple `map` or `OrderedDict`?
41 /// The value a `ModuleDict` provides over manually calling an ordered map of
42 /// modules is that it allows treating the whole container *as a single module*,
43 /// such that performing a transformation on the `ModuleDict` applies to each of
44 /// the modules it stores (which are each a registered submodule of the
45 /// `ModuleDict`). For example, calling `.to(torch::kCUDA)` on a `ModuleDict`
46 /// will move each module in the map to CUDA memory. For example:
47 ///
48 /// \rst
49 /// .. code-block:: cpp
50 ///
51 ///   torch::OrderedDict<std::string, std::shared_ptr<Module>> ordereddict = {
52 ///     {"linear", Linear(10, 3).ptr()},
53 ///     {"conv", Conv2d(1, 2, 3).ptr()},
54 ///     {"dropout", Dropout(0.5).ptr()},
55 ///   };
56 ///   torch::nn::ModuleDict dict(ordereddict);
57 ///
58 ///   // Convert all modules to CUDA.
59 ///   dict->to(torch::kCUDA);
60 ///
61 /// \endrst
62 ///
63 /// Finally, `ModuleDict` provides a lightweight container API, such as allowing
64 /// iteration over submodules, positional access, adding new modules from a
65 /// vector of key-module pairs or an `OrderedDict` or another `ModuleDict` after
66 /// construction via `update`.
67 class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
68  public:
69   using Iterator =
70       torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator;
71   using ConstIterator =
72       torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;
73 
74   ModuleDictImpl() = default;
75 
76   /// Constructs the `ModuleDict` from a list of string-Module pairs.
ModuleDictImpl(const std::vector<std::pair<std::string,std::shared_ptr<Module>>> & modules)77   explicit ModuleDictImpl(
78       const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
79           modules) {
80     update(modules);
81   }
82 
83   /// Constructs the `ModuleDict` from an `OrderedDict`.
ModuleDictImpl(const torch::OrderedDict<std::string,std::shared_ptr<Module>> & modules)84   explicit ModuleDictImpl(
85       const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) {
86     update(modules);
87   }
88 
89   /// Return the items in the `ModuleDict`.
items()90   std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const {
91     return modules_.pairs();
92   }
93 
94   /// Return the keys in the `ModuleDict`.
keys()95   std::vector<std::string> keys() const {
96     return modules_.keys();
97   }
98 
99   /// Return the values in the `ModuleDict`.
values()100   std::vector<std::shared_ptr<Module>> values() const {
101     return modules_.values();
102   }
103 
104   /// Return an iterator to the start of `ModuleDict`.
begin()105   Iterator begin() {
106     return modules_.begin();
107   }
108 
109   /// Return a const iterator to the start of `ModuleDict`.
begin()110   ConstIterator begin() const {
111     return modules_.begin();
112   }
113 
114   /// Return an iterator to the end of `ModuleDict`.
end()115   Iterator end() {
116     return modules_.end();
117   }
118 
119   /// Return a const iterator to the end of `ModuleDict`.
end()120   ConstIterator end() const {
121     return modules_.end();
122   }
123 
124   /// Return the number of items currently stored in the `ModuleDict`.
size()125   size_t size() const noexcept {
126     return modules_.size();
127   }
128 
129   /// Return true if the `ModuleDict` is empty, otherwise return false.
empty()130   bool empty() const noexcept {
131     return modules_.is_empty();
132   }
133 
134   /// Check if the centain parameter with the key in the `ModuleDict`.
contains(const std::string & key)135   bool contains(const std::string& key) const noexcept {
136     return modules_.contains(key);
137   }
138 
139   /// Remove all items from the `ModuleDict`.
clear()140   void clear() {
141     // Not remove the registration of modules to make it consistent with python
142     // version.
143     modules_.clear();
144   }
145 
146   /// Special cloning function for `ModuleDict` because it does not use
147   /// `reset()`.
148   std::shared_ptr<Module> clone(
149       const std::optional<Device>& device = std::nullopt) const override {
150     auto clone = std::make_shared<ModuleDictImpl>();
151     for (const auto& module : modules_) {
152       clone->insert(module.key(), module.value()->clone(device));
153     }
154     return clone;
155   }
156 
157   /// `reset()` is empty for `ModuleDict`, since it does not have parameters of
158   /// its own.
reset()159   void reset() override {}
160 
161   /// Pretty prints the `ModuleDict` into the given `stream`.
pretty_print(std::ostream & stream)162   void pretty_print(std::ostream& stream) const override {
163     stream << "torch::nn::ModuleDict";
164   }
165 
166   /// Attempts to returns the `Module` associated with the given `key`. Throws
167   /// an exception if no such `key` is stored in the `ModuleDict`. Check
168   /// contains(key) before for a non-throwing way of access.
169   std::shared_ptr<Module> operator[](const std::string& key) const {
170     return modules_[key];
171   }
172 
173   /// Attempts to return the module at the given key as the requested type.
174   /// Throws an exception if no such `key` is stored in the `ModuleDict`.
175   /// Check contains(key) before for a non-throwing way of access.
176   template <typename T>
at(const std::string & key)177   T& at(const std::string& key) {
178     static_assert(
179         torch::detail::is_module<T>::value,
180         "Can only call ModuleList::at with an nn::Module type");
181     auto module = modules_[key]->as<T>();
182     TORCH_CHECK(
183         module,
184         "Unable to cast module[",
185         key,
186         "] to ",
187         c10::demangle(typeid(T).name()));
188     return *module;
189   }
190 
191   /// Attempts to return the module at the given key as the requested type.
192   /// Throws an exception if no such `key` is stored in the `ModuleDict`.
193   /// Check contains(key) before for a non-throwing way of access.
194   template <typename T>
at(const std::string & key)195   const T& at(const std::string& key) const {
196     static_assert(
197         torch::detail::is_module<T>::value,
198         "Can only call ModuleList::at with an nn::Module type");
199     const auto module = modules_[key]->as<T>();
200     TORCH_CHECK(
201         module,
202         "Unable to cast module[",
203         key,
204         "] to ",
205         c10::demangle(typeid(T).name()));
206     return *module;
207   }
208 
209   /// Removes and returns the `Module` associated with the given `key`.
210   /// Throws an exception if no such `key` is stored in the `ModuleDict`.
211   /// Check contains(key) before for a non-throwing way of access.
pop(const std::string & key)212   std::shared_ptr<Module> pop(const std::string& key) {
213     auto module = modules_[key];
214     modules_.erase(key);
215     // Not remove the registration of the module to make it consistent with
216     // python version.
217     return module;
218   }
219 
220   /// Updated the `ModuleDict` with a vector of key-module pairs.
update(const std::vector<std::pair<std::string,std::shared_ptr<Module>>> & modules)221   void update(
222       const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
223           modules) {
224     for (auto& item : modules) {
225       insert(item.first, item.second);
226     }
227   }
228 
229   /// Updated the `ModuleDict` with key-value pairs from `OrderedDict` or
230   /// `ModuleDict`.
231   template <typename Container>
update(const Container & container)232   void update(const Container& container) {
233     for (auto& item : container) {
234       insert(item.key(), item.value());
235     }
236   }
237 
238  private:
239   /// Private `OrderedDict` holding the key-Module pairs.
240   torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_;
241 
242   /// Insert a key-module pair by overwriting existing keys,
243   /// and register or replace the `Module`.
insert(const std::string & key,std::shared_ptr<Module> module)244   void insert(const std::string& key, std::shared_ptr<Module> module) {
245     if (contains(key)) {
246       modules_[key] = std::move(module);
247       replace_module(key, modules_[key]);
248     } else {
249       modules_.insert(key, std::move(module));
250       register_module(key, modules_.back().value());
251     }
252   }
253 };
254 
255 /// A `ModuleHolder` subclass for `ModuleDictImpl`.
256 /// See the documentation for `ModuleDictImpl` class to learn what methods it
257 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
258 /// module storage semantics.
259 TORCH_MODULE(ModuleDict);
260 
261 } // namespace nn
262 } // namespace torch
263