xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_hook.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/autograd/function_hook.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/object_ptr.h>
6 
7 namespace torch::dynamo::autograd {
8 class SwapSavedVariables;
9 } // namespace torch::dynamo::autograd
10 
11 namespace torch::autograd {
12 
13 struct PyFunctionTensorPreHook : public FunctionPreHook {
14   PyFunctionTensorPreHook(PyObject* dict, size_t value_idx);
15   ~PyFunctionTensorPreHook() override;
16   variable_list operator()(const variable_list& values) override;
17   void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
18   PyObject* dict;
19   size_t value_idx;
20 };
21 
22 struct PyFunctionPreHook : public FunctionPreHook {
23   PyFunctionPreHook(PyObject* dict);
24   ~PyFunctionPreHook() override;
25   variable_list operator()(const variable_list& values) override;
26   void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
27   PyObject* dict;
28 };
29 
30 struct PyFunctionPostHook : public FunctionPostHook {
31   PyFunctionPostHook(PyObject* dict);
32   ~PyFunctionPostHook() override;
33   variable_list operator()(
34       const variable_list& outputs,
35       const variable_list& inputs) override;
36   void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
37   PyObject* dict;
38 };
39 
40 // PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks,
41 // and it is understandable if you are confused by why it's a subclass. We are
42 // simply following the precedent of PyFunctionPreHook and PyFunctionPostHook
43 // above to easily enroll into existing infrastructure.
44 struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook {
45   PyFunctionTensorPostAccGradHooks(PyObject* dict);
46   ~PyFunctionTensorPostAccGradHooks() override;
47   void operator()(const Variable& tensor) override;
48   void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override;
49   void apply_with_saved(
50       Variable& tensor,
51       torch::dynamo::autograd::SwapSavedVariables& saved) override;
52   PyObject* dict;
53 };
54 
55 } // namespace torch::autograd
56