xref: /aosp_15_r20/external/executorch/util/python_profiler.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import cProfile
8import io
9import logging
10import os
11import pstats
12import re
13from pstats import Stats
14
15from snakeviz.stats import json_stats, table_rows
16from tornado import template
17
18module_found = True
19try:
20    import snakeviz
21except ImportError:
22    module_found = False
23
24snakeviz_dir = os.path.dirname(os.path.abspath(snakeviz.__file__))
25snakeviz_templates_dir = os.path.join(snakeviz_dir, "templates")
26
27
28def _from_pstat_to_static_html(stats: Stats, html_filename: str):
29    """
30    Parses pstats data and populates viz.html template stored under templates dir.
31    This utility allows to export html file without kicking off webserver.
32
33    Note that it relies js scripts stored at rawgit cdn. This is not super
34    reliable, however it does allow one to not have to rely on webserver and
35    local rendering. On the other hand, for local rendering please follow
36    the main snakeviz tutorial
37
38    Inspiration for this util is from https://gist.github.com/jiffyclub/6b5e0f0f05ab487ff607.
39
40    Args:
41        stats: Stats generated from cProfile data
42        html_filename: Output filename in which populated template is rendered
43    """
44    RESTR = r'(?<!] \+ ")/static/'
45    REPLACE_WITH = "https://cdn.rawgit.com/jiffyclub/snakeviz/v0.4.2/snakeviz/static/"
46
47    if not isinstance(html_filename, str):
48        raise ValueError("A valid file name must be provided.")
49
50    viz_html_loader = template.Loader(snakeviz_templates_dir)
51    html_bytes_renderer = viz_html_loader.load("viz.html")
52    file_split = html_filename.split(".")
53    if len(file_split) < 2:
54        raise ValueError(
55            f"\033[0;32;40m Provided filename \033[0;31;47m {html_filename} \033[0;32;40m does not contain . separator."
56        )
57    profile_name = file_split[0]
58    html_bytes = html_bytes_renderer.generate(
59        profile_name=profile_name,
60        table_rows=table_rows(stats),
61        callees=json_stats(stats),
62    )
63    html_string = html_bytes.decode("utf-8")
64    html_string = re.sub(RESTR, REPLACE_WITH, html_string)
65    with open(html_filename, "w") as f:
66        f.write(html_string)
67
68
69class CProfilerFlameGraph:
70    def __init__(self, filename: str):
71        if not module_found:
72            raise Exception(
73                "Please install snakeviz to use CProfilerFlameGraph. Follow cprofiler_flamegraph.md for more information."
74            )
75        self.filename = filename
76
77    def __enter__(self):
78        self.pr = cProfile.Profile()
79        self.pr.enable()
80
81    def __exit__(self, exc_type, exc_val, exc_tb):
82        if exc_type is not None:
83            logging.error("Exception occurred", exc_info=(exc_type, exc_val, exc_tb))
84
85        self.pr.disable()
86        s = io.StringIO()
87        ps = pstats.Stats(self.pr, stream=s)
88        _from_pstat_to_static_html(ps, self.filename)
89