xref: /aosp_15_r20/external/pytorch/torch/xpu/_gpu_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Callable
2
3from torch._utils import CallbackRegistry
4
5
6EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
7    "XPU event creation"
8)
9EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
10    "XPU event deletion"
11)
12EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
13    "XPU event record"
14)
15EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
16    "XPU event wait"
17)
18MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
19    "XPU memory allocation"
20)
21MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
22    "XPU memory deallocation"
23)
24StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
25    "XPU stream creation"
26)
27DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
28    "XPU device synchronization"
29)
30StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
31    "XPU stream synchronization"
32)
33EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
34    "XPU event synchronization"
35)
36
37
38def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
39    EventCreationCallbacks.add_callback(cb)
40
41
42def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
43    EventDeletionCallbacks.add_callback(cb)
44
45
46def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
47    EventRecordCallbacks.add_callback(cb)
48
49
50def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
51    EventWaitCallbacks.add_callback(cb)
52
53
54def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
55    MemoryAllocationCallbacks.add_callback(cb)
56
57
58def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
59    MemoryDeallocationCallbacks.add_callback(cb)
60
61
62def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
63    StreamCreationCallbacks.add_callback(cb)
64
65
66def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
67    DeviceSynchronizationCallbacks.add_callback(cb)
68
69
70def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
71    StreamSynchronizationCallbacks.add_callback(cb)
72
73
74def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
75    EventSynchronizationCallbacks.add_callback(cb)
76