1from typing import Callable 2 3from torch._utils import CallbackRegistry 4 5 6EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 7 "XPU event creation" 8) 9EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 10 "XPU event deletion" 11) 12EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( 13 "XPU event record" 14) 15EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( 16 "XPU event wait" 17) 18MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 19 "XPU memory allocation" 20) 21MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 22 "XPU memory deallocation" 23) 24StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 25 "XPU stream creation" 26) 27DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( 28 "XPU device synchronization" 29) 30StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 31 "XPU stream synchronization" 32) 33EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( 34 "XPU event synchronization" 35) 36 37 38def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: 39 EventCreationCallbacks.add_callback(cb) 40 41 42def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: 43 EventDeletionCallbacks.add_callback(cb) 44 45 46def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: 47 EventRecordCallbacks.add_callback(cb) 48 49 50def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: 51 EventWaitCallbacks.add_callback(cb) 52 53 54def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: 55 MemoryAllocationCallbacks.add_callback(cb) 56 57 58def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: 59 MemoryDeallocationCallbacks.add_callback(cb) 60 61 62def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: 63 StreamCreationCallbacks.add_callback(cb) 64 65 66def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: 67 DeviceSynchronizationCallbacks.add_callback(cb) 68 69 70def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: 71 StreamSynchronizationCallbacks.add_callback(cb) 72 73 74def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: 75 EventSynchronizationCallbacks.add_callback(cb) 76