xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/any.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/container/any_module_holder.h>
6 #include <torch/nn/modules/container/any_value.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9 
10 #include <torch/csrc/autograd/variable.h>
11 #include <torch/csrc/utils/variadic.h>
12 
13 #include <ATen/Device.h>
14 
15 #include <memory>
16 #include <type_traits>
17 #include <typeinfo>
18 #include <utility>
19 #include <vector>
20 
21 namespace torch {
22 namespace nn {
23 
24 /// Stores a type erased `Module`.
25 ///
26 /// The PyTorch C++ API does not impose an interface on the signature of
27 /// `forward()` in `Module` subclasses. This gives you complete freedom to
28 /// design your `forward()` methods to your liking. However, this also means
29 /// there is no unified base type you could store in order to call `forward()`
30 /// polymorphically for any module. This is where the `AnyModule` comes in.
31 /// Instead of inheritance, it relies on type erasure for polymorphism.
32 ///
33 /// An `AnyModule` can store any `nn::Module` subclass that provides a
34 /// `forward()` method. This `forward()` may accept any types and return any
35 /// type. Once stored in an `AnyModule`, you can invoke the underlying module's
36 /// `forward()` by calling `AnyModule::forward()` with the arguments you would
37 /// supply to the stored module (though see one important limitation below).
38 /// Example:
39 ///
40 /// \rst
41 /// .. code-block:: cpp
42 ///
43 ///   struct GenericTrainer {
44 ///     torch::nn::AnyModule module;
45 ///
46 ///     void train(torch::Tensor input) {
47 ///       module.forward(input);
48 ///     }
49 ///   };
50 ///
51 ///   GenericTrainer trainer1{torch::nn::Linear(3, 4)};
52 ///   GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
53 /// \endrst
54 ///
55 /// As `AnyModule` erases the static type of the stored module (and its
56 /// `forward()` method) to achieve polymorphism, type checking of arguments is
57 /// moved to runtime. That is, passing an argument with an incorrect type to an
58 /// `AnyModule` will compile, but throw an exception at runtime:
59 ///
60 /// \rst
61 /// .. code-block:: cpp
62 ///
63 ///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
64 ///   // Linear takes a tensor as input, but we are passing an integer.
65 ///   // This will compile, but throw a `torch::Error` exception at runtime.
66 ///   module.forward(123);
67 /// \endrst
68 ///
69 /// \rst
70 /// .. attention::
71 ///   One noteworthy limitation of `AnyModule` is that its `forward()` method
72 ///   does not support implicit conversion of argument types. For example, if
73 ///   the stored module's `forward()` method accepts a `float` and you call
74 ///   `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
75 ///   an exception.
76 /// \endrst
77 ///
78 /// The return type of the `AnyModule`'s `forward()` method is controlled via
79 /// the first template argument to `AnyModule::forward()`. It defaults to
80 /// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
81 /// for example.
82 ///
83 /// \rst
84 /// .. code-block:: cpp
85 ///
86 ///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
87 ///   auto output = module.forward(torch::ones({2, 3}));
88 ///
89 ///   struct IntModule {
90 ///     int forward(int x) { return x; }
91 ///   };
92 ///   torch::nn::AnyModule module(IntModule{});
93 ///   int output = module.forward<int>(5);
94 /// \endrst
95 ///
96 /// The only other method an `AnyModule` provides access to on the stored
97 /// module is `clone()`. However, you may acquire a handle on the module via
98 /// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
99 /// the concrete type of the stored module, you can get a concrete handle to it
100 /// using `.get<T>()` where `T` is the concrete module type.
101 ///
102 /// \rst
103 /// .. code-block:: cpp
104 ///
105 ///   torch::nn::AnyModule module(torch::nn::Linear(3, 4));
106 ///   std::shared_ptr<nn::Module> ptr = module.ptr();
107 ///   torch::nn::Linear linear(module.get<torch::nn::Linear>());
108 /// \endrst
109 class AnyModule {
110  public:
111   /// A default-constructed `AnyModule` is in an empty state.
112   AnyModule() = default;
113 
114   /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
115   template <typename ModuleType>
116   explicit AnyModule(std::shared_ptr<ModuleType> module);
117 
118   /// Constructs an `AnyModule` from a concrete module object.
119   template <
120       typename ModuleType,
121       typename = torch::detail::enable_if_module_t<ModuleType>>
122   explicit AnyModule(ModuleType&& module);
123 
124   /// Constructs an `AnyModule` from a module holder.
125   template <typename ModuleType>
126   explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
127 
128   /// Move construction and assignment is allowed, and follows the default
129   /// behavior of move for `std::unique_ptr`.
130   AnyModule(AnyModule&&) = default;
131   AnyModule& operator=(AnyModule&&) = default;
132 
133   /// Creates a shallow copy of an `AnyModule`.
134   AnyModule(const AnyModule& other);
135   AnyModule& operator=(const AnyModule& other);
136 
137   /// Creates a deep copy of an `AnyModule` if it contains a module, else an
138   /// empty `AnyModule` if it is empty.
139   AnyModule clone(std::optional<Device> device = std::nullopt) const;
140 
141   /// Assigns a module to the `AnyModule` (to circumvent the explicit
142   /// constructor).
143   template <typename ModuleType>
144   AnyModule& operator=(std::shared_ptr<ModuleType> module);
145 
146   /// Invokes `forward()` on the contained module with the given arguments, and
147   /// returns the return value as an `AnyValue`. Use this method when chaining
148   /// `AnyModule`s in a loop.
149   template <typename... ArgumentTypes>
150   AnyValue any_forward(ArgumentTypes&&... arguments);
151 
152   /// Invokes `forward()` on the contained module with the given arguments, and
153   /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults
154   /// to `torch::Tensor`).
155   template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
156   ReturnType forward(ArgumentTypes&&... arguments);
157 
158   /// Attempts to cast the underlying module to the given module type. Throws an
159   /// exception if the types do not match.
160   template <typename T, typename = torch::detail::enable_if_module_t<T>>
161   T& get();
162 
163   /// Attempts to cast the underlying module to the given module type. Throws an
164   /// exception if the types do not match.
165   template <typename T, typename = torch::detail::enable_if_module_t<T>>
166   const T& get() const;
167 
168   /// Returns the contained module in a `nn::ModuleHolder` subclass if possible
169   /// (i.e. if `T` has a constructor for the underlying module type).
170   template <typename T, typename ContainedType = typename T::ContainedType>
171   T get() const;
172 
173   /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying
174   /// module.
175   std::shared_ptr<Module> ptr() const;
176 
177   /// Like `ptr()`, but casts the pointer to the given type.
178   template <typename T, typename = torch::detail::enable_if_module_t<T>>
179   std::shared_ptr<T> ptr() const;
180 
181   /// Returns the `type_info` object of the contained value.
182   const std::type_info& type_info() const;
183 
184   /// Returns true if the `AnyModule` does not contain a module.
185   bool is_empty() const noexcept;
186 
187  private:
188   /// Creates a `unique_ptr<AnyModulePlaceholder>` pointing to a
189   /// `AnyModuleHolder` of the correct type. This method is used to deduce the
190   /// arguments of the module's `forward()` method.
191   template <
192       typename ModuleType,
193       typename Class,
194       typename ReturnType,
195       typename... ArgumentTypes>
196   std::unique_ptr<AnyModulePlaceholder> make_holder(
197       std::shared_ptr<ModuleType>&& module,
198       ReturnType (Class::*)(ArgumentTypes...));
199 
200   /// Helper method invoked by const and non-const `get()`.
201   template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
202   ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
203 
204   /// Helper method invoked by const and non-const `get()`.
205   template <typename ModuleType>
206   ModuleType& get_() const;
207 
208   /// The type erased module.
209   std::unique_ptr<AnyModulePlaceholder> content_;
210 };
211 
212 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213 
214 template <typename ModuleType>
AnyModule(std::shared_ptr<ModuleType> module)215 AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
216     : content_(make_holder(
217           std::move(module),
218           &std::remove_reference<ModuleType>::type::forward)) {
219   // `AnyModule` can only store an `nn::Module` subclass object that provides
220   // a `forward()` method that has a non-templatized return type.
221   // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
222   // `forward()` method has a templatized return type.)
223   static_assert(
224       torch::detail::is_module<ModuleType>::value,
225       "Can only store object derived from nn::Module into AnyModule");
226   static_assert(
227       torch::detail::has_forward<ModuleType>::value,
228       "Can only store module with a forward() method that has a non-templatized"
229       " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
230       "into AnyModule, because its forward() method's argument type and return type are templatized."
231       " If you need to use nn::Sequentials inside each other you can subclass "
232       "nn::Sequential and write a non-templatized forward function for it. You can checkout "
233       "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
234       "for an example on how to do this.).");
235 }
236 
237 template <typename ModuleType, typename>
AnyModule(ModuleType && module)238 AnyModule::AnyModule(ModuleType&& module)
239     : AnyModule(
240           std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
241 
242 template <typename ModuleType>
AnyModule(const ModuleHolder<ModuleType> & module_holder)243 AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
244     : AnyModule(module_holder.ptr()) {}
245 
AnyModule(const AnyModule & other)246 inline AnyModule::AnyModule(const AnyModule& other)
247     : content_(other.content_ ? other.content_->copy() : nullptr) {}
248 
249 inline AnyModule& AnyModule::operator=(const AnyModule& other) {
250   if (this != &other) {
251     content_ = other.content_ ? other.content_->copy() : nullptr;
252   }
253   return *this;
254 }
255 
clone(std::optional<Device> device)256 inline AnyModule AnyModule::clone(std::optional<Device> device) const {
257   AnyModule clone;
258   clone.content_ = content_ ? content_->clone_module(device) : nullptr;
259   return clone;
260 }
261 
262 template <typename ModuleType>
263 AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
264   // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
265   return (*this = AnyModule(std::move(module)));
266 }
267 
268 template <typename... ArgumentTypes>
any_forward(ArgumentTypes &&...arguments)269 AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
270   TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
271   std::vector<AnyValue> values;
272   values.reserve(sizeof...(ArgumentTypes));
273   torch::apply(
274       [&values](AnyValue&& value) { values.push_back(std::move(value)); },
275       AnyValue(std::forward<ArgumentTypes>(arguments))...);
276   return content_->forward(std::move(values));
277 }
278 
279 template <typename ReturnType, typename... ArgumentTypes>
forward(ArgumentTypes &&...arguments)280 ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
281   return any_forward(std::forward<ArgumentTypes>(arguments)...)
282       .template get<ReturnType>();
283 }
284 
285 template <typename T, typename>
get()286 T& AnyModule::get() {
287   TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
288   return get_<T>();
289 }
290 
291 template <typename T, typename>
get()292 const T& AnyModule::get() const {
293   TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
294   return get_<T>();
295 }
296 
297 template <typename T, typename ContainedType>
get()298 T AnyModule::get() const {
299   return T(ptr<ContainedType>());
300 }
301 
ptr()302 inline std::shared_ptr<Module> AnyModule::ptr() const {
303   TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
304   return content_->ptr();
305 }
306 
307 template <typename T, typename>
ptr()308 std::shared_ptr<T> AnyModule::ptr() const {
309   TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
310   // Call get() but discard the value, just to do the type checking.
311   get_<T>();
312   return std::dynamic_pointer_cast<T>(ptr());
313 }
314 
type_info()315 inline const std::type_info& AnyModule::type_info() const {
316   TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
317   return content_->type_info;
318 }
319 
is_empty()320 inline bool AnyModule::is_empty() const noexcept {
321   return content_ == nullptr;
322 }
323 
324 // Private Methods
325 
326 template <
327     typename ModuleType,
328     typename Class,
329     typename ReturnType,
330     typename... ArgumentTypes>
make_holder(std::shared_ptr<ModuleType> && module,ReturnType (Class::*)(ArgumentTypes...))331 std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
332     std::shared_ptr<ModuleType>&& module,
333     ReturnType (Class::*)(ArgumentTypes...)) {
334   static_assert(
335       torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
336       "Modules stored inside AnyModule must not take references. "
337       "Use pointers instead.");
338   static_assert(
339       !std::is_void<ReturnType>::value,
340       "AnyModule cannot store modules that return void "
341       "(you can return a dummy value).");
342   return std::make_unique<
343       AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(
344       std::move(module));
345 }
346 
347 template <typename ModuleType>
get_()348 ModuleType& AnyModule::get_() const {
349   using M = typename std::remove_reference<ModuleType>::type;
350   static_assert(
351       torch::detail::has_forward<M>::value,
352       "Can only call AnyModule::get<T> with a type T that has a forward method");
353   return get_(&M::forward);
354 }
355 
356 template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
get_(ReturnType (ModuleType::*)(ArgumentTypes...))357 ModuleType& AnyModule::get_(
358     ReturnType (ModuleType::*)(ArgumentTypes...)) const {
359   if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
360     return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
361                 *content_)
362                 .module;
363   }
364   AT_ERROR(
365       "Attempted to cast module of type ",
366       c10::demangle(type_info().name()),
367       " to type ",
368       c10::demangle(typeid(ModuleType).name()));
369 }
370 
371 } // namespace nn
372 } // namespace torch
373