xref: /aosp_15_r20/external/pytorch/torch/utils/viz/MemoryViz.js (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1'use strict';
2
3import * as d3 from "https://cdn.skypack.dev/d3@5";
4import {axisLeft} from "https://cdn.skypack.dev/d3-axis@1";
5import {scaleLinear} from "https://cdn.skypack.dev/d3-scale@1";
6import {zoom, zoomIdentity} from "https://cdn.skypack.dev/d3-zoom@1";
7import {brushX} from "https://cdn.skypack.dev/d3-brush@1";
8
9const schemeTableau10 = [
10  '#4e79a7',
11  '#f28e2c',
12  '#e15759',
13  '#76b7b2',
14  '#59a14f',
15  '#edc949',
16  '#af7aa1',
17  '#ff9da7',
18  '#9c755f',
19  '#bab0ab',
20];
21
22function version_space() {
23  const version = {};
24  return (addr, increment) => {
25    if (!(addr in version)) {
26      version[addr] = 0;
27    }
28    const r = version[addr];
29    if (increment) {
30      version[addr]++;
31    }
32    return r;
33  };
34}
35
36function Segment(addr, size, stream, frames, version) {
37  return {addr, size, stream, version, frames};
38}
39
40function Block(addr, size, requested_size, frames, free_requested, version) {
41  return {addr, size, requested_size, frames, free_requested, version};
42}
43
44function EventSelector(outer, events, stack_info, memory_view) {
45  const events_div = outer
46    .append('div')
47    .attr(
48      'style',
49      'grid-column: 1; grid-row: 1; overflow: auto; font-family: monospace',
50    );
51
52  const events_selection = events_div
53    .selectAll('pre')
54    .data(events)
55    .enter()
56    .append('pre')
57    .text(e => formatEvent(e))
58    .attr('style', '');
59
60  let selected_event_idx = null;
61
62  const es = {
63    select(idx) {
64      if (selected_event_idx !== null) {
65        const selected_event = d3.select(
66          events_div.node().children[selected_event_idx],
67        );
68        selected_event.attr('style', '');
69      }
70      if (idx !== null) {
71        const div = d3.select(events_div.node().children[idx]);
72        div.attr('style', `background-color: ${schemeTableau10[5]}`);
73        const [reserved, allocated] = memory_view.draw(idx);
74        const enter = () => eventStack(div.datum(), allocated, reserved);
75        stack_info.highlight(enter);
76        div.node().scrollIntoViewIfNeeded(false);
77      } else {
78        memory_view.draw(0);
79      }
80      selected_event_idx = idx;
81    },
82  };
83  d3.select('body').on('keydown', _e => {
84    const key = d3.event.key;
85    const actions = {ArrowDown: 1, ArrowUp: -1};
86    if (selected_event_idx !== null && key in actions) {
87      const new_idx = selected_event_idx + actions[key];
88      es.select(Math.max(0, Math.min(new_idx, events.length - 1)));
89      d3.event.preventDefault();
90    }
91  });
92
93  stack_info.register(
94    events_selection,
95    t => eventStack(t.datum()),
96    _t => {},
97    d => es.select(d.datum().idx),
98  );
99
100  return es;
101}
102
103function formatSize(num) {
104  const orig = num;
105  // https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size
106  const units = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'];
107  for (const unit of units) {
108    if (Math.abs(num) < 1024.0) {
109      return `${num.toFixed(1)}${unit}B (${orig} bytes)`;
110    }
111    num /= 1024.0;
112  }
113  return `${num.toFixed(1)}YiB`;
114}
115function formatAddr(event) {
116  const prefix = event.action.startsWith('segment') ? 's' : 'b';
117  return `${prefix}${event.addr.toString(16)}_${event.version}`;
118}
119function formatEvent(event) {
120  const stream =
121    event.stream === null ? '' : `\n              (stream ${event.stream})`;
122  switch (event.action) {
123    case 'oom':
124      return `OOM (requested ${formatSize(event.size)}, CUDA has ${formatSize(
125        event.device_free,
126      )} memory free)${stream}`;
127    case 'snapshot':
128      return 'snapshot';
129    default:
130      return `${event.action.padEnd(14)} ${formatAddr(event).padEnd(
131        18,
132      )} ${formatSize(event.size)}${stream}`;
133  }
134}
135
136function eventStack(e, allocated, reserved) {
137  let event = formatEvent(e);
138  if (reserved !== undefined) {
139    event = `(${formatSize(allocated)} allocated / ${formatSize(
140      reserved,
141    )} reserved)\n${event}`;
142  }
143  return event + '\n' + format_frames(e.frames);
144}
145
146function hashCode(num) {
147  const numStr = num.toString();
148  let hash = 0;
149  for (let i = 0; i < numStr.length; i++) {
150    const charCode = numStr.charCodeAt(i);
151    hash = (hash << 5) - hash + charCode;
152    hash = hash & hash; // Convert to 32-bit integer
153  }
154  return hash;
155}
156
157function addStroke(d) {
158  d.attr('stroke', 'red')
159    .attr('stroke-width', '2')
160    .attr('vector-effect', 'non-scaling-stroke');
161}
162
163function removeStroke(d) {
164  d.attr('stroke', '');
165}
166
167function calculate_fragmentation(blocks, sorted_segments) {
168  const sorted_blocks = Object.values(blocks).sort((a, b) => a.addr - b.addr);
169  let block_i = 0;
170  let total_size = 0;
171  let sum_squared_free = 0;
172  for (const seg of sorted_segments) {
173    let addr = seg.addr;
174    total_size += seg.size;
175    while (
176      block_i < sorted_blocks.length &&
177      sorted_blocks[block_i].addr < seg.addr + seg.size
178    ) {
179      const block = sorted_blocks[block_i];
180      if (block.addr > addr) {
181        sum_squared_free += (block.addr - addr) ** 2;
182      }
183      addr = block.addr + block.size;
184      block_i += 1;
185    }
186    if (addr < seg.addr + seg.size) {
187      sum_squared_free += (seg.addr + seg.size - addr) ** 2;
188    }
189  }
190  console.log(sum_squared_free / total_size ** 2);
191}
192
193function MemoryView(outer, stack_info, snapshot, device) {
194  const svg = outer
195    .append('svg')
196    .attr('style', 'grid-column: 2; grid-row: 1; width: 100%; height: 100%;')
197    .attr('viewBox', '0 0 200 100')
198    .attr('preserveAspectRatio', 'xMinYMin meet');
199  const g = svg.append('g');
200  const seg_zoom = zoom();
201  seg_zoom.on('zoom', () => {
202    g.attr('transform', d3.event.transform);
203  });
204  svg.call(seg_zoom);
205
206  const sorted_segments = [];
207  const block_map = {};
208  for (const seg of snapshot.segments) {
209    if (seg.device !== device) {
210      continue;
211    }
212    sorted_segments.push(
213      Segment(
214        seg.address,
215        seg.total_size,
216        seg.stream,
217        seg.frames || [],
218        seg.version,
219      ),
220    );
221    for (const b of seg.blocks) {
222      if (b.state !== 'active_pending_free' && b.state !== 'active_allocated') {
223        continue;
224      }
225      block_map[b.addr] = Block(
226        b.addr,
227        b.size,
228        b.requested_size,
229        b.frames,
230        b.state === 'active_pending_free',
231        b.version,
232      );
233    }
234  }
235  sorted_segments.sort((x, y) => x.addr - y.addr);
236
237  function simulate_memory(idx) {
238    // create a copy of segments because we edit size properties below
239    const l_segments = sorted_segments.map(x => {
240      return {...x};
241    });
242    const l_block_map = {...block_map};
243
244    function map_segment(merge, seg) {
245      let idx = l_segments.findIndex(e => e.addr > seg.addr);
246      if (!merge) {
247        l_segments.splice(idx, 0, seg);
248        return;
249      }
250      if (idx === -1) {
251        idx = l_segments.length;
252      }
253      l_segments.splice(idx, 0, seg);
254      if (idx + 1 < l_segments.length) {
255        const next = l_segments[idx + 1];
256        if (seg.addr + seg.size === next.addr && seg.stream === next.stream) {
257          seg.size += next.size;
258          l_segments.splice(idx + 1, 1);
259        }
260      }
261      if (idx > 0) {
262        const prev = l_segments[idx - 1];
263        if (prev.addr + prev.size === seg.addr && prev.stream === seg.stream) {
264          prev.size += seg.size;
265          l_segments.splice(idx, 1);
266        }
267      }
268    }
269    function unmap_segment(merge, seg) {
270      if (!merge) {
271        l_segments.splice(
272          l_segments.findIndex(x => x.addr === seg.addr),
273          1,
274        );
275        return;
276      }
277      const seg_end = seg.addr + seg.size;
278      const idx = l_segments.findIndex(
279        e => e.addr <= seg.addr && seg_end <= e.addr + e.size,
280      );
281      const existing = l_segments[idx];
282      const existing_end = existing.addr + existing.size;
283      if (existing.addr === seg.addr) {
284        existing.addr += seg.size;
285        existing.size -= seg.size;
286        if (existing.size === 0) {
287          l_segments.splice(idx, 1);
288        }
289      } else if (existing_end === seg_end) {
290        existing.size -= seg.size;
291      } else {
292        existing.size = seg.addr - existing.addr;
293        seg.addr = seg_end;
294        seg.size = existing_end - seg_end;
295        l_segments.splice(idx + 1, 0, seg);
296      }
297    }
298    const events = snapshot.device_traces[device];
299    for (let i = events.length - 1; i > idx; i--) {
300      const event = events[i];
301      switch (event.action) {
302        case 'free':
303          l_block_map[event.addr] = Block(
304            event.addr,
305            event.size,
306            event.size,
307            event.frames,
308            false,
309            event.version,
310          );
311          break;
312        case 'free_requested':
313          l_block_map[event.addr].free_requested = false;
314          break;
315        case 'free_completed':
316          l_block_map[event.addr] = Block(
317            event.addr,
318            event.size,
319            event.size,
320            event.frames,
321            true,
322            event.version,
323          );
324          break;
325        case 'alloc':
326          delete l_block_map[event.addr];
327          break;
328        case 'segment_free':
329        case 'segment_unmap':
330          map_segment(
331            event.action === 'segment_unmap',
332            Segment(
333              event.addr,
334              event.size,
335              event.stream,
336              event.frames,
337              event.version,
338            ),
339          );
340          break;
341        case 'segment_alloc':
342        case 'segment_map':
343          unmap_segment(
344            event.action === 'segment_map',
345            Segment(
346              event.addr,
347              event.size,
348              event.stream,
349              event.frames,
350              event.version,
351            ),
352          );
353          break;
354        case 'oom':
355          break;
356        default:
357          break;
358      }
359    }
360    const new_blocks = Object.values(l_block_map);
361    return [l_segments, new_blocks];
362  }
363
364  return {
365    draw(idx) {
366      const [segments_unsorted, blocks] = simulate_memory(idx);
367      g.selectAll('g').remove();
368
369      const segment_d = g.append('g');
370      const block_g = g.append('g');
371      const block_r = g.append('g');
372
373      segment_d.selectAll('rect').remove();
374      block_g.selectAll('rect').remove();
375      block_r.selectAll('rect').remove();
376      const segments = [...segments_unsorted].sort((x, y) =>
377        x.size === y.size ? x.addr - y.addr : x.size - y.size,
378      );
379
380      const segments_by_addr = [...segments].sort((x, y) => x.addr - y.addr);
381
382      const max_size = segments.length === 0 ? 0 : segments.at(-1).size;
383
384      const xScale = scaleLinear().domain([0, max_size]).range([0, 200]);
385      const padding = xScale.invert(1);
386
387      let cur_row = 0;
388      let cur_row_size = 0;
389      for (const seg of segments) {
390        seg.occupied = 0;
391        seg.internal_free = 0;
392        if (cur_row_size + seg.size > max_size) {
393          cur_row_size = 0;
394          cur_row += 1;
395        }
396        seg.offset = cur_row_size;
397        seg.row = cur_row;
398        cur_row_size += seg.size + padding;
399      }
400
401      const num_rows = cur_row + 1;
402
403      const yScale = scaleLinear().domain([0, num_rows]).range([0, 100]);
404
405      const segments_selection = segment_d
406        .selectAll('rect')
407        .data(segments)
408        .enter()
409        .append('rect')
410        .attr('x', x => xScale(x.offset))
411        .attr('y', x => yScale(x.row))
412        .attr('width', x => xScale(x.size))
413        .attr('height', yScale(4 / 5))
414        .attr('stroke', 'black')
415        .attr('stroke-width', '1')
416        .attr('vector-effect', 'non-scaling-stroke')
417        .attr('fill', 'white');
418
419      stack_info.register(
420        segments_selection,
421        d => {
422          addStroke(d);
423          const t = d.datum();
424          const free = t.size - t.occupied;
425          let internal = '';
426          if (t.internal_free > 0) {
427            internal = ` (${(t.internal_free / free) * 100}% internal)`;
428          }
429          return (
430            `s${t.addr.toString(16)}_${t.version}: segment ${formatSize(
431              t.size,
432            )} allocated, ` +
433            `${formatSize(free)} free${internal} (stream ${
434              t.stream
435            })\n${format_frames(t.frames)}`
436          );
437        },
438        d => {
439          d.attr('stroke', 'black')
440            .attr('stroke-width', '1')
441            .attr('vector-effect', 'non-scaling-stroke');
442        },
443      );
444
445      function find_segment(addr) {
446        let left = 0;
447        let right = segments_by_addr.length - 1;
448        while (left <= right) {
449          const mid = Math.floor((left + right) / 2);
450          if (addr < segments_by_addr[mid].addr) {
451            right = mid - 1;
452          } else if (
453            addr >=
454            segments_by_addr[mid].addr + segments_by_addr[mid].size
455          ) {
456            left = mid + 1;
457          } else {
458            return segments_by_addr[mid];
459          }
460        }
461        return null;
462      }
463
464      for (const b of blocks) {
465        b.segment = find_segment(b.addr);
466        b.segment.occupied += b.requested_size;
467        b.segment.internal_free += b.size - b.requested_size;
468      }
469
470      const block_selection = block_g
471        .selectAll('rect')
472        .data(blocks)
473        .enter()
474        .append('rect')
475        .attr('x', x => xScale(x.segment.offset + (x.addr - x.segment.addr)))
476        .attr('y', x => yScale(x.segment.row))
477        .attr('width', x => xScale(x.requested_size))
478        .attr('height', yScale(4 / 5))
479        .attr('fill', (x, _i) =>
480          x.free_requested
481            ? 'red'
482            : schemeTableau10[
483                Math.abs(hashCode(x.addr)) % schemeTableau10.length
484              ],
485        );
486
487      stack_info.register(
488        block_selection,
489        d => {
490          addStroke(d);
491          const t = d.datum();
492          let requested = '';
493          if (t.free_requested) {
494            requested = ' (block freed but waiting due to record_stream)';
495          }
496          return (
497            `b${t.addr.toString(16)}_${t.version} ` +
498            `${formatSize(t.requested_size)} allocation${requested} (stream ${
499              t.segment.stream
500            })\n` +
501            format_frames(t.frames)
502          );
503        },
504        removeStroke,
505      );
506
507      const free_selection = block_r
508        .selectAll('rect')
509        .data(blocks)
510        .enter()
511        .append('rect')
512        .attr('x', x =>
513          xScale(
514            x.segment.offset + (x.addr - x.segment.addr) + x.requested_size,
515          ),
516        )
517        .attr('y', x => yScale(x.segment.row))
518        .attr('width', x => xScale(x.size - x.requested_size))
519        .attr('height', yScale(4 / 5))
520        .attr('fill', (_x, _i) => 'red');
521
522      stack_info.register(
523        free_selection,
524        d => {
525          addStroke(d);
526          const t = d.datum();
527          return (
528            `Free space lost due to rounding ${formatSize(
529              t.size - t.requested_size,
530            )}` +
531            ` (stream ${t.segment.stream})\n` +
532            format_frames(t.frames)
533          );
534        },
535        removeStroke,
536      );
537
538      const reserved = segments.reduce((x, y) => x + y.size, 0);
539      const allocated = blocks.reduce((x, y) => x + y.requested_size, 0);
540      return [reserved, allocated];
541    },
542  };
543}
544
545function StackInfo(outer) {
546  const stack_trace = outer
547    .append('pre')
548    .attr('style', 'grid-column: 1 / 3; grid-row: 2; overflow: auto');
549  let selected = {
550    enter: () => {
551      stack_trace.text('');
552    },
553    leave: () => {},
554  };
555  return {
556    register(dom, enter, leave = _e => {}, select = _e => {}) {
557      dom
558        .on('mouseover', _e => {
559          selected.leave();
560          stack_trace.text(enter(d3.select(d3.event.target)));
561        })
562        .on('mousedown', _e => {
563          const obj = d3.select(d3.event.target);
564          selected = {
565            enter: () => stack_trace.text(enter(obj)),
566            leave: () => leave(obj),
567          };
568          select(obj);
569        })
570        .on('mouseleave', _e => {
571          leave(d3.select(d3.event.target));
572          selected.enter();
573        });
574    },
575    highlight(enter, leave = () => {}) {
576      selected = {enter: () => stack_trace.text(enter()), leave};
577      selected.enter();
578    },
579  };
580}
581
582function create_segment_view(dst, snapshot, device) {
583  const outer = dst
584    .append('div')
585    .attr(
586      'style',
587      'display: grid; grid-template-columns: 1fr 2fr; grid-template-rows: 2fr 1fr; height: 100%; gap: 10px',
588    );
589
590  const events = snapshot.device_traces[device];
591  const stack_info = StackInfo(outer);
592  const memory_view = MemoryView(outer, stack_info, snapshot, device);
593  const event_selector = EventSelector(outer, events, stack_info, memory_view);
594
595  window.requestAnimationFrame(function () {
596    event_selector.select(events.length > 0 ? events.length - 1 : null);
597  });
598}
599
600function annotate_snapshot(snapshot) {
601  snapshot.segment_version = version_space();
602  snapshot.block_version = version_space();
603  snapshot.categories = [];
604  const empty_list = [];
605  let next_stream = 1;
606  const stream_names = {0: 0};
607  function stream_name(s) {
608    if (!(s in stream_names)) {
609      stream_names[s] = next_stream++;
610    }
611    return stream_names[s];
612  }
613  const new_traces = [];
614  for (const device_trace of snapshot.device_traces) {
615    const new_trace = [];
616    new_traces.push(new_trace);
617    for (const t of device_trace) {
618      if (!('frames' in t)) {
619        t.frames = empty_list;
620      }
621      // set unique version for each time an address is used
622      // so that ctrl-f can be used to search for the beginning
623      // and end of allocations and segments
624      t.stream = stream_name(t.stream);
625      switch (t.action) {
626        case 'free_completed':
627          t.version = snapshot.block_version(t.addr, true);
628          if (new_trace.length > 0) {
629            // elide free_requested/free_completed into a single event
630            const prev = new_trace.at(-1);
631            if (prev.action === 'free_requested' && prev.addr === t.addr) {
632              prev.action = 'free';
633              continue;
634            }
635          }
636          break;
637        case 'free_requested':
638        case 'alloc':
639          t.version = snapshot.block_version(t.addr, false);
640          break;
641        case 'segment_free':
642        case 'segment_unmap':
643          t.version = snapshot.segment_version(t.addr, true);
644          break;
645        case 'segment_alloc':
646        case 'segment_map':
647          t.version = snapshot.segment_version(t.addr, false);
648          break;
649        default:
650          break;
651      }
652      if ('category' in t && !snapshot.categories.includes(t.category)) {
653        snapshot.categories.push(t.category);
654      }
655      t.idx = new_trace.length;
656      new_trace.push(t);
657    }
658  }
659  snapshot.device_traces = new_traces;
660  // if every event was on the default stream, we elide stream printing
661  if (next_stream == 1) {
662    for (const device_trace of snapshot.device_traces) {
663      for (const t of device_trace) {
664        t.stream = null;
665      }
666    }
667  }
668
669  for (const seg of snapshot.segments) {
670    seg.stream = stream_name(seg.stream);
671    seg.version = snapshot.segment_version(seg.address, false);
672    let addr = seg.address;
673    for (const b of seg.blocks) {
674      b.addr = addr;
675      if (!('frames' in b)) {
676        // legacy format where 'requested_size' may be missing
677        // and frames might be in history rather than directly on block
678        if ('history' in b) {
679          b.frames = b.history[0].frames || empty_list;
680          b.requested_size = b.requested_size || b.history[0].real_size;
681        } else {
682          b.frames = empty_list;
683          b.requested_size = b.requested_size || b.size;
684        }
685      }
686      b.version = snapshot.block_version(b.addr, false);
687      addr += b.size;
688    }
689  }
690
691  if (
692    snapshot.categories.length > 0 &&
693    !snapshot.categories.includes('unknown')
694  ) {
695    snapshot.categores.push('unknown');
696  }
697}
698
699function elideRepeats(frames) {
700  const result = [];
701  const length = frames.length;
702  for (let i = 0; i < length; ) {
703    let j = i + 1;
704    const f = frames[i];
705    while (j < length && f === frames[j]) {
706      j++;
707    }
708    switch (j - i) {
709      case 1:
710        result.push(f);
711        break;
712      case 2:
713        result.push(f, f);
714        break;
715      default:
716        result.push(f, `<repeats ${j - i - 1} times>`);
717        break;
718    }
719    i = j;
720  }
721  return result;
722}
723function frameFilter({name, filename}) {
724  const omitFunctions = [
725    'unwind::unwind',
726    'CapturedTraceback::gather',
727    'gather_with_cpp',
728    '_start',
729    '__libc_start_main',
730    'PyEval_',
731    'PyObject_',
732    'PyFunction_',
733  ];
734
735  const omitFilenames = [
736    'core/boxing',
737    '/Register',
738    '/Redispatch',
739    'pythonrun.c',
740    'Modules/main.c',
741    'Objects/call.c',
742    'Objects/methodobject.c',
743    'pycore_ceval.h',
744    'ceval.c',
745    'cpython/abstract.h',
746  ];
747
748  for (const of of omitFunctions) {
749    if (name.includes(of)) {
750      return false;
751    }
752  }
753
754  for (const of of omitFilenames) {
755    if (filename.includes(of)) {
756      return false;
757    }
758  }
759
760  return true;
761}
762
763function format_frames(frames) {
764  if (frames.length === 0) {
765    return (
766      `This block has no frames. Potential causes:\n` +
767      `1) This block was allocated before _record_memory_history was enabled.\n` +
768      `2) The context or stacks passed to _record_memory_history does not include this block. Consider changing context to 'state', 'alloc', or 'all', or changing stacks to 'all'.\n` +
769      `3) This event occurred during backward, which has no python frames, and memory history did not include C++ frames. Use stacks='all' to record both C++ and python frames.`
770    );
771  }
772  const frame_strings = frames
773    .filter(frameFilter)
774    .map(f => `${f.filename}:${f.line}:${f.name}`);
775  return elideRepeats(frame_strings).join('\n');
776}
777
778function process_alloc_data(snapshot, device, plot_segments, max_entries) {
779  const elements = [];
780  const initially_allocated = [];
781  const actions = [];
782  const addr_to_alloc = {};
783
784  const alloc = plot_segments ? 'segment_alloc' : 'alloc';
785  const [free, free_completed] = plot_segments
786    ? ['segment_free', 'segment_free']
787    : ['free', 'free_completed'];
788  for (const e of snapshot.device_traces[device]) {
789    switch (e.action) {
790      case alloc:
791        elements.push(e);
792        addr_to_alloc[e.addr] = elements.length - 1;
793        actions.push(elements.length - 1);
794        break;
795      case free:
796      case free_completed:
797        if (e.addr in addr_to_alloc) {
798          actions.push(addr_to_alloc[e.addr]);
799          delete addr_to_alloc[e.addr];
800        } else {
801          elements.push(e);
802          initially_allocated.push(elements.length - 1);
803          actions.push(elements.length - 1);
804        }
805        break;
806      default:
807        break;
808    }
809  }
810  for (const seg of snapshot.segments) {
811    if (seg.device !== device) {
812      continue;
813    }
814    if (plot_segments) {
815      if (!(seg.address in addr_to_alloc)) {
816        const element = {
817          action: 'alloc',
818          addr: seg.address,
819          size: seg.total_size,
820          frames: [],
821          stream: seg.stream,
822          version: seg.version,
823        };
824        elements.push(element);
825        initially_allocated.push(elements.length - 1);
826      }
827    } else {
828      for (const b of seg.blocks) {
829        if (b.state === 'active_allocated' && !(b.addr in addr_to_alloc)) {
830          const element = {
831            action: 'alloc',
832            addr: b.addr,
833            size: b.requested_size,
834            frames: b.frames,
835            stream: seg.stream,
836            version: b.version,
837          };
838          elements.push(element);
839          initially_allocated.push(elements.length - 1);
840        }
841      }
842    }
843  }
844  initially_allocated.reverse();
845  // if there are no actions, the graph will be blank,
846  // but if there are existing allocations we do not want to hide them
847  // by having just one allocate action it will show a flat graph with all segments
848  if (actions.length === 0 && initially_allocated.length > 0) {
849    actions.push(initially_allocated.pop());
850  }
851
852  const current = [];
853  const current_data = [];
854  const data = [];
855  let max_size = 0;
856
857  let total_mem = 0;
858  let total_summarized_mem = 0;
859  let timestep = 0;
860
861  const max_at_time = [];
862
863  const summarized_mem = {
864    elem: 'summarized',
865    timesteps: [],
866    offsets: [total_mem],
867    size: [],
868    color: 0,
869  };
870  const summarized_elems = {};
871
872  function advance(n) {
873    summarized_mem.timesteps.push(timestep);
874    summarized_mem.offsets.push(total_mem);
875    summarized_mem.size.push(total_summarized_mem);
876    timestep += n;
877    for (let i = 0; i < n; i++) {
878      max_at_time.push(total_mem + total_summarized_mem);
879    }
880  }
881
882  const sizes = elements
883    .map((x, i) => [x.size, i])
884    .sort(([x, _xi], [y, _yi]) => y - x);
885
886  const draw_elem = {};
887  for (const [_s, e] of sizes.slice(0, max_entries)) {
888    draw_elem[e] = true;
889  }
890
891  function add_allocation(elem) {
892    const element_obj = elements[elem];
893    const size = element_obj.size;
894    current.push(elem);
895    let color = elem;
896    if (snapshot.categories.length > 0) {
897      color = snapshot.categories.indexOf(element_obj.category || 'unknown');
898    }
899    const e = {
900      elem,
901      timesteps: [timestep],
902      offsets: [total_mem],
903      size,
904      color,
905    };
906    current_data.push(e);
907    data.push(e);
908    total_mem += size;
909    element_obj.max_allocated_mem = total_mem + total_summarized_mem;
910  }
911
912  for (const elem of initially_allocated) {
913    if (elem in draw_elem) {
914      add_allocation(elem);
915    } else {
916      total_summarized_mem += elements[elem].size;
917      summarized_elems[elem] = true;
918    }
919  }
920
921  for (const elem of actions) {
922    const size = elements[elem].size;
923    if (!(elem in draw_elem)) {
924      if (elem in summarized_elems) {
925        advance(1);
926        total_summarized_mem -= size;
927        summarized_elems[elem] = null;
928      } else {
929        total_summarized_mem += size;
930        summarized_elems[elem] = true;
931        advance(1);
932      }
933      continue;
934    }
935    const idx = current.findLastIndex(x => x === elem);
936    // first time we see an action we add it
937    // second time we remove it
938    if (idx === -1) {
939      add_allocation(elem);
940      advance(1);
941    } else {
942      advance(1);
943      const removed = current_data[idx];
944      removed.timesteps.push(timestep);
945      removed.offsets.push(removed.offsets.at(-1));
946      current.splice(idx, 1);
947      current_data.splice(idx, 1);
948
949      if (idx < current.length) {
950        for (let j = idx; j < current.length; j++) {
951          const e = current_data[j];
952          e.timesteps.push(timestep);
953          e.offsets.push(e.offsets.at(-1));
954          e.timesteps.push(timestep + 3);
955          e.offsets.push(e.offsets.at(-1) - size);
956        }
957        advance(3);
958      }
959      total_mem -= size;
960    }
961    max_size = Math.max(total_mem + total_summarized_mem, max_size);
962  }
963
964  for (const elem of current_data) {
965    elem.timesteps.push(timestep);
966    elem.offsets.push(elem.offsets.at(-1));
967  }
968  data.push(summarized_mem);
969
970  return {
971    max_size,
972    allocations_over_time: data,
973    max_at_time,
974    summarized_mem,
975    elements_length: elements.length,
976    context_for_id: id => {
977      const elem = elements[id];
978      let text = `Addr: ${formatAddr(elem)}`;
979      text = `${text}, Size: ${formatSize(elem.size)} allocation`;
980      text = `${text}, Total memory used after allocation: ${formatSize(
981        elem.max_allocated_mem,
982      )}`;
983      if (elem.stream !== null) {
984        text = `${text}, stream ${elem.stream}`;
985      }
986      if (elem.timestamp !== null) {
987        var d = new Date(elem.time_us / 1000);
988        text = `${text}, timestamp ${d}`;
989      }
990      if (!elem.action.includes('alloc')) {
991        text = `${text}\nalloc not recorded, stack trace for free:`;
992      }
993      text = `${text}\n${format_frames(elem.frames)}`;
994      return text;
995    },
996  };
997}
998
999function MemoryPlot(
1000  svg,
1001  data,
1002  left_pad,
1003  width,
1004  height,
1005  colors = schemeTableau10,
1006) {
1007  function format_points(d) {
1008    const size = d.size;
1009    const xs = d.timesteps.map(t => xscale(t));
1010    const bottom = d.offsets.map(t => yscale(t));
1011    const m = Array.isArray(size)
1012      ? (t, i) => yscale(t + size[i])
1013      : t => yscale(t + size);
1014    const top = d.offsets.map(m);
1015    const p0 = xs.map((x, i) => `${x},${bottom[i]}`);
1016    const p1 = xs.map((x, i) => `${x},${top[i]}`).reverse();
1017    return `${p0.join(' ')} ${p1.join(' ')}`;
1018  }
1019
1020  const max_timestep = data.max_at_time.length;
1021  const max_size = data.max_size;
1022
1023  const plot_width = width - left_pad;
1024  const plot_height = height;
1025
1026  const yscale = scaleLinear().domain([0, max_size]).range([plot_height, 0]);
1027  const yaxis = axisLeft(yscale).tickFormat(d3.format('.3s'));
1028  const xscale = scaleLinear().domain([0, max_timestep]).range([0, plot_width]);
1029  const plot_coordinate_space = svg
1030    .append('g')
1031    .attr('transform', `translate(${left_pad}, ${0})`);
1032  const plot_outer = plot_coordinate_space.append('g');
1033
1034  function view_rect(a) {
1035    return a
1036      .append('rect')
1037      .attr('x', 0)
1038      .attr('y', 0)
1039      .attr('width', plot_width)
1040      .attr('height', plot_height)
1041      .attr('fill', 'white');
1042  }
1043
1044  view_rect(plot_outer);
1045
1046  const cp = svg.append('clipPath').attr('id', 'clip');
1047  view_rect(cp);
1048  plot_outer.attr('clip-path', 'url(#clip)');
1049
1050  const zoom_group = plot_outer.append('g');
1051  const scrub_group = zoom_group.append('g');
1052
1053  const plot = scrub_group
1054    .selectAll('polygon')
1055    .data(data.allocations_over_time)
1056    .enter()
1057    .append('polygon')
1058    .attr('points', format_points)
1059    .attr('fill', d => colors[d.color % colors.length]);
1060
1061  const axis = plot_coordinate_space.append('g').call(yaxis);
1062
1063  function handleZoom() {
1064    const t = d3.event.transform;
1065    zoom_group.attr('transform', t);
1066    axis.call(yaxis.scale(d3.event.transform.rescaleY(yscale)));
1067  }
1068
1069  const thezoom = zoom().on('zoom', handleZoom);
1070  plot_outer.call(thezoom);
1071
1072  return {
1073    select_window: (stepbegin, stepend, max) => {
1074      const begin = xscale(stepbegin);
1075      const size = xscale(stepend) - xscale(stepbegin);
1076      const scale = plot_width / size;
1077      const translate = -begin;
1078      const yscale = max_size / max;
1079      scrub_group.attr(
1080        'transform',
1081        `scale(${scale / yscale}, 1) translate(${translate}, 0)`,
1082      );
1083      plot_outer.call(
1084        thezoom.transform,
1085        zoomIdentity
1086          .scale(yscale)
1087          .translate(0, -(plot_height - plot_height / yscale)),
1088      );
1089    },
1090    set_delegate: delegate => {
1091      plot
1092        .on('mouseover', function (_e, _d) {
1093          delegate.set_selected(d3.select(this));
1094        })
1095        .on('mousedown', function (_e, _d) {
1096          delegate.default_selected = d3.select(this);
1097        })
1098        .on('mouseleave', function (_e, _d) {
1099          delegate.set_selected(delegate.default_selected);
1100        });
1101    },
1102  };
1103}
1104
1105function ContextViewer(text, data) {
1106  let current_selected = null;
1107
1108  return {
1109    default_selected: null,
1110    set_selected: d => {
1111      if (current_selected !== null) {
1112        current_selected.attr('stroke', null).attr('stroke-width', null);
1113      }
1114      if (d === null) {
1115        text.text('');
1116      } else {
1117        const dd = d.datum();
1118        if (dd.elem === 'summarized') {
1119          text.html(
1120            'Small tensors that were not plotted to cutdown on render time.\n' +
1121              'Use detail slider to see smaller allocations.',
1122          );
1123        } else {
1124          text.text(`${dd.elem} ${data.context_for_id(dd.elem)}`);
1125        }
1126        d.attr('stroke', 'black')
1127          .attr('stroke-width', 1)
1128          .attr('vector-effect', 'non-scaling-stroke');
1129      }
1130      current_selected = d;
1131    },
1132  };
1133}
1134
1135function MiniMap(mini_svg, plot, data, left_pad, width, height = 70) {
1136  const max_at_time = data.max_at_time;
1137  const plot_width = width - left_pad;
1138  const yscale = scaleLinear().domain([0, data.max_size]).range([height, 0]);
1139  const minixscale = scaleLinear()
1140    .domain([0, max_at_time.length])
1141    .range([left_pad, width]);
1142
1143  const mini_points = [
1144    [max_at_time.length, 0],
1145    [0, 0],
1146  ];
1147
1148  for (const [i, m] of max_at_time.entries()) {
1149    const [_lastx, lasty] = mini_points[mini_points.length - 1];
1150    if (m !== lasty) {
1151      mini_points.push([i, lasty]);
1152      mini_points.push([i, m]);
1153    } else if (i === max_at_time.length - 1) {
1154      mini_points.push([i, m]);
1155    }
1156  }
1157
1158  let points = mini_points.map(([t, o]) => `${minixscale(t)}, ${yscale(o)}`);
1159  points = points.join(' ');
1160  mini_svg
1161    .append('polygon')
1162    .attr('points', points)
1163    .attr('fill', schemeTableau10[0]);
1164
1165  const xscale = scaleLinear()
1166    .domain([0, max_at_time.length])
1167    .range([0, plot_width]);
1168
1169  const brush = brushX();
1170  brush.extent([
1171    [left_pad, 0],
1172    [width, height],
1173  ]);
1174  brush.on('brush', function () {
1175    const [begin, end] = d3.event.selection.map(x => x - left_pad);
1176
1177    const stepbegin = Math.floor(xscale.invert(begin));
1178    const stepend = Math.floor(xscale.invert(end));
1179    let max = 0;
1180    for (let i = stepbegin; i < stepend; i++) {
1181      max = Math.max(max, max_at_time[i]);
1182    }
1183    plot.select_window(stepbegin, stepend, max);
1184  });
1185  mini_svg.call(brush);
1186  return {};
1187}
1188
1189function Legend(plot_svg, categories) {
1190  const xstart = 100;
1191  const ystart = 5;
1192  plot_svg
1193    .append('g')
1194    .selectAll('rect')
1195    .data(categories)
1196    .enter()
1197    .append('rect')
1198    .attr('x', (c, i) => xstart)
1199    .attr('y', (c, i) => ystart + i * 15)
1200    .attr('width', 10)
1201    .attr('height', 10)
1202    .attr('fill', (c, i) => schemeTableau10[i % schemeTableau10.length]);
1203  plot_svg
1204    .append('g')
1205    .selectAll('text')
1206    .data(categories)
1207    .enter()
1208    .append('text')
1209    .attr('x', (c, i) => xstart + 20)
1210    .attr('y', (c, i) => ystart + i * 15 + 8)
1211    .attr('font-family', 'helvetica')
1212    .attr('font-size', 10)
1213    .text(c => c);
1214  return {};
1215}
1216
1217function create_trace_view(
1218  dst,
1219  snapshot,
1220  device,
1221  plot_segments = false,
1222  max_entries = 15000,
1223) {
1224  const left_pad = 70;
1225  const data = process_alloc_data(snapshot, device, plot_segments, max_entries);
1226  dst.selectAll('svg').remove();
1227  dst.selectAll('div').remove();
1228
1229  const d = dst.append('div');
1230  d.append('input')
1231    .attr('type', 'range')
1232    .attr('min', 0)
1233    .attr('max', data.elements_length)
1234    .attr('value', max_entries)
1235    .on('change', function () {
1236      create_trace_view(dst, snapshot, device, plot_segments, this.value);
1237    });
1238  d.append('label').text('Detail');
1239
1240  const grid_container = dst
1241    .append('div')
1242    .attr(
1243      'style',
1244      'display: grid; grid-template-columns: 1fr; grid-template-rows: 10fr 1fr 8fr; height: 100%; gap: 10px',
1245    );
1246
1247  const plot_svg = grid_container
1248    .append('svg')
1249    .attr('display', 'block')
1250    .attr('viewBox', '0 0 1024 576')
1251    .attr('preserveAspectRatio', 'none')
1252    .attr('style', 'grid-column: 1; grid-row: 1; width: 100%; height: 100%;');
1253
1254  const plot = MemoryPlot(plot_svg, data, left_pad, 1024, 576);
1255
1256  if (snapshot.categories.length !== 0) {
1257    Legend(plot_svg.append('g'), snapshot.categories);
1258  }
1259
1260  const mini_svg = grid_container
1261    .append('svg')
1262    .attr('display', 'block')
1263    .attr('viewBox', '0 0 1024 60')
1264    .attr('preserveAspectRatio', 'none')
1265    .attr('style', 'grid-column: 1; grid-row: 2; width: 100%; height: 100%;');
1266
1267  MiniMap(mini_svg, plot, data, left_pad, 1024);
1268  const context_div = grid_container
1269    .append('div')
1270    .attr(
1271      'style',
1272      'grid-column: 1; grid-row: 3; width: 100%; height: 100%; overflow: auto;',
1273    );
1274  const delegate = ContextViewer(context_div.append('pre').text('none'), data);
1275  plot.set_delegate(delegate);
1276}
1277
1278function create_settings_view(dst, snapshot, device) {
1279  dst.selectAll('svg').remove();
1280  dst.selectAll('div').remove();
1281  const settings_div = dst.append('div');
1282  settings_div.append('p').text('CUDA Caching Allocator Settings:');
1283
1284  // Check if allocator_settings exists in snapshot
1285  if ('allocator_settings' in snapshot) {
1286    settings_div
1287      .append('pre')
1288      .text(JSON.stringify(snapshot.allocator_settings, null, 2));
1289  } else {
1290    settings_div.append('p').text('No allocator settings found.');
1291  }
1292}
1293
1294function unpickle(buffer) {
1295  const bytebuffer = new Uint8Array(buffer);
1296  const decoder = new TextDecoder();
1297
1298  const stack = [];
1299  const marks = [];
1300  const memo = [];
1301  let offset = 0;
1302  let memo_id = 0;
1303
1304  const APPENDS = 'e'.charCodeAt(0);
1305  const BINGET = 'h'.charCodeAt(0);
1306  const BININT = 'J'.charCodeAt(0);
1307  const BININT1 = 'K'.charCodeAt(0);
1308  const BININT2 = 'M'.charCodeAt(0);
1309  const EMPTY_DICT = '}'.charCodeAt(0);
1310  const EMPTY_LIST = ']'.charCodeAt(0);
1311  const FRAME = 0x95;
1312  const LONG1 = 0x8a;
1313  const LONG_BINGET = 'j'.charCodeAt(0);
1314  const MARK = '('.charCodeAt(0);
1315  const MEMOIZE = 0x94;
1316  const PROTO = 0x80;
1317  const SETITEMS = 'u'.charCodeAt(0);
1318  const SHORT_BINUNICODE = 0x8c;
1319  const STOP = '.'.charCodeAt(0);
1320  const TUPLE2 = 0x86;
1321  const APPEND = 'a'.charCodeAt(0);
1322  const NEWFALSE = 0x89;
1323  const BINPUT = 'q'.charCodeAt(0);
1324  const BINUNICODE = 'X'.charCodeAt(0);
1325  const EMPTY_TUPLE = ')'.charCodeAt(0);
1326  const NEWTRUE = 0x88;
1327  const NONE = 'N'.charCodeAt(0);
1328  const BINFLOAT = 'G'.charCodeAt(0);
1329  const TUPLE = 't'.charCodeAt(0);
1330  const TUPLE1 = 0x85;
1331  const TUPLE3 = 0x87;
1332  // untested
1333  const LONG_BINPUT = 'r'.charCodeAt(0);
1334  const LIST = 'l'.charCodeAt(0);
1335  const DICT = 'd'.charCodeAt(0);
1336  const SETITEM = 's'.charCodeAt(0);
1337
1338  const scratch_buffer = new ArrayBuffer(8);
1339  const scratch_bytes = new Uint8Array(scratch_buffer);
1340  const big = new BigInt64Array(scratch_buffer);
1341  const float64 = new Float64Array(scratch_buffer);
1342
1343  function read_uint4() {
1344    const n =
1345      bytebuffer[offset] +
1346      bytebuffer[offset + 1] * 256 +
1347      bytebuffer[offset + 2] * 65536 +
1348      bytebuffer[offset + 3] * 16777216;
1349    offset += 4;
1350    return n;
1351  }
1352  function setitems(d, mark) {
1353    for (let i = mark; i < stack.length; i += 2) {
1354      d[stack[i]] = stack[i + 1];
1355    }
1356    stack.splice(mark, Infinity);
1357  }
1358
1359  while (true) {
1360    const opcode = bytebuffer[offset++];
1361    switch (opcode) {
1362      case PROTO:
1363        {
1364          const version = bytebuffer[offset++];
1365          if (version < 2 || version > 4) {
1366            throw new Error(`Unhandled version ${version}`);
1367          }
1368        }
1369        break;
1370      case APPEND:
1371        {
1372          const v = stack.pop();
1373          stack.at(-1).push(v);
1374        }
1375        break;
1376      case APPENDS:
1377        {
1378          const mark = marks.pop();
1379          const arr = stack[mark - 1];
1380          arr.push(...stack.splice(mark, Infinity));
1381        }
1382        break;
1383      case LIST:
1384      case TUPLE:
1385        {
1386          const mark = marks.pop();
1387          stack.push([...stack.splice(mark, Infinity)]);
1388        }
1389        break;
1390      case NEWFALSE:
1391        stack.push(false);
1392        break;
1393      case NEWTRUE:
1394        stack.push(true);
1395        break;
1396      case NONE:
1397        stack.push(null);
1398        break;
1399      case BINGET:
1400        stack.push(memo[bytebuffer[offset++]]);
1401        break;
1402      case BININT:
1403        {
1404          let i32 = read_uint4();
1405          if (i32 > 0x7fffffff) {
1406            i32 -= 0x100000000;
1407          }
1408          stack.push(i32);
1409        }
1410        break;
1411      case BININT1:
1412        stack.push(bytebuffer[offset++]);
1413        break;
1414      case BININT2:
1415        {
1416          const v = bytebuffer[offset] + bytebuffer[offset + 1] * 256;
1417          stack.push(v);
1418          offset += 2;
1419        }
1420        break;
1421      case EMPTY_DICT:
1422        stack.push({});
1423        break;
1424      case EMPTY_LIST:
1425        stack.push([]);
1426        break;
1427      case FRAME:
1428        offset += 8;
1429        break;
1430      case LONG1:
1431        {
1432          const s = bytebuffer[offset++];
1433          if (s <= 8) {
1434            for (let i = 0; i < s; i++) {
1435              scratch_bytes[i] = bytebuffer[offset++];
1436            }
1437            const fill = scratch_bytes[s - 1] >= 128 ? 0xff : 0x0;
1438            for (let i = s; i < 8; i++) {
1439              scratch_bytes[i] = fill;
1440            }
1441            stack.push(Number(big[0]));
1442          } else { // BigInt
1443            let scratch_bytes_unbounded = [];
1444            for (let i = 0; i < s; i++) {
1445              scratch_bytes_unbounded.push(bytebuffer[offset++]);
1446            }
1447
1448            // BigInt can only convert from unsigned hex, thus we need to
1449            // convert from twos-complement if negative
1450            const negative = scratch_bytes_unbounded[s - 1] >= 128;
1451            if (negative) {
1452              // implements scratch_bytes_unbounded = ~scratch_bytes_unbounded + 1
1453              // byte-by-byte.
1454              let carry = 1;
1455              for (let i = 0; i < s; i++) {
1456                const twos_complement = (0xff ^ scratch_bytes_unbounded[i]) + carry;
1457                carry = twos_complement > 0xff ? 1 : 0;
1458                scratch_bytes_unbounded[i] = 0xff & twos_complement;
1459              }
1460            }
1461
1462            const hex_str = Array.from(scratch_bytes_unbounded.reverse(), byte => {
1463              return byte.toString(16).padStart(2, '0');
1464            }).join('');
1465
1466            const big_int = negative ? -BigInt(`0x${hex_str}`) : BigInt(`0x${hex_str}`);
1467            stack.push(big_int);
1468          }
1469        }
1470        break;
1471      case LONG_BINGET:
1472        {
1473          const idx = read_uint4();
1474          stack.push(memo[idx]);
1475        }
1476        break;
1477      case MARK:
1478        marks.push(stack.length);
1479        break;
1480      case MEMOIZE:
1481        memo[memo_id++] = stack.at(-1);
1482        break;
1483      case BINPUT:
1484        memo[bytebuffer[offset++]] = stack.at(-1);
1485        break;
1486      case LONG_BINPUT:
1487        memo[read_uint4()] = stack.at(-1);
1488        break;
1489      case SETITEMS:
1490        {
1491          const mark = marks.pop();
1492          const d = stack[mark - 1];
1493          setitems(d, mark);
1494        }
1495        break;
1496      case SETITEM: {
1497        const v = stack.pop();
1498        const k = stack.pop();
1499        stack.at(-1)[k] = v;
1500        break;
1501      }
1502      case DICT:
1503        {
1504          const mark = marks.pop();
1505          const d = {};
1506          setitems(d, mark);
1507          stack.push(d);
1508        }
1509        break;
1510      case SHORT_BINUNICODE:
1511        {
1512          const n = bytebuffer[offset++];
1513          stack.push(decoder.decode(new Uint8Array(buffer, offset, n)));
1514          offset += n;
1515        }
1516        break;
1517      case BINUNICODE:
1518        {
1519          const n = read_uint4();
1520          stack.push(decoder.decode(new Uint8Array(buffer, offset, n)));
1521          offset += n;
1522        }
1523        break;
1524      case STOP:
1525        return stack.pop();
1526      case EMPTY_TUPLE:
1527        stack.push([]);
1528        break;
1529      case TUPLE1:
1530        stack.push([stack.pop()]);
1531        break;
1532      case TUPLE2:
1533        stack.push(stack.splice(-2, Infinity));
1534        break;
1535      case TUPLE3:
1536        stack.push(stack.splice(-3, Infinity));
1537        break;
1538      case BINFLOAT:
1539        for (let i = 7; i >= 0; i--) {
1540          // stored in big-endian order
1541          scratch_bytes[i] = bytebuffer[offset++];
1542        }
1543        stack.push(float64[0]);
1544        break;
1545      default:
1546        throw new Error(`UNKNOWN OPCODE: ${opcode}`);
1547    }
1548  }
1549}
1550
1551function decode_base64(input) {
1552  function decode_char(i, shift) {
1553    const nChr = input.charCodeAt(i);
1554    const r =
1555      nChr > 64 && nChr < 91
1556        ? nChr - 65
1557        : nChr > 96 && nChr < 123
1558        ? nChr - 71
1559        : nChr > 47 && nChr < 58
1560        ? nChr + 4
1561        : nChr === 43
1562        ? 62
1563        : nChr === 47
1564        ? 63
1565        : 0;
1566    return r << shift;
1567  }
1568  const output = new Uint8Array((input.length / 4) * 3);
1569  for (let i = 0, j = 0; i < input.length; i += 4, j += 3) {
1570    const u24 =
1571      decode_char(i, 18) +
1572      decode_char(i + 1, 12) +
1573      decode_char(i + 2, 6) +
1574      decode_char(i + 3);
1575    output[j] = u24 >> 16;
1576    output[j + 1] = (u24 >> 8) & 0xff;
1577    output[j + 2] = u24 & 0xff;
1578  }
1579  return output.buffer;
1580}
1581
1582const kinds = {
1583  'Active Memory Timeline': create_trace_view,
1584  'Allocator State History': create_segment_view,
1585  'Active Cached Segment Timeline': (dst, snapshot, device) =>
1586    create_trace_view(dst, snapshot, device, true),
1587  'Allocator Settings': create_settings_view,
1588};
1589
1590const snapshot_cache = {};
1591const snapshot_to_loader = {};
1592const snapshot_to_url = {};
1593const selection_to_div = {};
1594
1595const style = `
1596pre {
1597  margin: 0px;
1598}
1599html, body {
1600  height: 100%;
1601  overflow: clip;
1602}`;
1603
1604const head = d3.select('head');
1605head.append('style').text(style);
1606const body = d3.select('body');
1607const snapshot_select = body.append('select');
1608const view = body.append('select');
1609for (const x in kinds) {
1610  view.append('option').text(x);
1611}
1612const gpu = body.append('select');
1613
1614function unpickle_and_annotate(data) {
1615  data = unpickle(data);
1616  console.log(data);
1617  annotate_snapshot(data);
1618  return data;
1619}
1620
1621function snapshot_change(f) {
1622  const view_value = view.node().value;
1623  let device = Number(gpu.node().value);
1624  const snapshot = snapshot_cache[f];
1625  gpu.selectAll('option').remove();
1626  const has_segments = {};
1627  for (const s of snapshot.segments) {
1628    has_segments[s.device] = true;
1629  }
1630  let device_valid = false;
1631  for (const [i, trace] of snapshot.device_traces.entries()) {
1632    if (trace.length > 0 || i in has_segments) {
1633      gpu.append('option').text(i);
1634      if (i === device) {
1635        device_valid = true;
1636        gpu.node().selectedIndex = gpu.node().children.length - 1;
1637      }
1638    }
1639  }
1640  if (!device_valid) {
1641    device = Number(gpu.node().value);
1642  }
1643  const key = [f, view_value, device];
1644  if (!(key in selection_to_div)) {
1645    selection_to_div[key] = d3.select('body').append('div');
1646    kinds[view_value](selection_to_div[key], snapshot, device);
1647  }
1648  const selected_div = selection_to_div[key];
1649
1650  selected_div.attr('style', 'display: float; height: 100%');
1651}
1652
1653function selected_change() {
1654  for (const d of Object.values(selection_to_div)) {
1655    d.attr('style', 'display: none; height: 100%');
1656  }
1657  const f = snapshot_select.node().value;
1658  if (f === '') {
1659    return;
1660  }
1661  if (!(f in snapshot_cache)) {
1662    snapshot_to_loader[f](f);
1663  } else {
1664    snapshot_change(f);
1665  }
1666}
1667
1668snapshot_select.on('change', selected_change);
1669view.on('change', selected_change);
1670gpu.on('change', selected_change);
1671
1672body.on('dragover', e => {
1673  event.preventDefault();
1674});
1675
1676body.on('drop', () => {
1677  console.log(event.dataTransfer.files);
1678  Array.from(event.dataTransfer.files).forEach(file => {
1679    add_snapshot(file.name, unique_name => {
1680      const reader = new FileReader();
1681      reader.onload = e => {
1682        finished_loading(unique_name, e.target.result);
1683      };
1684      reader.readAsArrayBuffer(file);
1685    });
1686  });
1687  event.preventDefault();
1688  snapshot_select.node().selectedIndex =
1689    snapshot_select.node().options.length - 1;
1690  selected_change();
1691});
1692
1693selection_to_div[''] = body
1694  .append('div')
1695  .text(
1696    'Drag and drop a file to load a local snapshot. No data from the snapshot is uploaded.',
1697  );
1698
1699let next_unique_n = 1;
1700function add_snapshot(name, loader) {
1701  if (name in snapshot_to_loader) {
1702    name = `${name} (${next_unique_n++})`;
1703  }
1704  snapshot_select.append('option').text(name);
1705  snapshot_to_loader[name] = loader;
1706}
1707
1708function finished_loading(name, data) {
1709  snapshot_cache[name] = unpickle_and_annotate(data);
1710  snapshot_change(name);
1711}
1712
1713export function add_remote_files(files) {
1714  files.forEach(f =>
1715    add_snapshot(f.name, unique_name => {
1716      console.log('fetching', f.url);
1717      fetch(f.url)
1718        .then(x => x.arrayBuffer())
1719        .then(data => finished_loading(unique_name, data));
1720    }),
1721  );
1722  if (files.length > 0) {
1723    selected_change();
1724  }
1725}
1726
1727export function add_local_files(files, view_value) {
1728  view.node().value = view_value;
1729  files.forEach(f =>
1730    add_snapshot(f.name, unique_name => {
1731      finished_loading(unique_name, decode_base64(f.base64));
1732    }),
1733  );
1734  if (files.length > 0) {
1735    selected_change();
1736  }
1737}
1738