xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/common.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 /// This macro enables a module with default arguments in its forward method
4 /// to be used in a Sequential module.
5 ///
6 /// Example usage:
7 ///
8 /// Let's say we have a module declared like this:
9 /// ```
10 /// struct MImpl : torch::nn::Module {
11 ///  public:
12 ///   explicit MImpl(int value_) : value(value_) {}
13 ///   torch::Tensor forward(int a, int b = 2, double c = 3.0) {
14 ///     return torch::tensor(a + b + c);
15 ///   }
16 ///  private:
17 ///   int value;
18 /// };
19 /// TORCH_MODULE(M);
20 /// ```
21 ///
22 /// If we try to use it in a Sequential module and run forward:
23 /// ```
24 /// torch::nn::Sequential seq(M(1));
25 /// seq->forward(1);
26 /// ```
27 ///
28 /// We will receive the following error message:
29 /// ```
30 /// MImpl's forward() method expects 3 argument(s), but received 1.
31 /// If MImpl's forward() method has default arguments, please make sure
32 /// the forward() method is declared with a corresponding
33 /// `FORWARD_HAS_DEFAULT_ARGS` macro.
34 /// ```
35 ///
36 /// The right way to fix this error is to use the `FORWARD_HAS_DEFAULT_ARGS`
37 /// macro when declaring the module:
38 /// ```
39 /// struct MImpl : torch::nn::Module {
40 ///  public:
41 ///   explicit MImpl(int value_) : value(value_) {}
42 ///   torch::Tensor forward(int a, int b = 2, double c = 3.0) {
43 ///     return torch::tensor(a + b + c);
44 ///   }
45 ///  protected:
46 ///   /*
47 ///   NOTE: looking at the argument list of `forward`:
48 ///   `forward(int a, int b = 2, double c = 3.0)`
49 ///   we saw the following default arguments:
50 ///   ----------------------------------------------------------------
51 ///   0-based index of default |         Default value of arg
52 ///   arg in forward arg list  |  (wrapped by `torch::nn::AnyValue()`)
53 ///   ----------------------------------------------------------------
54 ///               1            |       torch::nn::AnyValue(2)
55 ///               2            |       torch::nn::AnyValue(3.0)
56 ///   ----------------------------------------------------------------
57 ///   Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS`
58 ///   macro:
59 ///   */
60 ///   FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2,
61 ///   torch::nn::AnyValue(3.0)})
62 ///  private:
63 ///   int value;
64 /// };
65 /// TORCH_MODULE(M);
66 /// ```
67 /// Now, running the following would work:
68 /// ```
69 /// torch::nn::Sequential seq(M(1));
70 /// seq->forward(1);  // This correctly populates the default arguments for
71 /// `MImpl::forward`
72 /// ```
73 #define FORWARD_HAS_DEFAULT_ARGS(...)                                         \
74   template <typename ModuleType, typename... ArgumentTypes>                   \
75   friend struct torch::nn::AnyModuleHolder;                                   \
76   bool _forward_has_default_args() override {                                 \
77     return true;                                                              \
78   }                                                                           \
79   unsigned int _forward_num_required_args() override {                        \
80     std::pair<unsigned int, torch::nn::AnyValue> args_info[] = {__VA_ARGS__}; \
81     return args_info[0].first;                                                \
82   }                                                                           \
83   std::vector<torch::nn::AnyValue> _forward_populate_default_args(            \
84       std::vector<torch::nn::AnyValue>&& arguments) override {                \
85     std::pair<unsigned int, torch::nn::AnyValue> args_info[] = {__VA_ARGS__}; \
86     unsigned int num_all_args = std::rbegin(args_info)->first + 1;            \
87     TORCH_INTERNAL_ASSERT(                                                    \
88         arguments.size() >= _forward_num_required_args() &&                   \
89         arguments.size() <= num_all_args);                                    \
90     std::vector<torch::nn::AnyValue> ret = std::move(arguments);              \
91     ret.reserve(num_all_args);                                                \
92     for (auto& arg_info : args_info) {                                        \
93       if (arg_info.first > ret.size() - 1)                                    \
94         ret.emplace_back(std::move(arg_info.second));                         \
95     }                                                                         \
96     return ret;                                                               \
97   }
98