xref: /aosp_15_r20/external/pytorch/functorch/dim/magic_trace.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import os
7import signal
8import subprocess
9from contextlib import contextmanager
10
11
12@contextmanager
13def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
14    pid = os.getpid()
15    if not os.path.exists(magic_trace_cache):
16        print(f"Downloading magic_trace to: {magic_trace_cache}")
17        subprocess.run(
18            [
19                "wget",
20                "-O",
21                magic_trace_cache,
22                "-q",
23                "https://github.com/janestreet/magic-trace/releases/download/v1.0.2/magic-trace",
24            ]
25        )
26        subprocess.run(["chmod", "+x", magic_trace_cache])
27    args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
28    p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
29    while True:
30        x = p.stderr.readline()
31        print(x)
32        if "Attached" in x:
33            break
34    try:
35        yield
36    finally:
37        p.send_signal(signal.SIGINT)
38        r = p.wait()
39        print(p.stderr.read())
40        p.stderr.close()
41        if r != 0:
42            raise ValueError(f"magic_trace exited abnormally: {r}")
43