1 #pragma once 2 #include <torch/csrc/autograd/function_hook.h> 3 #include <functional> 4 #include <memory> 5 6 namespace torch::autograd { 7 8 using hooks_list = 9 std::vector<std::function<at::TensorBase(const at::TensorBase&)>>; 10 11 struct CppFunctionTensorPreHook : public FunctionPreHook { 12 CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks, size_t value_idx); 13 variable_list operator()(const variable_list& values) override; 14 15 std::shared_ptr<hooks_list> hooks_; 16 size_t value_idx_; 17 }; 18 19 struct CppFunctionSingleTensorPreHook : public FunctionPreHook { 20 CppFunctionSingleTensorPreHook( 21 std::function<at::TensorBase(const at::TensorBase&)> hook, 22 size_t value_idx); 23 variable_list operator()(const variable_list& values) override; 24 25 std::function<at::TensorBase(const at::TensorBase&)> hook_; 26 size_t value_idx_; 27 }; 28 29 } // namespace torch::autograd 30