xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/container/functional.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/utils/variadic.h>
5 #include <torch/nn/cloneable.h>
6 #include <torch/nn/pimpl.h>
7 #include <torch/types.h>
8 
9 #include <functional>
10 #include <utility>
11 
12 namespace torch {
13 namespace nn {
14 
15 /// Wraps a function in a `Module`.
16 ///
17 /// The `Functional` module allows wrapping an arbitrary function or function
18 /// object in an `nn::Module`. This is primarily handy for usage in
19 /// `Sequential`.
20 ///
21 /// \rst
22 /// .. code-block:: cpp
23 ///
24 ///   Sequential sequential(
25 ///     Linear(3, 4),
26 ///     Functional(torch::relu),
27 ///     BatchNorm1d(3),
28 ///     Functional(torch::elu, /*alpha=*/1));
29 /// \endrst
30 ///
31 /// While a `Functional` module only accepts a single `Tensor` as input, it is
32 /// possible for the wrapped function to accept further arguments. However,
33 /// these have to be bound *at construction time*. For example, if
34 /// you want to wrap `torch::leaky_relu`, which accepts a `slope` scalar as its
35 /// second argument, with a particular value for its `slope` in a `Functional`
36 /// module, you could write
37 ///
38 /// \rst
39 /// .. code-block:: cpp
40 ///
41 ///   Functional(torch::leaky_relu, /*slope=*/0.5)
42 /// \endrst
43 ///
44 /// The value of `0.5` is then stored within the `Functional` object and
45 /// supplied to the function call at invocation time. Note that such bound
46 /// values are evaluated eagerly and stored a single time. See the documentation
47 /// of [std::bind](https://en.cppreference.com/w/cpp/utility/functional/bind)
48 /// for more information on the semantics of argument binding.
49 ///
50 /// \rst
51 /// .. attention::
52 ///   After passing any bound arguments, the function must accept a single
53 ///   tensor and return a single tensor.
54 /// \endrst
55 ///
56 /// Note that `Functional` overloads the call operator (`operator()`) such that
57 /// you can invoke it with `my_func(...)`.
58 class TORCH_API FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
59  public:
60   using Function = std::function<Tensor(Tensor)>;
61 
62   /// Constructs a `Functional` from a function object.
63   explicit FunctionalImpl(Function function);
64 
65   template <
66       typename SomeFunction,
67       typename... Args,
68       typename = std::enable_if_t<(sizeof...(Args) > 0)>>
FunctionalImpl(SomeFunction original_function,Args &&...args)69   explicit FunctionalImpl(SomeFunction original_function, Args&&... args)
70       // NOLINTNEXTLINE(modernize-avoid-bind)
71       : function_(std::bind(
72             original_function,
73             /*input=*/std::placeholders::_1,
74             std::forward<Args>(args)...)) {
75     // std::bind is normally evil, but (1) gcc is broken w.r.t. handling
76     // parameter pack expansion in lambdas and (2) moving parameter packs into
77     // a lambda only works with C++14, so std::bind is the more move-aware
78     // solution here.
79   }
80 
81   void reset() override;
82 
83   /// Pretty prints the `Functional` module into the given `stream`.
84   void pretty_print(std::ostream& stream) const override;
85 
86   /// Forwards the `input` tensor to the underlying (bound) function object.
87   Tensor forward(Tensor input);
88 
89   /// Calls forward(input).
90   Tensor operator()(Tensor input);
91 
92   bool is_serializable() const override;
93 
94  private:
95   Function function_;
96 };
97 
98 /// A `ModuleHolder` subclass for `FunctionalImpl`.
99 /// See the documentation for `FunctionalImpl` class to learn what methods it
100 /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
101 /// module storage semantics.
102 TORCH_MODULE(Functional);
103 
104 } // namespace nn
105 } // namespace torch
106