xref: /aosp_15_r20/external/pytorch/torch/cuda/profiler.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import tempfile
4
5import torch
6
7from . import check_error, cudart
8
9
10__all__ = ["init", "start", "stop", "profile"]
11
12DEFAULT_FLAGS = [
13    "gpustarttimestamp",
14    "gpuendtimestamp",
15    "gridsize3d",
16    "threadblocksize",
17    "streamid",
18    "enableonstart 0",
19    "conckerneltrace",
20]
21
22
23def init(output_file, flags=None, output_mode="key_value"):
24    rt = cudart()
25    if not hasattr(rt, "cudaOutputMode"):
26        raise AssertionError("HIP does not support profiler initialization!")
27    if (
28        hasattr(torch.version, "cuda")
29        and torch.version.cuda is not None
30        and int(torch.version.cuda.split(".")[0]) >= 12
31    ):
32        # Check https://github.com/pytorch/pytorch/pull/91118
33        # cudaProfilerInitialize is no longer needed after CUDA 12
34        raise AssertionError("CUDA12+ does not need profiler initialization!")
35    flags = DEFAULT_FLAGS if flags is None else flags
36    if output_mode == "key_value":
37        output_mode_enum = rt.cudaOutputMode.KeyValuePair
38    elif output_mode == "csv":
39        output_mode_enum = rt.cudaOutputMode.CSV
40    else:
41        raise RuntimeError(
42            "supported CUDA profiler output modes are: key_value and csv"
43        )
44    with tempfile.NamedTemporaryFile(delete=True) as f:
45        f.write(b"\n".join(f.encode("ascii") for f in flags))
46        f.flush()
47        check_error(rt.cudaProfilerInitialize(f.name, output_file, output_mode_enum))
48
49
50def start():
51    r"""Starts cuda profiler data collection.
52
53    .. warning::
54        Raises CudaError in case of it is unable to start the profiler.
55    """
56    check_error(cudart().cudaProfilerStart())
57
58
59def stop():
60    r"""Stops cuda profiler data collection.
61
62    .. warning::
63        Raises CudaError in case of it is unable to stop the profiler.
64    """
65    check_error(cudart().cudaProfilerStop())
66
67
68@contextlib.contextmanager
69def profile():
70    """
71    Enable profiling.
72
73    Context Manager to enabling profile collection by the active profiling tool from CUDA backend.
74    Example:
75        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
76        >>> import torch
77        >>> model = torch.nn.Linear(20, 30).cuda()
78        >>> inputs = torch.randn(128, 20).cuda()
79        >>> with torch.cuda.profiler.profile() as prof:
80        ...     model(inputs)
81    """
82    try:
83        start()
84        yield
85    finally:
86        stop()
87