1 #pragma once 2 3 /// This macro enables a module with default arguments in its forward method 4 /// to be used in a Sequential module. 5 /// 6 /// Example usage: 7 /// 8 /// Let's say we have a module declared like this: 9 /// ``` 10 /// struct MImpl : torch::nn::Module { 11 /// public: 12 /// explicit MImpl(int value_) : value(value_) {} 13 /// torch::Tensor forward(int a, int b = 2, double c = 3.0) { 14 /// return torch::tensor(a + b + c); 15 /// } 16 /// private: 17 /// int value; 18 /// }; 19 /// TORCH_MODULE(M); 20 /// ``` 21 /// 22 /// If we try to use it in a Sequential module and run forward: 23 /// ``` 24 /// torch::nn::Sequential seq(M(1)); 25 /// seq->forward(1); 26 /// ``` 27 /// 28 /// We will receive the following error message: 29 /// ``` 30 /// MImpl's forward() method expects 3 argument(s), but received 1. 31 /// If MImpl's forward() method has default arguments, please make sure 32 /// the forward() method is declared with a corresponding 33 /// `FORWARD_HAS_DEFAULT_ARGS` macro. 34 /// ``` 35 /// 36 /// The right way to fix this error is to use the `FORWARD_HAS_DEFAULT_ARGS` 37 /// macro when declaring the module: 38 /// ``` 39 /// struct MImpl : torch::nn::Module { 40 /// public: 41 /// explicit MImpl(int value_) : value(value_) {} 42 /// torch::Tensor forward(int a, int b = 2, double c = 3.0) { 43 /// return torch::tensor(a + b + c); 44 /// } 45 /// protected: 46 /// /* 47 /// NOTE: looking at the argument list of `forward`: 48 /// `forward(int a, int b = 2, double c = 3.0)` 49 /// we saw the following default arguments: 50 /// ---------------------------------------------------------------- 51 /// 0-based index of default | Default value of arg 52 /// arg in forward arg list | (wrapped by `torch::nn::AnyValue()`) 53 /// ---------------------------------------------------------------- 54 /// 1 | torch::nn::AnyValue(2) 55 /// 2 | torch::nn::AnyValue(3.0) 56 /// ---------------------------------------------------------------- 57 /// Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` 58 /// macro: 59 /// */ 60 /// FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, 61 /// torch::nn::AnyValue(3.0)}) 62 /// private: 63 /// int value; 64 /// }; 65 /// TORCH_MODULE(M); 66 /// ``` 67 /// Now, running the following would work: 68 /// ``` 69 /// torch::nn::Sequential seq(M(1)); 70 /// seq->forward(1); // This correctly populates the default arguments for 71 /// `MImpl::forward` 72 /// ``` 73 #define FORWARD_HAS_DEFAULT_ARGS(...) \ 74 template <typename ModuleType, typename... ArgumentTypes> \ 75 friend struct torch::nn::AnyModuleHolder; \ 76 bool _forward_has_default_args() override { \ 77 return true; \ 78 } \ 79 unsigned int _forward_num_required_args() override { \ 80 std::pair<unsigned int, torch::nn::AnyValue> args_info[] = {__VA_ARGS__}; \ 81 return args_info[0].first; \ 82 } \ 83 std::vector<torch::nn::AnyValue> _forward_populate_default_args( \ 84 std::vector<torch::nn::AnyValue>&& arguments) override { \ 85 std::pair<unsigned int, torch::nn::AnyValue> args_info[] = {__VA_ARGS__}; \ 86 unsigned int num_all_args = std::rbegin(args_info)->first + 1; \ 87 TORCH_INTERNAL_ASSERT( \ 88 arguments.size() >= _forward_num_required_args() && \ 89 arguments.size() <= num_all_args); \ 90 std::vector<torch::nn::AnyValue> ret = std::move(arguments); \ 91 ret.reserve(num_all_args); \ 92 for (auto& arg_info : args_info) { \ 93 if (arg_info.first > ret.size() - 1) \ 94 ret.emplace_back(std::move(arg_info.second)); \ 95 } \ 96 return ret; \ 97 } 98