xref: /aosp_15_r20/external/pytorch/torch/utils/jit/log_extract.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3from typing import Any, List, Tuple, cast
4import random
5import torch
6import time
7from torch.utils.benchmark import Timer
8
9def extract_ir(filename: str) -> List[str]:
10    BEGIN = "<GRAPH_EXPORT>"
11    END = "</GRAPH_EXPORT>"
12    pfx = None
13    current = ""
14    graphs = []
15    with open(filename) as f:
16        split_strs = f.read().split(BEGIN)
17        for i, split_str in enumerate(split_strs):
18            if i == 0:
19                continue
20            end_loc = split_str.find(END)
21            if end_loc == -1:
22                continue
23            s = split_str[:end_loc]
24            pfx = split_strs[i - 1].splitlines()[-1]
25            lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
26            graphs.append(''.join(lines))
27
28    return graphs
29
30
31def make_tensor_from_type(inp_type: torch._C.TensorType):
32    size = inp_type.sizes()
33    stride = inp_type.strides()
34    device = inp_type.device()
35    dtype = inp_type.dtype()
36    assert size is not None
37    assert stride is not None
38    assert device is not None
39    assert dtype is not None
40    return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
41
42def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
43    graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
44    graph.makeMultiOutputIntoTuple()
45    inputs = []
46    for inp in graph.inputs():
47        if isinstance(inp.type(), torch._C.FloatType):
48            inputs.append(random.uniform(.1, 100))
49        elif isinstance(inp.type(), torch._C.IntType):
50            inputs.append(random.randint(1, 100))
51        elif isinstance(inp.type(), torch._C.TensorType):
52            tensorType = cast(torch._C.TensorType, inp.type())
53            inputs.append(make_tensor_from_type(tensorType))
54        elif isinstance(inp.type(), torch._C.BoolType):
55            inputs.append(random.randint(0, 1) == 1)
56        else:
57            raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
58
59    func = torch._C._create_function_from_graph("forward", graph)
60    torch._C._jit_pass_erase_shape_information(func.graph)
61    return (func, inputs)
62
63def time_cuda(fn, inputs, test_runs):
64    t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
65    times = t.blocked_autorange()
66    return times.median * 1000  # time in ms
67
68def time_cpu(fn, inputs, test_runs):
69    s = time.perf_counter()
70    for _ in range(test_runs):
71        fn(*inputs)
72    e = time.perf_counter()
73    return (e - s) / test_runs * 1000  # time in ms
74
75def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
76    graph, _ = load_graph_and_inputs(ir)
77    for _ in range(warmup_runs):
78        graph(*inputs)
79
80    is_cpu = None
81    for input in inputs:
82        if isinstance(input, torch.Tensor):
83            is_cpu = input.device.type == "cpu"
84            break
85    assert is_cpu is not None
86
87    out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
88    return out
89
90@contextmanager
91def no_fuser(*args, **kwargs):
92    old_optimize = torch._C._get_graph_executor_optimize(False)
93    try:
94        yield
95    finally:
96        torch._C._get_graph_executor_optimize(old_optimize)
97
98def run_baseline_no_fusion(ir, inputs) -> float:
99    with no_fuser():
100        return run_test(ir, inputs)
101
102
103def run_nnc(ir, inputs, dynamic) -> float:
104    try:
105        strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
106        old_strat = torch.jit.set_fusion_strategy(strat)
107        with torch.jit.fuser("fuser1"):
108            return run_test(ir, inputs)
109    finally:
110        torch.jit.set_fusion_strategy(old_strat)
111
112def run_nvfuser(ir, inputs) -> float:
113    with torch.jit.fuser("fuser2"):
114        return run_test(ir, inputs)
115