xref: /aosp_15_r20/external/pytorch/torch/mps/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker__all__ = ["start", "stop", "profile"]
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerdef start(mode: str = "interval", wait_until_completed: bool = False) -> None:
11*da0073e9SAndroid Build Coastguard Worker    r"""Start OS Signpost tracing from MPS backend.
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker    The generated OS Signposts could be recorded and viewed in
14*da0073e9SAndroid Build Coastguard Worker    XCode Instruments Logging tool.
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    Args:
17*da0073e9SAndroid Build Coastguard Worker        mode(str): OS Signpost tracing mode could be "interval", "event",
18*da0073e9SAndroid Build Coastguard Worker            or both "interval,event".
19*da0073e9SAndroid Build Coastguard Worker            The interval mode traces the duration of execution of the operations,
20*da0073e9SAndroid Build Coastguard Worker            whereas event mode marks the completion of executions.
21*da0073e9SAndroid Build Coastguard Worker            See document `Recording Performance Data`_ for more info.
22*da0073e9SAndroid Build Coastguard Worker        wait_until_completed(bool): Waits until the MPS Stream complete
23*da0073e9SAndroid Build Coastguard Worker            executing each encoded GPU operation. This helps generating single
24*da0073e9SAndroid Build Coastguard Worker            dispatches on the trace's timeline.
25*da0073e9SAndroid Build Coastguard Worker            Note that enabling this option would affect the performance negatively.
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    .. _Recording Performance Data:
28*da0073e9SAndroid Build Coastguard Worker       https://developer.apple.com/documentation/os/logging/recording_performance_data
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker    mode_normalized = mode.lower().replace(" ", "")
31*da0073e9SAndroid Build Coastguard Worker    torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerdef stop():
35*da0073e9SAndroid Build Coastguard Worker    r"""Stops generating OS Signpost tracing from MPS backend."""
36*da0073e9SAndroid Build Coastguard Worker    torch._C._mps_profilerStopTrace()
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
40*da0073e9SAndroid Build Coastguard Workerdef profile(mode: str = "interval", wait_until_completed: bool = False):
41*da0073e9SAndroid Build Coastguard Worker    r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    Args:
44*da0073e9SAndroid Build Coastguard Worker        mode(str): OS Signpost tracing mode could be "interval", "event",
45*da0073e9SAndroid Build Coastguard Worker            or both "interval,event".
46*da0073e9SAndroid Build Coastguard Worker            The interval mode traces the duration of execution of the operations,
47*da0073e9SAndroid Build Coastguard Worker            whereas event mode marks the completion of executions.
48*da0073e9SAndroid Build Coastguard Worker            See document `Recording Performance Data`_ for more info.
49*da0073e9SAndroid Build Coastguard Worker        wait_until_completed(bool): Waits until the MPS Stream complete
50*da0073e9SAndroid Build Coastguard Worker            executing each encoded GPU operation. This helps generating single
51*da0073e9SAndroid Build Coastguard Worker            dispatches on the trace's timeline.
52*da0073e9SAndroid Build Coastguard Worker            Note that enabling this option would affect the performance negatively.
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    .. _Recording Performance Data:
55*da0073e9SAndroid Build Coastguard Worker       https://developer.apple.com/documentation/os/logging/recording_performance_data
56*da0073e9SAndroid Build Coastguard Worker    """
57*da0073e9SAndroid Build Coastguard Worker    try:
58*da0073e9SAndroid Build Coastguard Worker        start(mode, wait_until_completed)
59*da0073e9SAndroid Build Coastguard Worker        yield
60*da0073e9SAndroid Build Coastguard Worker    finally:
61*da0073e9SAndroid Build Coastguard Worker        stop()
62