xref: /aosp_15_r20/external/pytorch/torch/_dynamo/callback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2class CompilationCallbackHandler:
3    def __init__(self):
4        self.start_callbacks = []
5        self.end_callbacks = []
6
7    def register_start_callback(self, callback):
8        """
9        Register a callback function to be called when the compilation starts.
10
11        Args:
12        - callback (callable): The callback function to register.
13        """
14        self.start_callbacks.append(callback)
15        return callback
16
17    def register_end_callback(self, callback):
18        """
19        Register a callback function to be called when the compilation ends.
20
21        Args:
22        - callback (callable): The callback function to register.
23        """
24        self.end_callbacks.append(callback)
25        return callback
26
27    def remove_start_callback(self, callback):
28        """
29        Remove a registered start callback function.
30
31        Args:
32        - callback (callable): The callback function to remove.
33        """
34        self.start_callbacks.remove(callback)
35
36    def remove_end_callback(self, callback):
37        """
38        Remove a registered end callback function.
39
40        Args:
41        - callback (callable): The callback function to remove.
42        """
43        self.end_callbacks.remove(callback)
44
45    def run_start_callbacks(self):
46        """
47        Execute all registered start callbacks.
48        """
49        for callback in self.start_callbacks:
50            callback()
51
52    def run_end_callbacks(self):
53        """
54        Execute all registered end callbacks.
55        """
56        for callback in self.end_callbacks:
57            callback()
58
59    def clear(self):
60        """
61        Clear all registered callbacks.
62        """
63        self.start_callbacks.clear()
64        self.end_callbacks.clear()
65
66
67callback_handler = CompilationCallbackHandler()
68
69
70def on_compile_start(callback):
71    """
72    Decorator to register a callback function for the start of the compilation.
73    """
74    callback_handler.register_start_callback(callback)
75    return callback
76
77
78def on_compile_end(callback):
79    """
80    Decorator to register a callback function for the end of the compilation.
81    """
82    callback_handler.register_end_callback(callback)
83    return callback
84