1 #pragma once 2 3 #include <torch/csrc/autograd/function_hook.h> 4 #include <torch/csrc/dynamo/compiled_autograd.h> 5 6 namespace torch::autograd::utils { 7 8 // Turns lambda into a torch::autograd::FunctionPostHook. 9 class LambdaPostHook : public torch::autograd::FunctionPostHook { 10 using variable_list = std::vector<torch::autograd::Variable>; 11 using fn_type = 12 std::function<variable_list(const variable_list&, const variable_list&)>; 13 using compiled_fn_type = std::function<void(CompiledNodeArgs&)>; 14 15 public: 16 // The lambda function takes as arguments the outputs and inputs of the 17 // autograd function and can modify the outputs of the autograd function by 18 // returning a new output if needed. LambdaPostHook(fn_type fn)19 /* implicit */ LambdaPostHook(fn_type fn) : fn_(std::move(fn)) {} 20 LambdaPostHook(fn_type fn,compiled_fn_type compiled_fn)21 LambdaPostHook(fn_type fn, compiled_fn_type compiled_fn) 22 : fn_(std::move(fn)), compiled_fn_(std::move(compiled_fn)) {} 23 operator()24 variable_list operator()( 25 const variable_list& outputs, 26 const variable_list& inputs) override { 27 return fn_(outputs, inputs); 28 } 29 compiled_args(CompiledNodeArgs & args)30 void compiled_args(CompiledNodeArgs& args) override {} 31 32 protected: 33 std::function<variable_list(const variable_list&, const variable_list&)> fn_; 34 compiled_fn_type compiled_fn_{}; 35 }; 36 37 } // namespace torch::autograd::utils 38