xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/cpp_hook.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <torch/csrc/autograd/function_hook.h>
3 #include <functional>
4 #include <memory>
5 
6 namespace torch::autograd {
7 
8 using hooks_list =
9     std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
10 
11 struct CppFunctionTensorPreHook : public FunctionPreHook {
12   CppFunctionTensorPreHook(std::shared_ptr<hooks_list> hooks, size_t value_idx);
13   variable_list operator()(const variable_list& values) override;
14 
15   std::shared_ptr<hooks_list> hooks_;
16   size_t value_idx_;
17 };
18 
19 struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
20   CppFunctionSingleTensorPreHook(
21       std::function<at::TensorBase(const at::TensorBase&)> hook,
22       size_t value_idx);
23   variable_list operator()(const variable_list& values) override;
24 
25   std::function<at::TensorBase(const at::TensorBase&)> hook_;
26   size_t value_idx_;
27 };
28 
29 } // namespace torch::autograd
30