xref: /aosp_15_r20/external/mesa3d/bin/flamegraph_map_lp_jit.py (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#
2# Copyright 2024 Autodesk, Inc.
3#
4# SPDX-License-Identifier: MIT
5#
6
7import argparse
8from bisect import bisect_left, bisect_right
9from dataclasses import dataclass
10from pathlib import Path
11import re
12
13
14@dataclass
15class Instruction:
16    address: int
17    assembly: str
18    samples: int = 0
19
20
21def mapping_address_key(mapping: tuple[int, int, str]):
22    return mapping[0]
23
24
25def instruction_address_key(instruction: Instruction):
26    return instruction.address
27
28
29def parse_mappings(map_file_path: Path):
30    mappings: list[tuple[int, int, str]] = []
31    with open(map_file_path) as map_file:
32        for mapping in map_file:
33            address_hex, size_hex, name = mapping.split(' ')
34            address = int(address_hex, base=16)
35            mappings.append((address, address + int(size_hex, base=16), name.strip()))
36
37    mappings.sort(key=mapping_address_key)
38    return mappings
39
40
41def parse_traces(trace_file_path: Path):
42    pattern = re.compile(r'((?:[^;]+;)*?[^;]+) (\d+)\n')
43
44    traces: list[tuple[list[str], int]] = []
45    with open(trace_file_path) as trace_file:
46        for trace in trace_file:
47            match = pattern.fullmatch(trace)
48            traces.append((match.group(1).split(';'), int(match.group(2))))
49
50    return traces
51
52
53def parse_asm(asm_file_path: Path):
54    symbol_pattern = re.compile(r'(\w+) ([0-9a-fA-F]+):\n')
55    instruction_pattern = re.compile(r' *([0-9a-fA-F]+):\t(.*?)\n')
56
57    asm: dict[tuple[int, str], list[Instruction]] = {}
58    with open(asm_file_path) as asm_file:
59        current_instructions = None
60        for line in asm_file:
61            if match := symbol_pattern.fullmatch(line):
62                symbol = (int(match.group(2), base=16), match.group(1))
63                current_instructions = asm[symbol] = []
64            elif match := instruction_pattern.fullmatch(line):
65                current_instructions.append(Instruction(int(match.group(1), base=16), match.group(2)))
66
67    return asm
68
69
70def main():
71    parser = argparse.ArgumentParser(description='Map LLVMPipe JIT addresses in FlameGraph style '
72                                     'collapsed stack traces to their symbol name. Also optionally '
73                                     'annotate JIT assembly dumps with sample counts.')
74    parser.add_argument('jit_symbol_map', type=Path, help='JIT symbol map from LLVMPipe')
75    parser.add_argument('collapsed_traces', type=Path)
76    parser.add_argument('-a', '--asm', type=Path, nargs='?', const='', metavar='asm_path',
77                        help='JIT assembly dump from LLVMPipe. Defaults to "<jit_symbol_map>.asm"')
78    parser.add_argument('-o', '--out', type=Path, metavar='out_path')
79    arguments = parser.parse_args()
80
81    mappings = parse_mappings(arguments.jit_symbol_map)
82    traces = parse_traces(arguments.collapsed_traces)
83
84    asm = {}
85    asm_file_path: Path | None = arguments.asm
86    if asm_file_path:
87        if len(asm_file_path.parts) <= 0:
88            asm_file_path = Path(str(arguments.jit_symbol_map) + '.asm')
89            if asm_file_path.exists():
90                asm = parse_asm(asm_file_path)
91        else:
92            asm = parse_asm(asm_file_path)
93
94    merged_traces: dict[str, int] = {}
95    for stack, count in traces:
96        for i, function in enumerate(stack):
97            if not function.startswith('0x'):
98                continue
99
100            address = int(function, base=16)
101            mapping = mappings[bisect_right(mappings, address, key=mapping_address_key) - 1]
102            if address < mapping[0] or address >= mapping[1]:
103                continue
104
105            stack[i] = f'lp`{mapping[2]}@{mapping[0]:x}'
106
107            symbol = (mapping[0], mapping[2])
108            if symbol in asm:
109                instructions = asm[symbol]
110                instruction_address = address - symbol[0]
111                index = bisect_left(instructions, instruction_address, key=instruction_address_key)
112                if index < len(instructions) and instructions[index].address == instruction_address:
113                    instructions[index].samples += count
114
115        stack_key = ';'.join(stack)
116        if stack_key in merged_traces:
117            merged_traces[stack_key] += count
118        else:
119            merged_traces[stack_key] = count
120
121    out_file_path: Path | None = arguments.out
122    if not out_file_path:
123        out_file_path = arguments.collapsed_traces.with_stem(f'{arguments.collapsed_traces.stem}_mapped')
124    with open(out_file_path, 'w') as out:
125        for t, c in merged_traces.items():
126            print(f'{t} {c}', file=out)
127
128    if asm:
129        annotated_asm_file_path = asm_file_path.with_stem(f'{asm_file_path.stem}_annotated')
130        with open(annotated_asm_file_path, 'w') as out:
131            for symbol, instructions in asm.items():
132                print(f'{symbol[1]}: ;{symbol[0]:x}', file=out)
133                for instruction in instructions:
134                    print(f'\t{instruction.assembly}', end='', file=out)
135                    if instruction.samples:
136                        print(f' ;s {instruction.samples}', file=out)
137                    else:
138                        print(file=out)
139                print(file=out)
140
141if __name__ == '__main__':
142    main()
143