1 #pragma once 2 3 #include <torch/csrc/autograd/function_hook.h> 4 #include <torch/csrc/python_headers.h> 5 #include <torch/csrc/utils/object_ptr.h> 6 7 namespace torch::dynamo::autograd { 8 class SwapSavedVariables; 9 } // namespace torch::dynamo::autograd 10 11 namespace torch::autograd { 12 13 struct PyFunctionTensorPreHook : public FunctionPreHook { 14 PyFunctionTensorPreHook(PyObject* dict, size_t value_idx); 15 ~PyFunctionTensorPreHook() override; 16 variable_list operator()(const variable_list& values) override; 17 void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; 18 PyObject* dict; 19 size_t value_idx; 20 }; 21 22 struct PyFunctionPreHook : public FunctionPreHook { 23 PyFunctionPreHook(PyObject* dict); 24 ~PyFunctionPreHook() override; 25 variable_list operator()(const variable_list& values) override; 26 void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; 27 PyObject* dict; 28 }; 29 30 struct PyFunctionPostHook : public FunctionPostHook { 31 PyFunctionPostHook(PyObject* dict); 32 ~PyFunctionPostHook() override; 33 variable_list operator()( 34 const variable_list& outputs, 35 const variable_list& inputs) override; 36 void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; 37 PyObject* dict; 38 }; 39 40 // PyFunctionTensorPostAccGradHooks is a dictionary of PostAccumulateGradHooks, 41 // and it is understandable if you are confused by why it's a subclass. We are 42 // simply following the precedent of PyFunctionPreHook and PyFunctionPostHook 43 // above to easily enroll into existing infrastructure. 44 struct PyFunctionTensorPostAccGradHooks : public PostAccumulateGradHook { 45 PyFunctionTensorPostAccGradHooks(PyObject* dict); 46 ~PyFunctionTensorPostAccGradHooks() override; 47 void operator()(const Variable& tensor) override; 48 void compiled_args(torch::dynamo::autograd::CompiledNodeArgs& args) override; 49 void apply_with_saved( 50 Variable& tensor, 51 torch::dynamo::autograd::SwapSavedVariables& saved) override; 52 PyObject* dict; 53 }; 54 55 } // namespace torch::autograd 56