xref: /aosp_15_r20/external/pytorch/torch/utils/model_dump/code.js (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import { h, Component, render } from 'https://unpkg.com/preact?module';
2import htm from 'https://unpkg.com/htm?module';
3
4const html = htm.bind(h);
5
6const BURNED_IN_MODEL_INFO = null;
7
8// https://stackoverflow.com/a/20732091
9function humanFileSize(size) {
10  if (size == 0) { return "0 B"; }
11  var i = Math.floor( Math.log(size) / Math.log(1024) );
12  return (size / Math.pow(1024, i)).toFixed(2) * 1 + ' ' + ['B', 'kB', 'MB', 'GB', 'TB'][i];
13}
14
15function caret(down) {
16  return down ? "\u25BE" : "\u25B8";
17}
18
19class Blamer {
20  constructor() {
21    this.blame_on_click = false;
22    this.aux_content_pane = null;
23  }
24
25  setAuxContentPane(pane) {
26    this.aux_content_pane = pane;
27  }
28
29  readyBlame() {
30    this.blame_on_click = true;
31  }
32
33  maybeBlame(arg) {
34    if (!this.blame_on_click) {
35      return;
36    }
37    this.blame_on_click = false;
38    if (!this.aux_content_pane) {
39      return;
40    }
41    this.aux_content_pane.doBlame(arg);
42  }
43}
44
45let blame = new Blamer();
46
47class Hider extends Component {
48  constructor() {
49    super();
50    this.state = { shown: null };
51  }
52
53  componentDidMount() {
54    this.setState({ shown: this.props.shown === "true" });
55  }
56
57  render({name, children}, {shown}) {
58    let my_caret = html`<span class=caret onClick=${() => this.click()} >${caret(shown)}</span>`;
59    return html`<div data-hider-title=${name} data-shown=${shown}>
60      <h2>${my_caret} ${name}</h2>
61      <div>${shown ? this.props.children : []}</div></div>`;
62  }
63
64  click() {
65    this.setState({shown: !this.state.shown});
66  }
67}
68
69function ModelSizeSection({model: {file_size, zip_files}}) {
70  let store_size = 0;
71  let compr_size = 0;
72  for (const zi of zip_files) {
73    if (zi.compression === 0) {
74      // TODO: Maybe check that compressed_size === file_size.
75      store_size += zi.compressed_size;
76    } else {
77      compr_size += zi.compressed_size;
78    }
79  }
80  let zip_overhead = file_size - store_size - compr_size;
81  // TODO: Better formatting.  Right-align this.
82  return html`
83    <${Hider} name="Model Size" shown=true>
84    <pre>.
85      Model size: ${file_size} (${humanFileSize(file_size)})
86      Stored files: ${store_size} (${humanFileSize(store_size)})
87      Compressed files: ${compr_size} (${humanFileSize(compr_size)})
88      Zip overhead: ${zip_overhead} (${humanFileSize(zip_overhead)})
89    </pre><//>`;
90}
91
92function StructuredDataSection({name, data, shown}) {
93  return html`
94    <${Hider} name=${name} shown=${shown}>
95    <div style="font-family:monospace;">
96      <${StructuredData} data=${data} indent="" prefix=""/>
97    </div><//>`;
98}
99
100class StructuredData extends Component {
101  constructor() {
102    super();
103    this.state = { shown: false };
104
105    this.INLINE_TYPES = new Set(["boolean", "number", "string"])
106    this.IGNORED_STATE_KEYS = new Set(["training", "_is_full_backward_hook"])
107  }
108
109  click() {
110    this.setState({shown: !this.state.shown});
111  }
112
113  expando(data) {
114    if (data === null || this.INLINE_TYPES.has(typeof(data))) {
115      return false;
116    }
117    if (typeof(data) != "object") {
118      throw new Error("Not an object");
119    }
120    if (Array.isArray(data)) {
121      // TODO: Maybe show simple lists and tuples on one line.
122      return true;
123    }
124    if (data.__tuple_values__) {
125      // TODO: Maybe show simple lists and tuples on one line.
126      return true;
127    }
128    if (data.__is_dict__) {
129      // TODO: Maybe show simple (empty?) dicts on one line.
130      return true;
131    }
132    if (data.__module_type__) {
133      return true;
134    }
135    if (data.__tensor_v2__) {
136      return false;
137    }
138    if (data.__qtensor__) {
139      return false;
140    }
141    throw new Error("Can't handle data type.", data);
142  }
143
144  renderHeadline(data) {
145    if (data === null) {
146      return "None";
147    }
148    if (typeof(data) == "boolean") {
149      const sd = String(data);
150      return sd.charAt(0).toUpperCase() + sd.slice(1);
151    }
152    if (typeof(data) == "number") {
153      return JSON.stringify(data);
154    }
155    if (typeof(data) == "string") {
156      return JSON.stringify(data);
157    }
158    if (typeof(data) != "object") {
159      throw new Error("Not an object");
160    }
161    if (Array.isArray(data)) {
162      return "list([";
163    }
164    if (data.__tuple_values__) {
165      return "tuple((";
166    }
167    if (data.__is_dict__) {
168      return "dict({";
169    }
170    if (data.__module_type__) {
171      return data.__module_type__ + "()";
172    }
173    if (data.__tensor_v2__) {
174      const [storage, offset, size, stride, grad] = data.__tensor_v2__;
175      const [dtype, key, device, numel] = storage;
176      return this.renderTensor(
177        "tensor", dtype, key, device, numel, offset, size, stride, grad, []);
178    }
179    if (data.__qtensor__) {
180      const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__;
181      const [dtype, key, device, numel] = storage;
182      let extra_parts = [];
183      if (quantizer[0] == "per_tensor_affine") {
184        extra_parts.push(`scale=${quantizer[1]}`);
185        extra_parts.push(`zero_point=${quantizer[2]}`);
186      } else {
187        extra_parts.push(`quantizer=${quantizer[0]}`);
188      }
189      return this.renderTensor(
190        "qtensor", dtype, key, device, numel, offset, size, stride, grad, extra_parts);
191    }
192    throw new Error("Can't handle data type.", data);
193  }
194
195  renderTensor(
196      prefix,
197      dtype,
198      storage_key,
199      device,
200      storage_numel,
201      offset,
202      size,
203      stride,
204      grad,
205      extra_parts) {
206    let parts = [
207      "(" + size.join(",") + ")",
208      dtype,
209    ];
210    parts.push(...extra_parts);
211    if (device != "cpu") {
212      parts.push(device);
213    }
214    if (grad) {
215      parts.push("grad");
216    }
217    // TODO: Check stride and indicate if the tensor is channels-last or non-contiguous
218    // TODO: Check size, stride, offset, and numel and indicate if
219    // the tensor doesn't use all data in storage.
220    // TODO: Maybe show key?
221    void(offset);
222    void(stride);
223    void(storage_key);
224    void(storage_numel);
225    return prefix + "(" + parts.join(", ") + ")";
226  }
227
228  renderBody(indent, data) {
229    if (data === null || this.INLINE_TYPES.has(typeof(data))) {
230      throw "Should not reach here."
231    }
232    if (typeof(data) != "object") {
233      throw new Error("Not an object");
234    }
235    if (Array.isArray(data)) {
236      let new_indent = indent + "\u00A0\u00A0";
237      let parts = [];
238      for (let idx = 0; idx < data.length; idx++) {
239        // Does it make sense to put explicit index numbers here?
240        parts.push(html`<br/><${StructuredData} prefix=${idx + ": "} indent=${new_indent} data=${data[idx]} />`);
241      }
242      return parts;
243    }
244    if (data.__tuple_values__) {
245      // Handled the same as lists.
246      return this.renderBody(indent, data.__tuple_values__);
247    }
248    if (data.__is_dict__) {
249      let new_indent = indent + "\u00A0\u00A0";
250      let parts = [];
251      for (let idx = 0; idx < data.keys.length; idx++) {
252        if (typeof(data.keys[idx]) != "string") {
253          parts.push(html`<br/>${new_indent}Non-string key`);
254        } else {
255          parts.push(html`<br/><${StructuredData} prefix=${data.keys[idx] + ": "} indent=${new_indent} data=${data.values[idx]} />`);
256        }
257      }
258      return parts;
259    }
260    if (data.__module_type__) {
261      const mstate = data.state;
262      if (mstate === null || typeof(mstate) != "object") {
263        throw new Error("Bad module state");
264      }
265      let new_indent = indent + "\u00A0\u00A0";
266      let parts = [];
267      if (mstate.__is_dict__) {
268        // TODO: Less copy/paste between this and normal dicts.
269        for (let idx = 0; idx < mstate.keys.length; idx++) {
270          if (typeof(mstate.keys[idx]) != "string") {
271            parts.push(html`<br/>${new_indent}Non-string key`);
272          } else if (this.IGNORED_STATE_KEYS.has(mstate.keys[idx])) {
273            // Do nothing.
274          } else {
275            parts.push(html`<br/><${StructuredData} prefix=${mstate.keys[idx] + ": "} indent=${new_indent} data=${mstate.values[idx]} />`);
276          }
277        }
278      } else if (mstate.__tuple_values__) {
279        parts.push(html`<br/><${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`);
280      } else if (mstate.__module_type__) {
281        // We normally wouldn't have the state of a module be another module,
282        // but we use "modules" to encode special values (like Unicode decode
283        // errors) that might be valid states.  Just go with it.
284        parts.push(html`<br/><${StructuredData} prefix="" indent=${new_indent} data=${mstate} />`);
285      } else {
286        throw new Error("Bad module state");
287      }
288      return parts;
289    }
290    if (data.__tensor_v2__) {
291      throw "Should not reach here."
292    }
293    if (data.__qtensor__) {
294      throw "Should not reach here."
295    }
296    throw new Error("Can't handle data type.", data);
297  }
298
299  render({data, indent, prefix}, {shown}) {
300    const exp = this.expando(data) ? html`<span class=caret onClick=${() => this.click()} >${caret(shown)} </span>` : "";
301    const headline = this.renderHeadline(data);
302    const body = shown ? this.renderBody(indent, data) : "";
303    return html`${indent}${exp}${prefix}${headline}${body}`;
304  }
305}
306
307function ZipContentsSection({model: {zip_files}}) {
308  // TODO: Add human-readable sizes?
309  // TODO: Add sorting options?
310  // TODO: Add hierarchical collapsible tree?
311  return html`
312    <${Hider} name="Zip Contents" shown=false>
313    <table>
314      <thead>
315        <tr>
316          <th>Mode</th>
317          <th>Size</th>
318          <th>Compressed</th>
319          <th>Name</th>
320        </tr>
321      </thead>
322      <tbody style="font-family:monospace;">
323        ${zip_files.map(zf => html`<tr>
324          <td>${{0: "store", 8: "deflate"}[zf.compression] || zf.compression}</td>
325          <td>${zf.file_size}</td>
326          <td>${zf.compressed_size}</td>
327          <td>${zf.filename}</td>
328        </tr>`)}
329      </tbody>
330    </table><//>`;
331}
332
333function CodeSection({model: {code_files}}) {
334  return html`
335    <${Hider} name="Code" shown=false>
336    <div>
337      ${Object.entries(code_files).map(([fn, code]) => html`<${OneCodeSection}
338          filename=${fn} code=${code} />`)}
339    </div><//>`;
340}
341
342class OneCodeSection extends Component {
343  constructor() {
344    super();
345    this.state = { shown: false };
346  }
347
348  click() {
349    const shown = !this.state.shown;
350    this.setState({shown: shown});
351  }
352
353  render({filename, code}, {shown}) {
354    const header = html`
355        <h3 style="font-family:monospace;">
356        <span class=caret onClick=${() => this.click()} >${caret(shown)} </span>
357        ${filename}</h3>
358        `;
359    if (!shown) {
360      return header;
361    }
362    return html`
363      ${header}
364      <pre>${code.map(c => this.renderBlock(c))}</pre>
365      `;
366  }
367
368  renderBlock([text, ist_file, line, ist_s_text, s_start, s_end]) {
369    return html`<span
370        onClick=${() => blame.maybeBlame({ist_file, line, ist_s_text, s_start, s_end})}
371      >${text}</span>`;
372  }
373}
374
375function ExtraJsonSection({files}) {
376  return html`
377    <${Hider} name="Extra files (JSON)" shown=false>
378    <div>
379      <p>Use "Log Raw Model Info" for hierarchical view in browser console.</p>
380      ${Object.entries(files).map(([fn, json]) => html`<${OneJsonSection}
381          filename=${fn} json=${json} />`)}
382    </div><//>`;
383}
384
385class OneJsonSection extends Component {
386  constructor() {
387    super();
388    this.state = { shown: false };
389  }
390
391  click() {
392    const shown = !this.state.shown;
393    this.setState({shown: shown});
394  }
395
396  render({filename, json}, {shown}) {
397    const header = html`
398        <h3 style="font-family:monospace;">
399        <span class=caret onClick=${() => this.click()} >${caret(shown)} </span>
400        ${filename}</h3>
401        `;
402    if (!shown) {
403      return header;
404    }
405    return html`
406      ${header}
407      <pre>${JSON.stringify(json, null, 2)}</pre>
408      `;
409  }
410}
411
412function ExtraPicklesSection({files}) {
413  return html`
414    <${Hider} name="Extra Pickles" shown=false>
415    <div>
416      ${Object.entries(files).map(([fn, content]) => html`<${OnePickleSection}
417          filename=${fn} content=${content} />`)}
418    </div><//>`;
419}
420
421class OnePickleSection extends Component {
422  constructor() {
423    super();
424    this.state = { shown: false };
425  }
426
427  click() {
428    const shown = !this.state.shown;
429    this.setState({shown: shown});
430  }
431
432  render({filename, content}, {shown}) {
433    const header = html`
434        <h3 style="font-family:monospace;">
435        <span class=caret onClick=${() => this.click()} >${caret(shown)} </span>
436        ${filename}</h3>
437        `;
438    if (!shown) {
439      return header;
440    }
441    return html`
442      ${header}
443      <pre>${content}</pre>
444      `;
445  }
446}
447
448function assertStorageAreEqual(key, lhs, rhs) {
449  if (lhs.length !== rhs.length ||
450    !lhs.every((val, idx) => val === rhs[idx])) {
451    throw new Error("Storage mismatch for key '" + key + "'");
452  }
453}
454
455function computeTensorMemory(numel, dtype) {
456  const sizes = {
457    "Byte": 1,
458    "Char": 1,
459    "Short": 2,
460    "Int": 4,
461    "Long": 8,
462    "Half": 2,
463    "Float": 4,
464    "Double": 8,
465    "ComplexHalf": 4,
466    "ComplexFloat": 8,
467    "ComplexDouble": 16,
468    "Bool": 1,
469    "QInt8": 1,
470    "QUInt8": 1,
471    "QInt32": 4,
472    "BFloat16": 2,
473  };
474  let dtsize = sizes[dtype];
475  if (!dtsize) {
476    throw new Error("Unrecognized dtype: " + dtype);
477  }
478  return numel * dtsize;
479}
480
481// TODO: Maybe track by dtype as well.
482// TODO: Maybe distinguish between visible size and storage size.
483function getTensorStorages(data) {
484  if (data === null) {
485    return new Map();
486  }
487  if (typeof(data) == "boolean") {
488    return new Map();
489  }
490  if (typeof(data) == "number") {
491    return new Map();
492  }
493  if (typeof(data) == "string") {
494    return new Map();
495  }
496  if (typeof(data) != "object") {
497    throw new Error("Not an object");
498  }
499  if (Array.isArray(data)) {
500    let result = new Map();
501    for (const item of data) {
502      const tensors = getTensorStorages(item);
503      for (const [key, storage] of tensors.entries()) {
504        if (!result.has(key)) {
505          result.set(key, storage);
506        } else {
507          const old_storage = result.get(key);
508          assertStorageAreEqual(key, old_storage, storage);
509        }
510      }
511    }
512    return result;
513  }
514  if (data.__tuple_values__) {
515    return getTensorStorages(data.__tuple_values__);
516  }
517  if (data.__is_dict__) {
518    return getTensorStorages(data.values);
519  }
520  if (data.__module_type__) {
521    return getTensorStorages(data.state);
522  }
523  if (data.__tensor_v2__) {
524    const [storage, offset, size, stride, grad] = data.__tensor_v2__;
525    const [dtype, key, device, numel] = storage;
526    return new Map([[key, storage]]);
527  }
528  if (data.__qtensor__) {
529    const [storage, offset, size, stride, quantizer, grad] = data.__qtensor__;
530    const [dtype, key, device, numel] = storage;
531    return new Map([[key, storage]]);
532  }
533  throw new Error("Can't handle data type.", data);
534}
535
536function getTensorMemoryByDevice(pickles) {
537  let all_tensors = [];
538  for (const [name, pickle] of pickles) {
539    const tensors = getTensorStorages(pickle);
540    all_tensors.push(...tensors.values());
541  }
542  let result = {};
543  for (const storage of all_tensors.values()) {
544    const [dtype, key, device, numel] = storage;
545    const size = computeTensorMemory(numel, dtype);
546    result[device] = (result[device] || 0) + size;
547  }
548  return result;
549}
550
551// Make this a separate component so it is rendered lazily.
552class OpenTensorMemorySection extends Component {
553  render({model: {model_data, constants}}) {
554    let sizes = getTensorMemoryByDevice(new Map([
555      ["data", model_data],
556      ["constants", constants],
557    ]));
558    return html`
559      <table>
560        <thead>
561          <tr>
562            <th>Device</th>
563            <th>Bytes</th>
564            <th>Human</th>
565          </tr>
566        </thead>
567        <tbody style="font-family:monospace;">
568          ${Object.entries(sizes).map(([dev, size]) => html`<tr>
569            <td>${dev}</td>
570            <td>${size}</td>
571            <td>${humanFileSize(size)}</td>
572          </tr>`)}
573        </tbody>
574      </table>`;
575  }
576}
577
578function TensorMemorySection({model}) {
579  return html`
580    <${Hider} name="Tensor Memory" shown=false>
581    <${OpenTensorMemorySection} model=${model} /><//>`;
582}
583
584class AuxContentPane extends Component {
585  constructor() {
586    super();
587    this.state = {
588      blame_info: null,
589    };
590  }
591
592  doBlame(arg) {
593    this.setState({...this.state, blame_info: arg});
594  }
595
596  render({model: {interned_strings}}, {blame_info}) {
597    let blame_content = "";
598    if (blame_info) {
599      const {ist_file, line, ist_s_text, s_start, s_end} = blame_info;
600      let s_text = interned_strings[ist_s_text];
601      if (s_start != 0 || s_end != s_text.length) {
602        let prefix = s_text.slice(0, s_start);
603        let main = s_text.slice(s_start, s_end);
604        let suffix = s_text.slice(s_end);
605        s_text = html`${prefix}<strong>${main}</strong>${suffix}`;
606      }
607      blame_content = html`
608        <h3>${interned_strings[ist_file]}:${line}</h3>
609        <pre>${s_start}:${s_end}</pre>
610        <pre>${s_text}</pre><br/>
611        `;
612    }
613    return html`
614      <button onClick=${() => blame.readyBlame()}>Blame Code</button>
615      <br/>
616      ${blame_content}
617      `;
618  }
619}
620
621class App extends Component {
622  constructor() {
623    super();
624    this.state = {
625      err: false,
626      model: null,
627    };
628  }
629
630  componentDidMount() {
631    const app = this;
632    if (BURNED_IN_MODEL_INFO !== null) {
633      app.setState({model: BURNED_IN_MODEL_INFO});
634    } else {
635      fetch("./model_info.json").then(function(response) {
636        if (!response.ok) {
637          throw new Error("Response not ok.");
638        }
639        return response.json();
640      }).then(function(body) {
641        app.setState({model: body});
642      }).catch(function(error) {
643        console.log("Top-level error: ", error);
644      });
645    }
646  }
647
648  componentDidCatch(error) {
649    void(error);
650    this.setState({...this.state, err: true});
651  }
652
653  render(_, {err}) {
654    if (this.state.model === null) {
655      return html`<h1>Loading...</h1>`;
656    }
657
658    const model = this.state.model.model;
659
660    let error_msg = "";
661    if (err) {
662      error_msg = html`<h2 style="background:red">An error occurred.  Check console</h2>`;
663    }
664
665    return html`
666      ${error_msg}
667      <div id=main_content style="position:absolute;width:99%;height:79%;overflow:scroll">
668        <h1>TorchScript Model (version ${model.version}): ${model.title}</h1>
669        <button onClick=${() => console.log(model)}>Log Raw Model Info</button>
670        <${ModelSizeSection} model=${model}/>
671        <${StructuredDataSection} name="Model Data" data=${model.model_data} shown=true/>
672        <${StructuredDataSection} name="Constants" data=${model.constants} shown=false/>
673        <${ZipContentsSection} model=${model}/>
674        <${CodeSection} model=${model}/>
675        <${ExtraJsonSection} files=${model.extra_files_jsons}/>
676        <${ExtraPicklesSection} files=${model.extra_pickles}/>
677        <${TensorMemorySection} model=${model}/>
678      </div>
679      <div id=aux_content style="position:absolute;width:99%;top:80%;height:20%;overflow:scroll">
680        <${AuxContentPane}
681          err=${this.state.error}
682          model=${model}
683          ref=${(p) => blame.setAuxContentPane(p)}/>
684      </div>
685      `;
686  }
687}
688
689render(h(App), document.body);
690