1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3""" 4model_dump: a one-stop shop for TorchScript model inspection. 5 6The goal of this tool is to provide a simple way to extract lots of 7useful information from a TorchScript model and make it easy for humans 8to consume. It (mostly) replaces zipinfo, common uses of show_pickle, 9and various ad-hoc analysis notebooks. 10 11The tool extracts information from the model and serializes it as JSON. 12That JSON can then be rendered by an HTML+JS page, either by 13loading the JSON over HTTP or producing a fully self-contained page 14with all of the code and data burned-in. 15""" 16 17# Maintainer notes follow. 18""" 19The implementation strategy has tension between 3 goals: 20- Small file size. 21- Fully self-contained. 22- Easy, modern JS environment. 23Using Preact and HTM achieves 1 and 2 with a decent result for 3. 24However, the models I tested with result in ~1MB JSON output, 25so even using something heavier like full React might be tolerable 26if the build process can be worked out. 27 28One principle I have followed that I think is very beneficial 29is to keep the JSON data as close as possible to the model 30and do most of the rendering logic on the client. 31This makes for easier development (just refresh, usually), 32allows for more laziness and dynamism, and lets us add more 33views of the same data without bloating the HTML file. 34 35Currently, this code doesn't actually load the model or even 36depend on any part of PyTorch. I don't know if that's an important 37feature to maintain, but it's probably worth preserving the ability 38to run at least basic analysis on models that cannot be loaded. 39 40I think the easiest way to develop this code is to cd into model_dump and 41run "python -m http.server", then load http://localhost:8000/skeleton.html 42in the browser. In another terminal, run 43"python -m torch.utils.model_dump --style=json FILE > \ 44 torch/utils/model_dump/model_info.json" 45every time you update the Python code or model. 46When you update JS, just refresh. 47 48Possible improvements: 49 - Fix various TODO comments in this file and the JS. 50 - Make the HTML much less janky, especially the auxiliary data panel. 51 - Make the auxiliary data panel start small, expand when 52 data is available, and have a button to clear/contract. 53 - Clean up the JS. There's a lot of copypasta because 54 I don't really know how to use Preact. 55 - Make the HTML render and work nicely inside a Jupyter notebook. 56 - Add the ability for JS to choose the URL to load the JSON based 57 on the page URL (query or hash). That way we could publish the 58 inlined skeleton once and have it load various JSON blobs. 59 - Add a button to expand all expandable sections so ctrl-F works well. 60 - Add hyperlinking from data to code, and code to code. 61 - Add hyperlinking from debug info to Diffusion. 62 - Make small tensor contents available. 63 - Do something nice for quantized models 64 (they probably don't work at all right now). 65""" 66 67import argparse 68import io 69import json 70import os 71import pickle 72import pprint 73import re 74import sys 75import urllib.parse 76import zipfile 77from pathlib import Path 78from typing import Dict 79 80import torch.utils.show_pickle 81 82 83DEFAULT_EXTRA_FILE_SIZE_LIMIT = 16 * 1024 84 85__all__ = ['get_storage_info', 'hierarchical_pickle', 'get_model_info', 'get_inline_skeleton', 86 'burn_in_info', 'get_info_and_burn_skeleton'] 87 88def get_storage_info(storage): 89 assert isinstance(storage, torch.utils.show_pickle.FakeObject) 90 assert storage.module == "pers" 91 assert storage.name == "obj" 92 assert storage.state is None 93 assert isinstance(storage.args, tuple) 94 assert len(storage.args) == 1 95 sa = storage.args[0] 96 assert isinstance(sa, tuple) 97 assert len(sa) == 5 98 assert sa[0] == "storage" 99 assert isinstance(sa[1], torch.utils.show_pickle.FakeClass) 100 assert sa[1].module == "torch" 101 assert sa[1].name.endswith("Storage") 102 storage_info = [sa[1].name.replace("Storage", "")] + list(sa[2:]) 103 return storage_info 104 105 106def hierarchical_pickle(data): 107 if isinstance(data, (bool, int, float, str, type(None))): 108 return data 109 if isinstance(data, list): 110 return [hierarchical_pickle(d) for d in data] 111 if isinstance(data, tuple): 112 return { 113 "__tuple_values__": hierarchical_pickle(list(data)), 114 } 115 if isinstance(data, dict): 116 return { 117 "__is_dict__": True, 118 "keys": hierarchical_pickle(list(data.keys())), 119 "values": hierarchical_pickle(list(data.values())), 120 } 121 if isinstance(data, torch.utils.show_pickle.FakeObject): 122 typename = f"{data.module}.{data.name}" 123 if ( 124 typename.startswith(('__torch__.', 'torch.jit.LoweredWrapper.', 'torch.jit.LoweredModule.')) 125 ): 126 assert data.args == () 127 return { 128 "__module_type__": typename, 129 "state": hierarchical_pickle(data.state), 130 } 131 if typename == "torch._utils._rebuild_tensor_v2": 132 assert data.state is None 133 if len(data.args) == 6: 134 storage, offset, size, stride, requires_grad, hooks = data.args 135 else: 136 storage, offset, size, stride, requires_grad, hooks, metadata = data.args 137 storage_info = get_storage_info(storage) 138 return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} 139 if typename == "torch._utils._rebuild_qtensor": 140 assert data.state is None 141 storage, offset, size, stride, quantizer, requires_grad, hooks = data.args 142 storage_info = get_storage_info(storage) 143 assert isinstance(quantizer, tuple) 144 assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass) 145 assert quantizer[0].module == "torch" 146 if quantizer[0].name == "per_tensor_affine": 147 assert len(quantizer) == 3 148 assert isinstance(quantizer[1], float) 149 assert isinstance(quantizer[2], int) 150 quantizer_extra = list(quantizer[1:3]) 151 else: 152 quantizer_extra = [] 153 quantizer_json = [quantizer[0].name] + quantizer_extra 154 return {"__qtensor__": [storage_info, offset, size, stride, quantizer_json, requires_grad]} 155 if typename == "torch.jit._pickle.restore_type_tag": 156 assert data.state is None 157 obj, typ = data.args 158 assert isinstance(typ, str) 159 return hierarchical_pickle(obj) 160 if re.fullmatch(r"torch\.jit\._pickle\.build_[a-z]+list", typename): 161 assert data.state is None 162 ls, = data.args 163 assert isinstance(ls, list) 164 return hierarchical_pickle(ls) 165 if typename == "torch.device": 166 assert data.state is None 167 name, = data.args 168 assert isinstance(name, str) 169 # Just forget that it was a device and return the name. 170 return name 171 if typename == "builtin.UnicodeDecodeError": 172 assert data.state is None 173 msg, = data.args 174 assert isinstance(msg, str) 175 # Hack: Pretend this is a module so we don't need custom serialization. 176 # Hack: Wrap the message in a tuple so it looks like a nice state object. 177 # TODO: Undo at least that second hack. We should support string states. 178 return { 179 "__module_type__": typename, 180 "state": hierarchical_pickle((msg,)), 181 } 182 raise Exception(f"Can't prepare fake object of type for JS: {typename}") # noqa: TRY002 183 raise Exception(f"Can't prepare data of type for JS: {type(data)}") # noqa: TRY002 184 185 186def get_model_info( 187 path_or_file, 188 title=None, 189 extra_file_size_limit=DEFAULT_EXTRA_FILE_SIZE_LIMIT): 190 """Get JSON-friendly information about a model. 191 192 The result is suitable for being saved as model_info.json, 193 or passed to burn_in_info. 194 """ 195 196 if isinstance(path_or_file, os.PathLike): 197 default_title = os.fspath(path_or_file) 198 file_size = path_or_file.stat().st_size # type: ignore[attr-defined] 199 elif isinstance(path_or_file, str): 200 default_title = path_or_file 201 file_size = Path(path_or_file).stat().st_size 202 else: 203 default_title = "buffer" 204 path_or_file.seek(0, io.SEEK_END) 205 file_size = path_or_file.tell() 206 path_or_file.seek(0) 207 208 title = title or default_title 209 210 with zipfile.ZipFile(path_or_file) as zf: 211 path_prefix = None 212 zip_files = [] 213 for zi in zf.infolist(): 214 prefix = re.sub("/.*", "", zi.filename) 215 if path_prefix is None: 216 path_prefix = prefix 217 elif prefix != path_prefix: 218 raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 219 zip_files.append(dict( 220 filename=zi.filename, 221 compression=zi.compress_type, 222 compressed_size=zi.compress_size, 223 file_size=zi.file_size, 224 )) 225 226 assert path_prefix is not None 227 version = zf.read(path_prefix + "/version").decode("utf-8").strip() 228 229 def get_pickle(name): 230 assert path_prefix is not None 231 with zf.open(path_prefix + f"/{name}.pkl") as handle: 232 raw = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() 233 return hierarchical_pickle(raw) 234 235 model_data = get_pickle("data") 236 constants = get_pickle("constants") 237 238 # Intern strings that are likely to be re-used. 239 # Pickle automatically detects shared structure, 240 # so re-used strings are stored efficiently. 241 # However, JSON has no way of representing this, 242 # so we have to do it manually. 243 interned_strings : Dict[str, int] = {} 244 245 def ist(s): 246 if s not in interned_strings: 247 interned_strings[s] = len(interned_strings) 248 return interned_strings[s] 249 250 code_files = {} 251 for zi in zf.infolist(): 252 if not zi.filename.endswith(".py"): 253 continue 254 with zf.open(zi) as handle: 255 raw_code = handle.read() 256 with zf.open(zi.filename + ".debug_pkl") as handle: 257 raw_debug = handle.read() 258 259 # Parse debug info and add begin/end markers if not present 260 # to ensure that we cover the entire source code. 261 debug_info_t = pickle.loads(raw_debug) 262 text_table = None 263 264 if (len(debug_info_t) == 3 and 265 isinstance(debug_info_t[0], str) and 266 debug_info_t[0] == 'FORMAT_WITH_STRING_TABLE'): 267 _, text_table, content = debug_info_t 268 269 def parse_new_format(line): 270 # (0, (('', '', 0), 0, 0)) 271 num, ((text_indexes, fname_idx, offset), start, end), tag = line 272 text = ''.join(text_table[x] for x in text_indexes) # type: ignore[index] 273 fname = text_table[fname_idx] # type: ignore[index] 274 return num, ((text, fname, offset), start, end), tag 275 276 debug_info_t = map(parse_new_format, content) 277 278 debug_info = list(debug_info_t) 279 if not debug_info: 280 debug_info.append((0, (('', '', 0), 0, 0))) 281 if debug_info[-1][0] != len(raw_code): 282 debug_info.append((len(raw_code), (('', '', 0), 0, 0))) 283 284 code_parts = [] 285 for di, di_next in zip(debug_info, debug_info[1:]): 286 start, source_range, *_ = di 287 end = di_next[0] 288 assert end > start 289 source, s_start, s_end = source_range 290 s_text, s_file, s_line = source 291 # TODO: Handle this case better. TorchScript ranges are in bytes, 292 # but JS doesn't really handle byte strings. 293 # if bytes and chars are not equivalent for this string, 294 # zero out the ranges so we don't highlight the wrong thing. 295 if len(s_text) != len(s_text.encode("utf-8")): 296 s_start = 0 297 s_end = 0 298 text = raw_code[start:end] 299 code_parts.append([text.decode("utf-8"), ist(s_file), s_line, ist(s_text), s_start, s_end]) 300 code_files[zi.filename] = code_parts 301 302 extra_files_json_pattern = re.compile(re.escape(path_prefix) + "/extra/.*\\.json") 303 extra_files_jsons = {} 304 for zi in zf.infolist(): 305 if not extra_files_json_pattern.fullmatch(zi.filename): 306 continue 307 if zi.file_size > extra_file_size_limit: 308 continue 309 with zf.open(zi) as handle: 310 try: 311 json_content = json.load(handle) 312 extra_files_jsons[zi.filename] = json_content 313 except json.JSONDecodeError: 314 extra_files_jsons[zi.filename] = "INVALID JSON" 315 316 always_render_pickles = { 317 "bytecode.pkl", 318 } 319 extra_pickles = {} 320 for zi in zf.infolist(): 321 if not zi.filename.endswith(".pkl"): 322 continue 323 with zf.open(zi) as handle: 324 # TODO: handle errors here and just ignore the file? 325 # NOTE: For a lot of these files (like bytecode), 326 # we could get away with just unpickling, but this should be safer. 327 obj = torch.utils.show_pickle.DumpUnpickler(handle, catch_invalid_utf8=True).load() 328 buf = io.StringIO() 329 pprint.pprint(obj, buf) 330 contents = buf.getvalue() 331 # Checked the rendered length instead of the file size 332 # because pickles with shared structure can explode in size during rendering. 333 if os.path.basename(zi.filename) not in always_render_pickles and \ 334 len(contents) > extra_file_size_limit: 335 continue 336 extra_pickles[zi.filename] = contents 337 338 return {"model": dict( 339 title=title, 340 file_size=file_size, 341 version=version, 342 zip_files=zip_files, 343 interned_strings=list(interned_strings), 344 code_files=code_files, 345 model_data=model_data, 346 constants=constants, 347 extra_files_jsons=extra_files_jsons, 348 extra_pickles=extra_pickles, 349 )} 350 351 352def get_inline_skeleton(): 353 """Get a fully-inlined skeleton of the frontend. 354 355 The returned HTML page has no external network dependencies for code. 356 It can load model_info.json over HTTP, or be passed to burn_in_info. 357 """ 358 359 import importlib.resources 360 361 skeleton = importlib.resources.read_text(__package__, "skeleton.html") 362 js_code = importlib.resources.read_text(__package__, "code.js") 363 for js_module in ["preact", "htm"]: 364 js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") 365 js_url = "data:application/javascript," + urllib.parse.quote(js_lib) 366 js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) 367 skeleton = skeleton.replace(' src="./code.js">', ">\n" + js_code) 368 return skeleton 369 370 371def burn_in_info(skeleton, info): 372 """Burn model info into the HTML skeleton. 373 374 The result will render the hard-coded model info and 375 have no external network dependencies for code or data. 376 """ 377 378 # Note that Python's json serializer does not escape slashes in strings. 379 # Since we're inlining this JSON directly into a script tag, a string 380 # containing "</script>" would end the script prematurely and 381 # mess up our page. Unconditionally escape fixes that. 382 return skeleton.replace( 383 "BURNED_IN_MODEL_INFO = null", 384 "BURNED_IN_MODEL_INFO = " + json.dumps(info, sort_keys=True).replace("/", "\\/")) 385 386 387def get_info_and_burn_skeleton(path_or_bytesio, **kwargs): 388 model_info = get_model_info(path_or_bytesio, **kwargs) 389 skeleton = get_inline_skeleton() 390 page = burn_in_info(skeleton, model_info) 391 return page 392 393 394def main(argv, *, stdout=None): 395 parser = argparse.ArgumentParser() 396 parser.add_argument("--style", choices=["json", "html"]) 397 parser.add_argument("--title") 398 parser.add_argument("model") 399 args = parser.parse_args(argv[1:]) 400 401 info = get_model_info(args.model, title=args.title) 402 403 output = stdout or sys.stdout 404 405 if args.style == "json": 406 output.write(json.dumps(info, sort_keys=True) + "\n") 407 elif args.style == "html": 408 skeleton = get_inline_skeleton() 409 page = burn_in_info(skeleton, info) 410 output.write(page) 411 else: 412 raise Exception("Invalid style") # noqa: TRY002 413