1 #pragma once
2
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/modules/container/any_module_holder.h>
6 #include <torch/nn/modules/container/any_value.h>
7 #include <torch/nn/pimpl.h>
8 #include <torch/types.h>
9
10 #include <torch/csrc/autograd/variable.h>
11 #include <torch/csrc/utils/variadic.h>
12
13 #include <ATen/Device.h>
14
15 #include <memory>
16 #include <type_traits>
17 #include <typeinfo>
18 #include <utility>
19 #include <vector>
20
21 namespace torch {
22 namespace nn {
23
24 /// Stores a type erased `Module`.
25 ///
26 /// The PyTorch C++ API does not impose an interface on the signature of
27 /// `forward()` in `Module` subclasses. This gives you complete freedom to
28 /// design your `forward()` methods to your liking. However, this also means
29 /// there is no unified base type you could store in order to call `forward()`
30 /// polymorphically for any module. This is where the `AnyModule` comes in.
31 /// Instead of inheritance, it relies on type erasure for polymorphism.
32 ///
33 /// An `AnyModule` can store any `nn::Module` subclass that provides a
34 /// `forward()` method. This `forward()` may accept any types and return any
35 /// type. Once stored in an `AnyModule`, you can invoke the underlying module's
36 /// `forward()` by calling `AnyModule::forward()` with the arguments you would
37 /// supply to the stored module (though see one important limitation below).
38 /// Example:
39 ///
40 /// \rst
41 /// .. code-block:: cpp
42 ///
43 /// struct GenericTrainer {
44 /// torch::nn::AnyModule module;
45 ///
46 /// void train(torch::Tensor input) {
47 /// module.forward(input);
48 /// }
49 /// };
50 ///
51 /// GenericTrainer trainer1{torch::nn::Linear(3, 4)};
52 /// GenericTrainer trainer2{torch::nn::Conv2d(3, 4, 2)};
53 /// \endrst
54 ///
55 /// As `AnyModule` erases the static type of the stored module (and its
56 /// `forward()` method) to achieve polymorphism, type checking of arguments is
57 /// moved to runtime. That is, passing an argument with an incorrect type to an
58 /// `AnyModule` will compile, but throw an exception at runtime:
59 ///
60 /// \rst
61 /// .. code-block:: cpp
62 ///
63 /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
64 /// // Linear takes a tensor as input, but we are passing an integer.
65 /// // This will compile, but throw a `torch::Error` exception at runtime.
66 /// module.forward(123);
67 /// \endrst
68 ///
69 /// \rst
70 /// .. attention::
71 /// One noteworthy limitation of `AnyModule` is that its `forward()` method
72 /// does not support implicit conversion of argument types. For example, if
73 /// the stored module's `forward()` method accepts a `float` and you call
74 /// `any_module.forward(3.4)` (where `3.4` is a `double`), this will throw
75 /// an exception.
76 /// \endrst
77 ///
78 /// The return type of the `AnyModule`'s `forward()` method is controlled via
79 /// the first template argument to `AnyModule::forward()`. It defaults to
80 /// `torch::Tensor`. To change it, you can write `any_module.forward<int>()`,
81 /// for example.
82 ///
83 /// \rst
84 /// .. code-block:: cpp
85 ///
86 /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
87 /// auto output = module.forward(torch::ones({2, 3}));
88 ///
89 /// struct IntModule {
90 /// int forward(int x) { return x; }
91 /// };
92 /// torch::nn::AnyModule module(IntModule{});
93 /// int output = module.forward<int>(5);
94 /// \endrst
95 ///
96 /// The only other method an `AnyModule` provides access to on the stored
97 /// module is `clone()`. However, you may acquire a handle on the module via
98 /// `.ptr()`, which returns a `shared_ptr<nn::Module>`. Further, if you know
99 /// the concrete type of the stored module, you can get a concrete handle to it
100 /// using `.get<T>()` where `T` is the concrete module type.
101 ///
102 /// \rst
103 /// .. code-block:: cpp
104 ///
105 /// torch::nn::AnyModule module(torch::nn::Linear(3, 4));
106 /// std::shared_ptr<nn::Module> ptr = module.ptr();
107 /// torch::nn::Linear linear(module.get<torch::nn::Linear>());
108 /// \endrst
109 class AnyModule {
110 public:
111 /// A default-constructed `AnyModule` is in an empty state.
112 AnyModule() = default;
113
114 /// Constructs an `AnyModule` from a `shared_ptr` to concrete module object.
115 template <typename ModuleType>
116 explicit AnyModule(std::shared_ptr<ModuleType> module);
117
118 /// Constructs an `AnyModule` from a concrete module object.
119 template <
120 typename ModuleType,
121 typename = torch::detail::enable_if_module_t<ModuleType>>
122 explicit AnyModule(ModuleType&& module);
123
124 /// Constructs an `AnyModule` from a module holder.
125 template <typename ModuleType>
126 explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
127
128 /// Move construction and assignment is allowed, and follows the default
129 /// behavior of move for `std::unique_ptr`.
130 AnyModule(AnyModule&&) = default;
131 AnyModule& operator=(AnyModule&&) = default;
132
133 /// Creates a shallow copy of an `AnyModule`.
134 AnyModule(const AnyModule& other);
135 AnyModule& operator=(const AnyModule& other);
136
137 /// Creates a deep copy of an `AnyModule` if it contains a module, else an
138 /// empty `AnyModule` if it is empty.
139 AnyModule clone(std::optional<Device> device = std::nullopt) const;
140
141 /// Assigns a module to the `AnyModule` (to circumvent the explicit
142 /// constructor).
143 template <typename ModuleType>
144 AnyModule& operator=(std::shared_ptr<ModuleType> module);
145
146 /// Invokes `forward()` on the contained module with the given arguments, and
147 /// returns the return value as an `AnyValue`. Use this method when chaining
148 /// `AnyModule`s in a loop.
149 template <typename... ArgumentTypes>
150 AnyValue any_forward(ArgumentTypes&&... arguments);
151
152 /// Invokes `forward()` on the contained module with the given arguments, and
153 /// casts the returned `AnyValue` to the supplied `ReturnType` (which defaults
154 /// to `torch::Tensor`).
155 template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
156 ReturnType forward(ArgumentTypes&&... arguments);
157
158 /// Attempts to cast the underlying module to the given module type. Throws an
159 /// exception if the types do not match.
160 template <typename T, typename = torch::detail::enable_if_module_t<T>>
161 T& get();
162
163 /// Attempts to cast the underlying module to the given module type. Throws an
164 /// exception if the types do not match.
165 template <typename T, typename = torch::detail::enable_if_module_t<T>>
166 const T& get() const;
167
168 /// Returns the contained module in a `nn::ModuleHolder` subclass if possible
169 /// (i.e. if `T` has a constructor for the underlying module type).
170 template <typename T, typename ContainedType = typename T::ContainedType>
171 T get() const;
172
173 /// Returns a `std::shared_ptr` whose dynamic type is that of the underlying
174 /// module.
175 std::shared_ptr<Module> ptr() const;
176
177 /// Like `ptr()`, but casts the pointer to the given type.
178 template <typename T, typename = torch::detail::enable_if_module_t<T>>
179 std::shared_ptr<T> ptr() const;
180
181 /// Returns the `type_info` object of the contained value.
182 const std::type_info& type_info() const;
183
184 /// Returns true if the `AnyModule` does not contain a module.
185 bool is_empty() const noexcept;
186
187 private:
188 /// Creates a `unique_ptr<AnyModulePlaceholder>` pointing to a
189 /// `AnyModuleHolder` of the correct type. This method is used to deduce the
190 /// arguments of the module's `forward()` method.
191 template <
192 typename ModuleType,
193 typename Class,
194 typename ReturnType,
195 typename... ArgumentTypes>
196 std::unique_ptr<AnyModulePlaceholder> make_holder(
197 std::shared_ptr<ModuleType>&& module,
198 ReturnType (Class::*)(ArgumentTypes...));
199
200 /// Helper method invoked by const and non-const `get()`.
201 template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
202 ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
203
204 /// Helper method invoked by const and non-const `get()`.
205 template <typename ModuleType>
206 ModuleType& get_() const;
207
208 /// The type erased module.
209 std::unique_ptr<AnyModulePlaceholder> content_;
210 };
211
212 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
213
214 template <typename ModuleType>
AnyModule(std::shared_ptr<ModuleType> module)215 AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
216 : content_(make_holder(
217 std::move(module),
218 &std::remove_reference<ModuleType>::type::forward)) {
219 // `AnyModule` can only store an `nn::Module` subclass object that provides
220 // a `forward()` method that has a non-templatized return type.
221 // (e.g. `AnyModule` cannot store `nn::Sequential`, because `nn::Sequential`'s
222 // `forward()` method has a templatized return type.)
223 static_assert(
224 torch::detail::is_module<ModuleType>::value,
225 "Can only store object derived from nn::Module into AnyModule");
226 static_assert(
227 torch::detail::has_forward<ModuleType>::value,
228 "Can only store module with a forward() method that has a non-templatized"
229 " argument type and return type into AnyModule (e.g. we cannot store nn::Sequential"
230 "into AnyModule, because its forward() method's argument type and return type are templatized."
231 " If you need to use nn::Sequentials inside each other you can subclass "
232 "nn::Sequential and write a non-templatized forward function for it. You can checkout "
233 "https://github.com/pytorch/vision/blob/2f46070f3cb1ea894d82578f3dc5677f82f34958/torchvision/csrc/models/mnasnet.cpp#L59 "
234 "for an example on how to do this.).");
235 }
236
237 template <typename ModuleType, typename>
AnyModule(ModuleType && module)238 AnyModule::AnyModule(ModuleType&& module)
239 : AnyModule(
240 std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
241
242 template <typename ModuleType>
AnyModule(const ModuleHolder<ModuleType> & module_holder)243 AnyModule::AnyModule(const ModuleHolder<ModuleType>& module_holder)
244 : AnyModule(module_holder.ptr()) {}
245
AnyModule(const AnyModule & other)246 inline AnyModule::AnyModule(const AnyModule& other)
247 : content_(other.content_ ? other.content_->copy() : nullptr) {}
248
249 inline AnyModule& AnyModule::operator=(const AnyModule& other) {
250 if (this != &other) {
251 content_ = other.content_ ? other.content_->copy() : nullptr;
252 }
253 return *this;
254 }
255
clone(std::optional<Device> device)256 inline AnyModule AnyModule::clone(std::optional<Device> device) const {
257 AnyModule clone;
258 clone.content_ = content_ ? content_->clone_module(device) : nullptr;
259 return clone;
260 }
261
262 template <typename ModuleType>
263 AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
264 // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature)
265 return (*this = AnyModule(std::move(module)));
266 }
267
268 template <typename... ArgumentTypes>
any_forward(ArgumentTypes &&...arguments)269 AnyValue AnyModule::any_forward(ArgumentTypes&&... arguments) {
270 TORCH_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
271 std::vector<AnyValue> values;
272 values.reserve(sizeof...(ArgumentTypes));
273 torch::apply(
274 [&values](AnyValue&& value) { values.push_back(std::move(value)); },
275 AnyValue(std::forward<ArgumentTypes>(arguments))...);
276 return content_->forward(std::move(values));
277 }
278
279 template <typename ReturnType, typename... ArgumentTypes>
forward(ArgumentTypes &&...arguments)280 ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
281 return any_forward(std::forward<ArgumentTypes>(arguments)...)
282 .template get<ReturnType>();
283 }
284
285 template <typename T, typename>
get()286 T& AnyModule::get() {
287 TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
288 return get_<T>();
289 }
290
291 template <typename T, typename>
get()292 const T& AnyModule::get() const {
293 TORCH_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
294 return get_<T>();
295 }
296
297 template <typename T, typename ContainedType>
get()298 T AnyModule::get() const {
299 return T(ptr<ContainedType>());
300 }
301
ptr()302 inline std::shared_ptr<Module> AnyModule::ptr() const {
303 TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
304 return content_->ptr();
305 }
306
307 template <typename T, typename>
ptr()308 std::shared_ptr<T> AnyModule::ptr() const {
309 TORCH_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
310 // Call get() but discard the value, just to do the type checking.
311 get_<T>();
312 return std::dynamic_pointer_cast<T>(ptr());
313 }
314
type_info()315 inline const std::type_info& AnyModule::type_info() const {
316 TORCH_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
317 return content_->type_info;
318 }
319
is_empty()320 inline bool AnyModule::is_empty() const noexcept {
321 return content_ == nullptr;
322 }
323
324 // Private Methods
325
326 template <
327 typename ModuleType,
328 typename Class,
329 typename ReturnType,
330 typename... ArgumentTypes>
make_holder(std::shared_ptr<ModuleType> && module,ReturnType (Class::*)(ArgumentTypes...))331 std::unique_ptr<AnyModulePlaceholder> AnyModule::make_holder(
332 std::shared_ptr<ModuleType>&& module,
333 ReturnType (Class::*)(ArgumentTypes...)) {
334 static_assert(
335 torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
336 "Modules stored inside AnyModule must not take references. "
337 "Use pointers instead.");
338 static_assert(
339 !std::is_void<ReturnType>::value,
340 "AnyModule cannot store modules that return void "
341 "(you can return a dummy value).");
342 return std::make_unique<
343 AnyModuleHolder<std::decay_t<ModuleType>, ArgumentTypes...>>(
344 std::move(module));
345 }
346
347 template <typename ModuleType>
get_()348 ModuleType& AnyModule::get_() const {
349 using M = typename std::remove_reference<ModuleType>::type;
350 static_assert(
351 torch::detail::has_forward<M>::value,
352 "Can only call AnyModule::get<T> with a type T that has a forward method");
353 return get_(&M::forward);
354 }
355
356 template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
get_(ReturnType (ModuleType::*)(ArgumentTypes...))357 ModuleType& AnyModule::get_(
358 ReturnType (ModuleType::*)(ArgumentTypes...)) const {
359 if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
360 return *static_cast<AnyModuleHolder<ModuleType, ArgumentTypes...>&>(
361 *content_)
362 .module;
363 }
364 AT_ERROR(
365 "Attempted to cast module of type ",
366 c10::demangle(type_info().name()),
367 " to type ",
368 c10::demangle(typeid(ModuleType).name()));
369 }
370
371 } // namespace nn
372 } // namespace torch
373