1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/utils/variadic.h> 5 #include <torch/nn/cloneable.h> 6 #include <torch/nn/pimpl.h> 7 #include <torch/types.h> 8 9 #include <functional> 10 #include <utility> 11 12 namespace torch { 13 namespace nn { 14 15 /// Wraps a function in a `Module`. 16 /// 17 /// The `Functional` module allows wrapping an arbitrary function or function 18 /// object in an `nn::Module`. This is primarily handy for usage in 19 /// `Sequential`. 20 /// 21 /// \rst 22 /// .. code-block:: cpp 23 /// 24 /// Sequential sequential( 25 /// Linear(3, 4), 26 /// Functional(torch::relu), 27 /// BatchNorm1d(3), 28 /// Functional(torch::elu, /*alpha=*/1)); 29 /// \endrst 30 /// 31 /// While a `Functional` module only accepts a single `Tensor` as input, it is 32 /// possible for the wrapped function to accept further arguments. However, 33 /// these have to be bound *at construction time*. For example, if 34 /// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its 35 /// second argument, with a particular value for its `slope` in a `Functional` 36 /// module, you could write 37 /// 38 /// \rst 39 /// .. code-block:: cpp 40 /// 41 /// Functional(torch::leaky_relu, /*slope=*/0.5) 42 /// \endrst 43 /// 44 /// The value of `0.5` is then stored within the `Functional` object and 45 /// supplied to the function call at invocation time. Note that such bound 46 /// values are evaluated eagerly and stored a single time. See the documentation 47 /// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind) 48 /// for more information on the semantics of argument binding. 49 /// 50 /// \rst 51 /// .. attention:: 52 /// After passing any bound arguments, the function must accept a single 53 /// tensor and return a single tensor. 54 /// \endrst 55 /// 56 /// Note that `Functional` overloads the call operator (`operator()`) such that 57 /// you can invoke it with `my_func(...)`. 58 class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> { 59 public: 60 using Function = std::function<Tensor(Tensor)>; 61 62 /// Constructs a `Functional` from a function object. 63 explicit FunctionalImpl(Function function); 64 65 template < 66 typename SomeFunction, 67 typename... Args, 68 typename = std::enable_if_t<(sizeof...(Args) > 0)>> FunctionalImpl(SomeFunction original_function,Args &&...args)69 explicit FunctionalImpl(SomeFunction original_function, Args&&... args) 70 // NOLINTNEXTLINE(modernize-avoid-bind) 71 : function_(std::bind( 72 original_function, 73 /*input=*/std::placeholders::_1, 74 std::forward<Args>(args)...)) { 75 // std::bind is normally evil, but (1) gcc is broken w.r.t. handling 76 // parameter pack expansion in lambdas and (2) moving parameter packs into 77 // a lambda only works with C++14, so std::bind is the more move-aware 78 // solution here. 79 } 80 81 void reset() override; 82 83 /// Pretty prints the `Functional` module into the given `stream`. 84 void pretty_print(std::ostream& stream) const override; 85 86 /// Forwards the `input` tensor to the underlying (bound) function object. 87 Tensor forward(Tensor input); 88 89 /// Calls forward(input). 90 Tensor operator()(Tensor input); 91 92 bool is_serializable() const override; 93 94 private: 95 Function function_; 96 }; 97 98 /// A `ModuleHolder` subclass for `FunctionalImpl`. 99 /// See the documentation for `FunctionalImpl` class to learn what methods it 100 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's 101 /// module storage semantics. 102 TORCH_MODULE(Functional); 103 104 } // namespace nn 105 } // namespace torch 106