xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/utils/lambda_post_hook.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/function_hook.h>
4 #include <torch/csrc/dynamo/compiled_autograd.h>
5 
6 namespace torch::autograd::utils {
7 
8 // Turns lambda into a torch::autograd::FunctionPostHook.
9 class LambdaPostHook : public torch::autograd::FunctionPostHook {
10   using variable_list = std::vector<torch::autograd::Variable>;
11   using fn_type =
12       std::function<variable_list(const variable_list&, const variable_list&)>;
13   using compiled_fn_type = std::function<void(CompiledNodeArgs&)>;
14 
15  public:
16   // The lambda function takes as arguments the outputs and inputs of the
17   // autograd function and can modify the outputs of the autograd function by
18   // returning a new output if needed.
LambdaPostHook(fn_type fn)19   /* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {}
20 
LambdaPostHook(fn_type fn,compiled_fn_type compiled_fn)21   LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn)
22       : fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {}
23 
operator()24   variable_list operator()(
25       const variable_list& outputs,
26       const variable_list& inputs) override {
27     return fn_(outputs, inputs);
28   }
29 
compiled_args(CompiledNodeArgs & args)30   void compiled_args(CompiledNodeArgs& args) override {}
31 
32  protected:
33   std::function<variable_list(const variable_list&, const variable_list&)> fn_;
34   compiled_fn_type compiled_fn_{};
35 };
36 
37 } // namespace torch::autograd::utils
38