xref: /aosp_15_r20/external/pytorch/torch/utils/model_dump/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3"""
4model_dump: a one-stop shop for TorchScript model inspection.
5
6The goal of this tool is to provide a simple way to extract lots of
7useful information from a TorchScript model and make it easy for humans
8to consume.  It (mostly) replaces zipinfo, common uses of show_pickle,
9and various ad-hoc analysis notebooks.
10
11The tool extracts information from the model and serializes it as JSON.
12That JSON can then be rendered by an HTML+JS page, either by
13loading the JSON over HTTP or producing a fully self-contained page
14with all of the code and data burned-in.
15"""
16
17# Maintainer notes follow.
18"""
19The implementation strategy has tension between 3 goals:
20- Small file size.
21- Fully self-contained.
22- Easy, modern JS environment.
23Using Preact and HTM achieves 1 and 2 with a decent result for 3.
24However, the models I tested with result in ~1MB JSON output,
25so even using something heavier like full React might be tolerable
26if the build process can be worked out.
27
28One principle I have followed that I think is very beneficial
29is to keep the JSON data as close as possible to the model
30and do most of the rendering logic on the client.
31This makes for easier development (just refresh, usually),
32allows for more laziness and dynamism, and lets us add more
33views of the same data without bloating the HTML file.
34
35Currently, this code doesn't actually load the model or even
36depend on any part of PyTorch.  I don't know if that's an important
37feature to maintain, but it's probably worth preserving the ability
38to run at least basic analysis on models that cannot be loaded.
39
40I think the easiest way to develop this code is to cd into model_dump and
41run "python -m http.server", then load http://localhost:8000/skeleton.html
42in the browser.  In another terminal, run
43"python -m torch.utils.model_dump --style=json FILE > \
44    torch/utils/model_dump/model_info.json"
45every time you update the Python code or model.
46When you update JS, just refresh.
47
48Possible improvements:
49    - Fix various TODO comments in this file and the JS.
50    - Make the HTML much less janky, especially the auxiliary data panel.
51    - Make the auxiliary data panel start small, expand when
52      data is available, and have a button to clear/contract.
53    - Clean up the JS.  There's a lot of copypasta because
54      I don't really know how to use Preact.
55    - Make the HTML render and work nicely inside a Jupyter notebook.
56    - Add the ability for JS to choose the URL to load the JSON based
57      on the page URL (query or hash).  That way we could publish the
58      inlined skeleton once and have it load various JSON blobs.
59    - Add a button to expand all expandable sections so ctrl-F works well.
60    - Add hyperlinking from data to code, and code to code.
61    - Add hyperlinking from debug info to Diffusion.
62    - Make small tensor contents available.
63    - Do something nice for quantized models
64      (they probably don't work at all right now).
65"""
66
67import argparse
68import io
69import json
70import os
71import pickle
72import pprint
73import re
74import sys
75import urllib.parse
76import zipfile
77from pathlib import Path
78from typing import Dict
79
80import torch.utils.show_pickle
81
82
83DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024
84
85__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton',
86           'burn_in_info', 'get_info_and_burn_skeleton']
87
88def get_storage_info(storage):
89    assert isinstance(storage, torch.utils.show_pickle.FakeObject)
90    assert storage.module == "pers"
91    assert storage.name == "obj"
92    assert storage.state is None
93    assert isinstance(storage.args, tuple)
94    assert len(storage.args) == 1
95    sa = storage.args[0]
96    assert isinstance(sa, tuple)
97    assert len(sa) == 5
98    assert sa[0] == "storage"
99    assert isinstance(sa[1], torch.utils.show_pickle.FakeClass)
100    assert sa[1].module == "torch"
101    assert sa[1].name.endswith("Storage")
102    storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:])
103    return storage_info
104
105
106def hierarchical_pickle(data):
107    if isinstance(data, (bool, int, float, str, type(None))):
108        return data
109    if isinstance(data, list):
110        return [hierarchical_pickle(d) for d in data]
111    if isinstance(data, tuple):
112        return {
113            "__tuple_values__": hierarchical_pickle(list(data)),
114        }
115    if isinstance(data, dict):
116        return {
117            "__is_dict__": True,
118            "keys": hierarchical_pickle(list(data.keys())),
119            "values": hierarchical_pickle(list(data.values())),
120        }
121    if isinstance(data, torch.utils.show_pickle.FakeObject):
122        typename = f"{data.module}.{data.name}"
123        if (
124            typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.'))
125        ):
126            assert data.args == ()
127            return {
128                "__module_type__": typename,
129                "state": hierarchical_pickle(data.state),
130            }
131        if typename == "torch._utils._rebuild_tensor_v2":
132            assert data.state is None
133            if len(data.args) == 6:
134                storage, offset, size, stride, requires_grad, hooks = data.args
135            else:
136                storage, offset, size, stride, requires_grad, hooks, metadata = data.args
137            storage_info = get_storage_info(storage)
138            return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]}
139        if typename == "torch._utils._rebuild_qtensor":
140            assert data.state is None
141            storage, offset, size, stride, quantizer, requires_grad, hooks = data.args
142            storage_info = get_storage_info(storage)
143            assert isinstance(quantizer, tuple)
144            assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass)
145            assert quantizer[0].module == "torch"
146            if quantizer[0].name == "per_tensor_affine":
147                assert len(quantizer) == 3
148                assert isinstance(quantizer[1], float)
149                assert isinstance(quantizer[2], int)
150                quantizer_extra = list(quantizer[1:3])
151            else:
152                quantizer_extra = []
153            quantizer_json = [quantizer[0].name] + quantizer_extra
154            return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]}
155        if typename == "torch.jit._pickle.restore_type_tag":
156            assert data.state is None
157            obj, typ = data.args
158            assert isinstance(typ, str)
159            return hierarchical_pickle(obj)
160        if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename):
161            assert data.state is None
162            ls, = data.args
163            assert isinstance(ls, list)
164            return hierarchical_pickle(ls)
165        if typename == "torch.device":
166            assert data.state is None
167            name, = data.args
168            assert isinstance(name, str)
169            # Just forget that it was a device and return the name.
170            return name
171        if typename == "builtin.UnicodeDecodeError":
172            assert data.state is None
173            msg, = data.args
174            assert isinstance(msg, str)
175            # Hack: Pretend this is a module so we don't need custom serialization.
176            # Hack: Wrap the message in a tuple so it looks like a nice state object.
177            # TODO: Undo at least that second hack.  We should support string states.
178            return {
179                "__module_type__": typename,
180                "state": hierarchical_pickle((msg,)),
181            }
182        raise Exception(f"Can't prepare fake object of type for JS: {typename}")  # noqa: TRY002
183    raise Exception(f"Can't prepare data of type for JS: {type(data)}")  # noqa: TRY002
184
185
186def get_model_info(
187        path_or_file,
188        title=None,
189        extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT):
190    """Get JSON-friendly information about a model.
191
192    The result is suitable for being saved as model_info.json,
193    or passed to burn_in_info.
194    """
195
196    if isinstance(path_or_file, os.PathLike):
197        default_title = os.fspath(path_or_file)
198        file_size = path_or_file.stat().st_size  # type: ignore[attr-defined]
199    elif isinstance(path_or_file, str):
200        default_title = path_or_file
201        file_size = Path(path_or_file).stat().st_size
202    else:
203        default_title = "buffer"
204        path_or_file.seek(0, io.SEEK_END)
205        file_size = path_or_file.tell()
206        path_or_file.seek(0)
207
208    title = title or default_title
209
210    with zipfile.ZipFile(path_or_file) as zf:
211        path_prefix = None
212        zip_files = []
213        for zi in zf.infolist():
214            prefix = re.sub("/.*", "", zi.filename)
215            if path_prefix is None:
216                path_prefix = prefix
217            elif prefix != path_prefix:
218                raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}")  # noqa: TRY002
219            zip_files.append(dict(
220                filename=zi.filename,
221                compression=zi.compress_type,
222                compressed_size=zi.compress_size,
223                file_size=zi.file_size,
224            ))
225
226        assert path_prefix is not None
227        version = zf.read(path_prefix + "/version").decode("utf-8").strip()
228
229        def get_pickle(name):
230            assert path_prefix is not None
231            with zf.open(path_prefix + f"/{name}.pkl") as handle:
232                raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
233                return hierarchical_pickle(raw)
234
235        model_data = get_pickle("data")
236        constants = get_pickle("constants")
237
238        # Intern strings that are likely to be re-used.
239        # Pickle automatically detects shared structure,
240        # so re-used strings are stored efficiently.
241        # However, JSON has no way of representing this,
242        # so we have to do it manually.
243        interned_strings : Dict[str, int] = {}
244
245        def ist(s):
246            if s not in interned_strings:
247                interned_strings[s] = len(interned_strings)
248            return interned_strings[s]
249
250        code_files = {}
251        for zi in zf.infolist():
252            if not zi.filename.endswith(".py"):
253                continue
254            with zf.open(zi) as handle:
255                raw_code = handle.read()
256            with zf.open(zi.filename + ".debug_pkl") as handle:
257                raw_debug = handle.read()
258
259            # Parse debug info and add begin/end markers if not present
260            # to ensure that we cover the entire source code.
261            debug_info_t = pickle.loads(raw_debug)
262            text_table = None
263
264            if (len(debug_info_t) == 3 and
265                    isinstance(debug_info_t[0], str) and
266                    debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'):
267                _, text_table, content = debug_info_t
268
269                def parse_new_format(line):
270                    # (0, (('', '', 0), 0, 0))
271                    num, ((text_indexes, fname_idx, offset), start, end), tag = line
272                    text = ''.join(text_table[x] for x in text_indexes)  # type: ignore[index]
273                    fname = text_table[fname_idx]  # type: ignore[index]
274                    return num, ((text, fname, offset), start, end), tag
275
276                debug_info_t = map(parse_new_format, content)
277
278            debug_info = list(debug_info_t)
279            if not debug_info:
280                debug_info.append((0, (('', '', 0), 0, 0)))
281            if debug_info[-1][0] != len(raw_code):
282                debug_info.append((len(raw_code), (('', '', 0), 0, 0)))
283
284            code_parts = []
285            for di, di_next in zip(debug_info, debug_info[1:]):
286                start, source_range, *_ = di
287                end = di_next[0]
288                assert end > start
289                source, s_start, s_end = source_range
290                s_text, s_file, s_line = source
291                # TODO: Handle this case better.  TorchScript ranges are in bytes,
292                # but JS doesn't really handle byte strings.
293                # if bytes and chars are not equivalent for this string,
294                # zero out the ranges so we don't highlight the wrong thing.
295                if len(s_text) != len(s_text.encode("utf-8")):
296                    s_start = 0
297                    s_end = 0
298                text = raw_code[start:end]
299                code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end])
300            code_files[zi.filename] = code_parts
301
302        extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json")
303        extra_files_jsons = {}
304        for zi in zf.infolist():
305            if not extra_files_json_pattern.fullmatch(zi.filename):
306                continue
307            if zi.file_size > extra_file_size_limit:
308                continue
309            with zf.open(zi) as handle:
310                try:
311                    json_content = json.load(handle)
312                    extra_files_jsons[zi.filename] = json_content
313                except json.JSONDecodeError:
314                    extra_files_jsons[zi.filename] = "INVALID JSON"
315
316        always_render_pickles = {
317            "bytecode.pkl",
318        }
319        extra_pickles = {}
320        for zi in zf.infolist():
321            if not zi.filename.endswith(".pkl"):
322                continue
323            with zf.open(zi) as handle:
324                # TODO: handle errors here and just ignore the file?
325                # NOTE: For a lot of these files (like bytecode),
326                # we could get away with just unpickling, but this should be safer.
327                obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load()
328            buf = io.StringIO()
329            pprint.pprint(obj, buf)
330            contents = buf.getvalue()
331            # Checked the rendered length instead of the file size
332            # because pickles with shared structure can explode in size during rendering.
333            if os.path.basename(zi.filename) not in always_render_pickles and \
334                    len(contents) > extra_file_size_limit:
335                continue
336            extra_pickles[zi.filename] = contents
337
338    return {"model": dict(
339        title=title,
340        file_size=file_size,
341        version=version,
342        zip_files=zip_files,
343        interned_strings=list(interned_strings),
344        code_files=code_files,
345        model_data=model_data,
346        constants=constants,
347        extra_files_jsons=extra_files_jsons,
348        extra_pickles=extra_pickles,
349    )}
350
351
352def get_inline_skeleton():
353    """Get a fully-inlined skeleton of the frontend.
354
355    The returned HTML page has no external network dependencies for code.
356    It can load model_info.json over HTTP, or be passed to burn_in_info.
357    """
358
359    import importlib.resources
360
361    skeleton = importlib.resources.read_text(__package__, "skeleton.html")
362    js_code = importlib.resources.read_text(__package__, "code.js")
363    for js_module in ["preact", "htm"]:
364        js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
365        js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
366        js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
367    skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code)
368    return skeleton
369
370
371def burn_in_info(skeleton, info):
372    """Burn model info into the HTML skeleton.
373
374    The result will render the hard-coded model info and
375    have no external network dependencies for code or data.
376    """
377
378    # Note that Python's json serializer does not escape slashes in strings.
379    # Since we're inlining this JSON directly into a script tag, a string
380    # containing "</script>" would end the script prematurely and
381    # mess up our page.  Unconditionally escape fixes that.
382    return skeleton.replace(
383        "BURNED_IN_MODEL_INFO = null",
384        "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/"))
385
386
387def get_info_and_burn_skeleton(path_or_bytesio, **kwargs):
388    model_info = get_model_info(path_or_bytesio, **kwargs)
389    skeleton = get_inline_skeleton()
390    page = burn_in_info(skeleton, model_info)
391    return page
392
393
394def main(argv, *, stdout=None):
395    parser = argparse.ArgumentParser()
396    parser.add_argument("--style", choices=["json", "html"])
397    parser.add_argument("--title")
398    parser.add_argument("model")
399    args = parser.parse_args(argv[1:])
400
401    info = get_model_info(args.model, title=args.title)
402
403    output = stdout or sys.stdout
404
405    if args.style == "json":
406        output.write(json.dumps(info, sort_keys=True) + "\n")
407    elif args.style == "html":
408        skeleton = get_inline_skeleton()
409        page = burn_in_info(skeleton, info)
410        output.write(page)
411    else:
412        raise Exception("Invalid style")  # noqa: TRY002
413