xref: /aosp_15_r20/external/pytorch/torch/utils/tensorboard/writer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization."""
3
4import os
5import time
6from typing import List, Optional, TYPE_CHECKING, Union
7
8import torch
9
10if TYPE_CHECKING:
11    from matplotlib.figure import Figure
12from tensorboard.compat import tf
13from tensorboard.compat.proto import event_pb2
14from tensorboard.compat.proto.event_pb2 import Event, SessionLog
15from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig
16from tensorboard.summary.writer.event_file_writer import EventFileWriter
17
18from ._convert_np import make_np
19from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt
20from ._onnx_graph import load_onnx_graph
21from ._pytorch_graph import graph
22from ._utils import figure_to_image
23from .summary import (
24    audio,
25    custom_scalars,
26    histogram,
27    histogram_raw,
28    hparams,
29    image,
30    image_boxes,
31    mesh,
32    pr_curve,
33    pr_curve_raw,
34    scalar,
35    tensor_proto,
36    text,
37    video,
38)
39
40__all__ = ["FileWriter", "SummaryWriter"]
41
42
43class FileWriter:
44    """Writes protocol buffers to event files to be consumed by TensorBoard.
45
46    The `FileWriter` class provides a mechanism to create an event file in a
47    given directory and add summaries and events to it. The class updates the
48    file contents asynchronously. This allows a training program to call methods
49    to add data to the file directly from the training loop, without slowing down
50    training.
51    """
52
53    def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""):
54        """Create a `FileWriter` and an event file.
55
56        On construction the writer creates a new event file in `log_dir`.
57        The other arguments to the constructor control the asynchronous writes to
58        the event file.
59
60        Args:
61          log_dir: A string. Directory where event file will be written.
62          max_queue: Integer. Size of the queue for pending events and
63            summaries before one of the 'add' calls forces a flush to disk.
64            Default is ten items.
65          flush_secs: Number. How often, in seconds, to flush the
66            pending events and summaries to disk. Default is every two minutes.
67          filename_suffix: A string. Suffix added to all event filenames
68            in the log_dir directory. More details on filename construction in
69            tensorboard.summary.writer.event_file_writer.EventFileWriter.
70        """
71        # Sometimes PosixPath is passed in and we need to coerce it to
72        # a string in all cases
73        # TODO: See if we can remove this in the future if we are
74        # actually the ones passing in a PosixPath
75        log_dir = str(log_dir)
76        self.event_writer = EventFileWriter(
77            log_dir, max_queue, flush_secs, filename_suffix
78        )
79
80    def get_logdir(self):
81        """Return the directory where event file will be written."""
82        return self.event_writer.get_logdir()
83
84    def add_event(self, event, step=None, walltime=None):
85        """Add an event to the event file.
86
87        Args:
88          event: An `Event` protocol buffer.
89          step: Number. Optional global step value for training process
90            to record with the event.
91          walltime: float. Optional walltime to override the default (current)
92            walltime (from time.time()) seconds after epoch
93        """
94        event.wall_time = time.time() if walltime is None else walltime
95        if step is not None:
96            # Make sure step is converted from numpy or other formats
97            # since protobuf might not convert depending on version
98            event.step = int(step)
99        self.event_writer.add_event(event)
100
101    def add_summary(self, summary, global_step=None, walltime=None):
102        """Add a `Summary` protocol buffer to the event file.
103
104        This method wraps the provided summary in an `Event` protocol buffer
105        and adds it to the event file.
106
107        Args:
108          summary: A `Summary` protocol buffer.
109          global_step: Number. Optional global step value for training process
110            to record with the summary.
111          walltime: float. Optional walltime to override the default (current)
112            walltime (from time.time()) seconds after epoch
113        """
114        event = event_pb2.Event(summary=summary)
115        self.add_event(event, global_step, walltime)
116
117    def add_graph(self, graph_profile, walltime=None):
118        """Add a `Graph` and step stats protocol buffer to the event file.
119
120        Args:
121          graph_profile: A `Graph` and step stats protocol buffer.
122          walltime: float. Optional walltime to override the default (current)
123            walltime (from time.time()) seconds after epoch
124        """
125        graph = graph_profile[0]
126        stepstats = graph_profile[1]
127        event = event_pb2.Event(graph_def=graph.SerializeToString())
128        self.add_event(event, None, walltime)
129
130        trm = event_pb2.TaggedRunMetadata(
131            tag="step1", run_metadata=stepstats.SerializeToString()
132        )
133        event = event_pb2.Event(tagged_run_metadata=trm)
134        self.add_event(event, None, walltime)
135
136    def add_onnx_graph(self, graph, walltime=None):
137        """Add a `Graph` protocol buffer to the event file.
138
139        Args:
140          graph: A `Graph` protocol buffer.
141          walltime: float. Optional walltime to override the default (current)
142            _get_file_writerfrom time.time())
143        """
144        event = event_pb2.Event(graph_def=graph.SerializeToString())
145        self.add_event(event, None, walltime)
146
147    def flush(self):
148        """Flushes the event file to disk.
149
150        Call this method to make sure that all pending events have been written to
151        disk.
152        """
153        self.event_writer.flush()
154
155    def close(self):
156        """Flushes the event file to disk and close the file.
157
158        Call this method when you do not need the summary writer anymore.
159        """
160        self.event_writer.close()
161
162    def reopen(self):
163        """Reopens the EventFileWriter.
164
165        Can be called after `close()` to add more events in the same directory.
166        The events will go into a new events file.
167        Does nothing if the EventFileWriter was not closed.
168        """
169        self.event_writer.reopen()
170
171
172class SummaryWriter:
173    """Writes entries directly to event files in the log_dir to be consumed by TensorBoard.
174
175    The `SummaryWriter` class provides a high-level API to create an event file
176    in a given directory and add summaries and events to it. The class updates the
177    file contents asynchronously. This allows a training program to call methods
178    to add data to the file directly from the training loop, without slowing down
179    training.
180    """
181
182    def __init__(
183        self,
184        log_dir=None,
185        comment="",
186        purge_step=None,
187        max_queue=10,
188        flush_secs=120,
189        filename_suffix="",
190    ):
191        """Create a `SummaryWriter` that will write out events and summaries to the event file.
192
193        Args:
194            log_dir (str): Save directory location. Default is
195              runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run.
196              Use hierarchical folder structure to compare
197              between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc.
198              for each new experiment to compare across them.
199            comment (str): Comment log_dir suffix appended to the default
200              ``log_dir``. If ``log_dir`` is assigned, this argument has no effect.
201            purge_step (int):
202              When logging crashes at step :math:`T+X` and restarts at step :math:`T`,
203              any events whose global_step larger or equal to :math:`T` will be
204              purged and hidden from TensorBoard.
205              Note that crashed and resumed experiments should have the same ``log_dir``.
206            max_queue (int): Size of the queue for pending events and
207              summaries before one of the 'add' calls forces a flush to disk.
208              Default is ten items.
209            flush_secs (int): How often, in seconds, to flush the
210              pending events and summaries to disk. Default is every two minutes.
211            filename_suffix (str): Suffix added to all event filenames in
212              the log_dir directory. More details on filename construction in
213              tensorboard.summary.writer.event_file_writer.EventFileWriter.
214
215        Examples::
216
217            from torch.utils.tensorboard import SummaryWriter
218
219            # create a summary writer with automatically generated folder name.
220            writer = SummaryWriter()
221            # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/
222
223            # create a summary writer using the specified folder name.
224            writer = SummaryWriter("my_experiment")
225            # folder location: my_experiment
226
227            # create a summary writer with comment appended.
228            writer = SummaryWriter(comment="LR_0.1_BATCH_16")
229            # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/
230
231        """
232        torch._C._log_api_usage_once("tensorboard.create.summarywriter")
233        if not log_dir:
234            import socket
235            from datetime import datetime
236
237            current_time = datetime.now().strftime("%b%d_%H-%M-%S")
238            log_dir = os.path.join(
239                "runs", current_time + "_" + socket.gethostname() + comment
240            )
241        self.log_dir = log_dir
242        self.purge_step = purge_step
243        self.max_queue = max_queue
244        self.flush_secs = flush_secs
245        self.filename_suffix = filename_suffix
246
247        # Initialize the file writers, but they can be cleared out on close
248        # and recreated later as needed.
249        self.file_writer = self.all_writers = None
250        self._get_file_writer()
251
252        # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard
253        v = 1e-12
254        buckets = []
255        neg_buckets = []
256        while v < 1e20:
257            buckets.append(v)
258            neg_buckets.append(-v)
259            v *= 1.1
260        self.default_bins = neg_buckets[::-1] + [0] + buckets
261
262    def _get_file_writer(self):
263        """Return the default FileWriter instance. Recreates it if closed."""
264        if self.all_writers is None or self.file_writer is None:
265            self.file_writer = FileWriter(
266                self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
267            )
268            self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
269            if self.purge_step is not None:
270                most_recent_step = self.purge_step
271                self.file_writer.add_event(
272                    Event(step=most_recent_step, file_version="brain.Event:2")
273                )
274                self.file_writer.add_event(
275                    Event(
276                        step=most_recent_step,
277                        session_log=SessionLog(status=SessionLog.START),
278                    )
279                )
280                self.purge_step = None
281        return self.file_writer
282
283    def get_logdir(self):
284        """Return the directory where event files will be written."""
285        return self.log_dir
286
287    def add_hparams(
288        self,
289        hparam_dict,
290        metric_dict,
291        hparam_domain_discrete=None,
292        run_name=None,
293        global_step=None,
294    ):
295        """Add a set of hyperparameters to be compared in TensorBoard.
296
297        Args:
298            hparam_dict (dict): Each key-value pair in the dictionary is the
299              name of the hyper parameter and it's corresponding value.
300              The type of the value can be one of `bool`, `string`, `float`,
301              `int`, or `None`.
302            metric_dict (dict): Each key-value pair in the dictionary is the
303              name of the metric and it's corresponding value. Note that the key used
304              here should be unique in the tensorboard record. Otherwise the value
305              you added by ``add_scalar`` will be displayed in hparam plugin. In most
306              cases, this is unwanted.
307            hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
308              contains names of the hyperparameters and all discrete values they can hold
309            run_name (str): Name of the run, to be included as part of the logdir.
310              If unspecified, will use current timestamp.
311            global_step (int): Global step value to record
312
313        Examples::
314
315            from torch.utils.tensorboard import SummaryWriter
316            with SummaryWriter() as w:
317                for i in range(5):
318                    w.add_hparams({'lr': 0.1*i, 'bsize': i},
319                                  {'hparam/accuracy': 10*i, 'hparam/loss': 10*i})
320
321        Expected result:
322
323        .. image:: _static/img/tensorboard/add_hparam.png
324           :scale: 50 %
325
326        """
327        torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
328        if type(hparam_dict) is not dict or type(metric_dict) is not dict:
329            raise TypeError("hparam_dict and metric_dict should be dictionary.")
330        exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete)
331
332        if not run_name:
333            run_name = str(time.time())
334        logdir = os.path.join(self._get_file_writer().get_logdir(), run_name)
335        with SummaryWriter(log_dir=logdir) as w_hp:
336            w_hp.file_writer.add_summary(exp, global_step)
337            w_hp.file_writer.add_summary(ssi, global_step)
338            w_hp.file_writer.add_summary(sei, global_step)
339            for k, v in metric_dict.items():
340                w_hp.add_scalar(k, v, global_step)
341
342    def add_scalar(
343        self,
344        tag,
345        scalar_value,
346        global_step=None,
347        walltime=None,
348        new_style=False,
349        double_precision=False,
350    ):
351        """Add scalar data to summary.
352
353        Args:
354            tag (str): Data identifier
355            scalar_value (float or string/blobname): Value to save
356            global_step (int): Global step value to record
357            walltime (float): Optional override default walltime (time.time())
358              with seconds after epoch of event
359            new_style (boolean): Whether to use new style (tensor field) or old
360              style (simple_value field). New style could lead to faster data loading.
361        Examples::
362
363            from torch.utils.tensorboard import SummaryWriter
364            writer = SummaryWriter()
365            x = range(100)
366            for i in x:
367                writer.add_scalar('y=2x', i * 2, i)
368            writer.close()
369
370        Expected result:
371
372        .. image:: _static/img/tensorboard/add_scalar.png
373           :scale: 50 %
374
375        """
376        torch._C._log_api_usage_once("tensorboard.logging.add_scalar")
377
378        summary = scalar(
379            tag, scalar_value, new_style=new_style, double_precision=double_precision
380        )
381        self._get_file_writer().add_summary(summary, global_step, walltime)
382
383    def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None):
384        """Add many scalar data to summary.
385
386        Args:
387            main_tag (str): The parent name for the tags
388            tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values
389            global_step (int): Global step value to record
390            walltime (float): Optional override default walltime (time.time())
391              seconds after epoch of event
392
393        Examples::
394
395            from torch.utils.tensorboard import SummaryWriter
396            writer = SummaryWriter()
397            r = 5
398            for i in range(100):
399                writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r),
400                                                'xcosx':i*np.cos(i/r),
401                                                'tanx': np.tan(i/r)}, i)
402            writer.close()
403            # This call adds three values to the same scalar plot with the tag
404            # 'run_14h' in TensorBoard's scalar section.
405
406        Expected result:
407
408        .. image:: _static/img/tensorboard/add_scalars.png
409           :scale: 50 %
410
411        """
412        torch._C._log_api_usage_once("tensorboard.logging.add_scalars")
413        walltime = time.time() if walltime is None else walltime
414        fw_logdir = self._get_file_writer().get_logdir()
415        for tag, scalar_value in tag_scalar_dict.items():
416            fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag
417            assert self.all_writers is not None
418            if fw_tag in self.all_writers.keys():
419                fw = self.all_writers[fw_tag]
420            else:
421                fw = FileWriter(
422                    fw_tag, self.max_queue, self.flush_secs, self.filename_suffix
423                )
424                self.all_writers[fw_tag] = fw
425            fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime)
426
427    def add_tensor(
428        self,
429        tag,
430        tensor,
431        global_step=None,
432        walltime=None,
433    ):
434        """Add tensor data to summary.
435
436        Args:
437            tag (str): Data identifier
438            tensor (torch.Tensor): tensor to save
439            global_step (int): Global step value to record
440        Examples::
441
442            from torch.utils.tensorboard import SummaryWriter
443            writer = SummaryWriter()
444            x = torch.tensor([1,2,3])
445            writer.add_scalar('x', x)
446            writer.close()
447
448        Expected result:
449            Summary::tensor::float_val [1,2,3]
450                   ::tensor::shape [3]
451                   ::tag 'x'
452
453        """
454        torch._C._log_api_usage_once("tensorboard.logging.add_tensor")
455
456        summary = tensor_proto(tag, tensor)
457        self._get_file_writer().add_summary(summary, global_step, walltime)
458
459    def add_histogram(
460        self,
461        tag,
462        values,
463        global_step=None,
464        bins="tensorflow",
465        walltime=None,
466        max_bins=None,
467    ):
468        """Add histogram to summary.
469
470        Args:
471            tag (str): Data identifier
472            values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram
473            global_step (int): Global step value to record
474            bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find
475              other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html
476            walltime (float): Optional override default walltime (time.time())
477              seconds after epoch of event
478
479        Examples::
480
481            from torch.utils.tensorboard import SummaryWriter
482            import numpy as np
483            writer = SummaryWriter()
484            for i in range(10):
485                x = np.random.random(1000)
486                writer.add_histogram('distribution centers', x + i, i)
487            writer.close()
488
489        Expected result:
490
491        .. image:: _static/img/tensorboard/add_histogram.png
492           :scale: 50 %
493
494        """
495        torch._C._log_api_usage_once("tensorboard.logging.add_histogram")
496        if isinstance(bins, str) and bins == "tensorflow":
497            bins = self.default_bins
498        self._get_file_writer().add_summary(
499            histogram(tag, values, bins, max_bins=max_bins), global_step, walltime
500        )
501
502    def add_histogram_raw(
503        self,
504        tag,
505        min,
506        max,
507        num,
508        sum,
509        sum_squares,
510        bucket_limits,
511        bucket_counts,
512        global_step=None,
513        walltime=None,
514    ):
515        """Add histogram with raw data.
516
517        Args:
518            tag (str): Data identifier
519            min (float or int): Min value
520            max (float or int): Max value
521            num (int): Number of values
522            sum (float or int): Sum of all values
523            sum_squares (float or int): Sum of squares for all values
524            bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket.
525              The number of elements of it should be the same as `bucket_counts`.
526            bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket
527            global_step (int): Global step value to record
528            walltime (float): Optional override default walltime (time.time())
529              seconds after epoch of event
530            see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md
531
532        Examples::
533
534            from torch.utils.tensorboard import SummaryWriter
535            import numpy as np
536            writer = SummaryWriter()
537            dummy_data = []
538            for idx, value in enumerate(range(50)):
539                dummy_data += [idx + 0.001] * value
540
541            bins = list(range(50+2))
542            bins = np.array(bins)
543            values = np.array(dummy_data).astype(float).reshape(-1)
544            counts, limits = np.histogram(values, bins=bins)
545            sum_sq = values.dot(values)
546            writer.add_histogram_raw(
547                tag='histogram_with_raw_data',
548                min=values.min(),
549                max=values.max(),
550                num=len(values),
551                sum=values.sum(),
552                sum_squares=sum_sq,
553                bucket_limits=limits[1:].tolist(),
554                bucket_counts=counts.tolist(),
555                global_step=0)
556            writer.close()
557
558        Expected result:
559
560        .. image:: _static/img/tensorboard/add_histogram_raw.png
561           :scale: 50 %
562
563        """
564        torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw")
565        if len(bucket_limits) != len(bucket_counts):
566            raise ValueError(
567                "len(bucket_limits) != len(bucket_counts), see the document."
568            )
569        self._get_file_writer().add_summary(
570            histogram_raw(
571                tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts
572            ),
573            global_step,
574            walltime,
575        )
576
577    def add_image(
578        self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW"
579    ):
580        """Add image data to summary.
581
582        Note that this requires the ``pillow`` package.
583
584        Args:
585            tag (str): Data identifier
586            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
587            global_step (int): Global step value to record
588            walltime (float): Optional override default walltime (time.time())
589              seconds after epoch of event
590            dataformats (str): Image data format specification of the form
591              CHW, HWC, HW, WH, etc.
592        Shape:
593            img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
594            convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
595            Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
596            corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
597
598        Examples::
599
600            from torch.utils.tensorboard import SummaryWriter
601            import numpy as np
602            img = np.zeros((3, 100, 100))
603            img[0] = np.arange(0, 10000).reshape(100, 100) / 10000
604            img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
605
606            img_HWC = np.zeros((100, 100, 3))
607            img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000
608            img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000
609
610            writer = SummaryWriter()
611            writer.add_image('my_image', img, 0)
612
613            # If you have non-default dimension setting, set the dataformats argument.
614            writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC')
615            writer.close()
616
617        Expected result:
618
619        .. image:: _static/img/tensorboard/add_image.png
620           :scale: 50 %
621
622        """
623        torch._C._log_api_usage_once("tensorboard.logging.add_image")
624        self._get_file_writer().add_summary(
625            image(tag, img_tensor, dataformats=dataformats), global_step, walltime
626        )
627
628    def add_images(
629        self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW"
630    ):
631        """Add batched image data to summary.
632
633        Note that this requires the ``pillow`` package.
634
635        Args:
636            tag (str): Data identifier
637            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
638            global_step (int): Global step value to record
639            walltime (float): Optional override default walltime (time.time())
640              seconds after epoch of event
641            dataformats (str): Image data format specification of the form
642              NCHW, NHWC, CHW, HWC, HW, WH, etc.
643        Shape:
644            img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be
645            accepted. e.g. NCHW or NHWC.
646
647        Examples::
648
649            from torch.utils.tensorboard import SummaryWriter
650            import numpy as np
651
652            img_batch = np.zeros((16, 3, 100, 100))
653            for i in range(16):
654                img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i
655                img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i
656
657            writer = SummaryWriter()
658            writer.add_images('my_image_batch', img_batch, 0)
659            writer.close()
660
661        Expected result:
662
663        .. image:: _static/img/tensorboard/add_images.png
664           :scale: 30 %
665
666        """
667        torch._C._log_api_usage_once("tensorboard.logging.add_images")
668        self._get_file_writer().add_summary(
669            image(tag, img_tensor, dataformats=dataformats), global_step, walltime
670        )
671
672    def add_image_with_boxes(
673        self,
674        tag,
675        img_tensor,
676        box_tensor,
677        global_step=None,
678        walltime=None,
679        rescale=1,
680        dataformats="CHW",
681        labels=None,
682    ):
683        """Add image and draw bounding boxes on the image.
684
685        Args:
686            tag (str): Data identifier
687            img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
688            box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects)
689              box should be represented as [x1, y1, x2, y2].
690            global_step (int): Global step value to record
691            walltime (float): Optional override default walltime (time.time())
692              seconds after epoch of event
693            rescale (float): Optional scale override
694            dataformats (str): Image data format specification of the form
695              NCHW, NHWC, CHW, HWC, HW, WH, etc.
696            labels (list of string): The label to be shown for each bounding box.
697        Shape:
698            img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument.
699            e.g. CHW or HWC
700
701            box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4,  where N is the number of
702            boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax).
703        """
704        torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes")
705        if labels is not None:
706            if isinstance(labels, str):
707                labels = [labels]
708            if len(labels) != box_tensor.shape[0]:
709                labels = None
710        self._get_file_writer().add_summary(
711            image_boxes(
712                tag,
713                img_tensor,
714                box_tensor,
715                rescale=rescale,
716                dataformats=dataformats,
717                labels=labels,
718            ),
719            global_step,
720            walltime,
721        )
722
723    def add_figure(
724        self,
725        tag: str,
726        figure: Union["Figure", List["Figure"]],
727        global_step: Optional[int] = None,
728        close: bool = True,
729        walltime: Optional[float] = None,
730    ) -> None:
731        """Render matplotlib figure into an image and add it to summary.
732
733        Note that this requires the ``matplotlib`` package.
734
735        Args:
736            tag: Data identifier
737            figure: Figure or a list of figures
738            global_step: Global step value to record
739            close: Flag to automatically close the figure
740            walltime: Optional override default walltime (time.time())
741              seconds after epoch of event
742        """
743        torch._C._log_api_usage_once("tensorboard.logging.add_figure")
744        if isinstance(figure, list):
745            self.add_image(
746                tag,
747                figure_to_image(figure, close),
748                global_step,
749                walltime,
750                dataformats="NCHW",
751            )
752        else:
753            self.add_image(
754                tag,
755                figure_to_image(figure, close),
756                global_step,
757                walltime,
758                dataformats="CHW",
759            )
760
761    def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):
762        """Add video data to summary.
763
764        Note that this requires the ``moviepy`` package.
765
766        Args:
767            tag (str): Data identifier
768            vid_tensor (torch.Tensor): Video data
769            global_step (int): Global step value to record
770            fps (float or int): Frames per second
771            walltime (float): Optional override default walltime (time.time())
772              seconds after epoch of event
773        Shape:
774            vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
775        """
776        torch._C._log_api_usage_once("tensorboard.logging.add_video")
777        self._get_file_writer().add_summary(
778            video(tag, vid_tensor, fps), global_step, walltime
779        )
780
781    def add_audio(
782        self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None
783    ):
784        """Add audio data to summary.
785
786        Args:
787            tag (str): Data identifier
788            snd_tensor (torch.Tensor): Sound data
789            global_step (int): Global step value to record
790            sample_rate (int): sample rate in Hz
791            walltime (float): Optional override default walltime (time.time())
792              seconds after epoch of event
793        Shape:
794            snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
795        """
796        torch._C._log_api_usage_once("tensorboard.logging.add_audio")
797        self._get_file_writer().add_summary(
798            audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime
799        )
800
801    def add_text(self, tag, text_string, global_step=None, walltime=None):
802        """Add text data to summary.
803
804        Args:
805            tag (str): Data identifier
806            text_string (str): String to save
807            global_step (int): Global step value to record
808            walltime (float): Optional override default walltime (time.time())
809              seconds after epoch of event
810        Examples::
811
812            writer.add_text('lstm', 'This is an lstm', 0)
813            writer.add_text('rnn', 'This is an rnn', 10)
814        """
815        torch._C._log_api_usage_once("tensorboard.logging.add_text")
816        self._get_file_writer().add_summary(
817            text(tag, text_string), global_step, walltime
818        )
819
820    def add_onnx_graph(self, prototxt):
821        torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph")
822        self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt))
823
824    def add_graph(
825        self, model, input_to_model=None, verbose=False, use_strict_trace=True
826    ):
827        """Add graph data to summary.
828
829        Args:
830            model (torch.nn.Module): Model to draw.
831            input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of
832                variables to be fed.
833            verbose (bool): Whether to print graph structure in console.
834            use_strict_trace (bool): Whether to pass keyword argument `strict` to
835                `torch.jit.trace`. Pass False when you want the tracer to
836                record your mutable container types (list, dict)
837        """
838        torch._C._log_api_usage_once("tensorboard.logging.add_graph")
839        # A valid PyTorch model should have a 'forward' method
840        self._get_file_writer().add_graph(
841            graph(model, input_to_model, verbose, use_strict_trace)
842        )
843
844    @staticmethod
845    def _encode(rawstr):
846        # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc.
847        retval = rawstr
848        retval = retval.replace("%", f"%{ord('%'):02x}")
849        retval = retval.replace("/", f"%{ord('/'):02x}")
850        retval = retval.replace("\\", "%%%02x" % (ord("\\")))  # noqa: UP031
851        return retval
852
853    def add_embedding(
854        self,
855        mat,
856        metadata=None,
857        label_img=None,
858        global_step=None,
859        tag="default",
860        metadata_header=None,
861    ):
862        """Add embedding projector data to summary.
863
864        Args:
865            mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point
866            metadata (list): A list of labels, each element will be converted to string
867            label_img (torch.Tensor): Images correspond to each data point
868            global_step (int): Global step value to record
869            tag (str): Name for the embedding
870            metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be
871                a list with values corresponding to headers.
872        Shape:
873            mat: :math:`(N, D)`, where N is number of data and D is feature dimension
874
875            label_img: :math:`(N, C, H, W)`
876
877        Examples::
878
879            import keyword
880            import torch
881            meta = []
882            while len(meta)<100:
883                meta = meta+keyword.kwlist # get some strings
884            meta = meta[:100]
885
886            for i, v in enumerate(meta):
887                meta[i] = v+str(i)
888
889            label_img = torch.rand(100, 3, 10, 32)
890            for i in range(100):
891                label_img[i]*=i/100.0
892
893            writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img)
894            writer.add_embedding(torch.randn(100, 5), label_img=label_img)
895            writer.add_embedding(torch.randn(100, 5), metadata=meta)
896
897        .. note::
898            Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for
899            coloring in the embedding projector.
900
901        """
902        torch._C._log_api_usage_once("tensorboard.logging.add_embedding")
903        mat = make_np(mat)
904        if global_step is None:
905            global_step = 0
906            # clear pbtxt?
907
908        # Maybe we should encode the tag so slashes don't trip us up?
909        # I don't think this will mess us up, but better safe than sorry.
910        subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}"
911        save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
912
913        fs = tf.io.gfile
914        if fs.exists(save_path):
915            if fs.isdir(save_path):
916                print(
917                    "warning: Embedding dir exists, did you set global_step for add_embedding()?"
918                )
919            else:
920                raise NotADirectoryError(
921                    f"Path: `{save_path}` exists, but is a file. Cannot proceed."
922                )
923        else:
924            fs.makedirs(save_path)
925
926        if metadata is not None:
927            assert mat.shape[0] == len(
928                metadata
929            ), "#labels should equal with #data points"
930            make_tsv(metadata, save_path, metadata_header=metadata_header)
931
932        if label_img is not None:
933            assert (
934                mat.shape[0] == label_img.shape[0]
935            ), "#images should equal with #data points"
936            make_sprite(label_img, save_path)
937
938        assert (
939            mat.ndim == 2
940        ), "mat should be 2D, where mat.size(0) is the number of data points"
941        make_mat(mat, save_path)
942
943        # Filesystem doesn't necessarily have append semantics, so we store an
944        # internal buffer to append to and re-write whole file after each
945        # embedding is added
946        if not hasattr(self, "_projector_config"):
947            self._projector_config = ProjectorConfig()
948        embedding_info = get_embedding_info(
949            metadata, label_img, subdir, global_step, tag
950        )
951        self._projector_config.embeddings.extend([embedding_info])
952
953        from google.protobuf import text_format
954
955        config_pbtxt = text_format.MessageToString(self._projector_config)
956        write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt)
957
958    def add_pr_curve(
959        self,
960        tag,
961        labels,
962        predictions,
963        global_step=None,
964        num_thresholds=127,
965        weights=None,
966        walltime=None,
967    ):
968        """Add precision recall curve.
969
970        Plotting a precision-recall curve lets you understand your model's
971        performance under different threshold settings. With this function,
972        you provide the ground truth labeling (T/F) and prediction confidence
973        (usually the output of your model) for each target. The TensorBoard UI
974        will let you choose the threshold interactively.
975
976        Args:
977            tag (str): Data identifier
978            labels (torch.Tensor, numpy.ndarray, or string/blobname):
979              Ground truth data. Binary label for each element.
980            predictions (torch.Tensor, numpy.ndarray, or string/blobname):
981              The probability that an element be classified as true.
982              Value should be in [0, 1]
983            global_step (int): Global step value to record
984            num_thresholds (int): Number of thresholds used to draw the curve.
985            walltime (float): Optional override default walltime (time.time())
986              seconds after epoch of event
987
988        Examples::
989
990            from torch.utils.tensorboard import SummaryWriter
991            import numpy as np
992            labels = np.random.randint(2, size=100)  # binary label
993            predictions = np.random.rand(100)
994            writer = SummaryWriter()
995            writer.add_pr_curve('pr_curve', labels, predictions, 0)
996            writer.close()
997
998        """
999        torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve")
1000        labels, predictions = make_np(labels), make_np(predictions)
1001        self._get_file_writer().add_summary(
1002            pr_curve(tag, labels, predictions, num_thresholds, weights),
1003            global_step,
1004            walltime,
1005        )
1006
1007    def add_pr_curve_raw(
1008        self,
1009        tag,
1010        true_positive_counts,
1011        false_positive_counts,
1012        true_negative_counts,
1013        false_negative_counts,
1014        precision,
1015        recall,
1016        global_step=None,
1017        num_thresholds=127,
1018        weights=None,
1019        walltime=None,
1020    ):
1021        """Add precision recall curve with raw data.
1022
1023        Args:
1024            tag (str): Data identifier
1025            true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts
1026            false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts
1027            true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts
1028            false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts
1029            precision (torch.Tensor, numpy.ndarray, or string/blobname): precision
1030            recall (torch.Tensor, numpy.ndarray, or string/blobname): recall
1031            global_step (int): Global step value to record
1032            num_thresholds (int): Number of thresholds used to draw the curve.
1033            walltime (float): Optional override default walltime (time.time())
1034              seconds after epoch of event
1035            see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md
1036        """
1037        torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw")
1038        self._get_file_writer().add_summary(
1039            pr_curve_raw(
1040                tag,
1041                true_positive_counts,
1042                false_positive_counts,
1043                true_negative_counts,
1044                false_negative_counts,
1045                precision,
1046                recall,
1047                num_thresholds,
1048                weights,
1049            ),
1050            global_step,
1051            walltime,
1052        )
1053
1054    def add_custom_scalars_multilinechart(
1055        self, tags, category="default", title="untitled"
1056    ):
1057        """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*.
1058
1059        Args:
1060            tags (list): list of tags that have been used in ``add_scalar()``
1061
1062        Examples::
1063
1064            writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330'])
1065        """
1066        torch._C._log_api_usage_once(
1067            "tensorboard.logging.add_custom_scalars_multilinechart"
1068        )
1069        layout = {category: {title: ["Multiline", tags]}}
1070        self._get_file_writer().add_summary(custom_scalars(layout))
1071
1072    def add_custom_scalars_marginchart(
1073        self, tags, category="default", title="untitled"
1074    ):
1075        """Shorthand for creating marginchart.
1076
1077        Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*,
1078        which should have exactly 3 elements.
1079
1080        Args:
1081            tags (list): list of tags that have been used in ``add_scalar()``
1082
1083        Examples::
1084
1085            writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006'])
1086        """
1087        torch._C._log_api_usage_once(
1088            "tensorboard.logging.add_custom_scalars_marginchart"
1089        )
1090        assert len(tags) == 3
1091        layout = {category: {title: ["Margin", tags]}}
1092        self._get_file_writer().add_summary(custom_scalars(layout))
1093
1094    def add_custom_scalars(self, layout):
1095        """Create special chart by collecting charts tags in 'scalars'.
1096
1097        NOTE: This function can only be called once for each SummaryWriter() object.
1098
1099        Because it only provides metadata to tensorboard, the function can be called before or after the training loop.
1100
1101        Args:
1102            layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary
1103              {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type
1104              (one of **Multiline** or **Margin**) and the second element should be a list containing the tags
1105              you have used in add_scalar function, which will be collected into the new chart.
1106
1107        Examples::
1108
1109            layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]},
1110                         'USA':{ 'dow':['Margin',   ['dow/aaa', 'dow/bbb', 'dow/ccc']],
1111                              'nasdaq':['Margin',   ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}}
1112
1113            writer.add_custom_scalars(layout)
1114        """
1115        torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars")
1116        self._get_file_writer().add_summary(custom_scalars(layout))
1117
1118    def add_mesh(
1119        self,
1120        tag,
1121        vertices,
1122        colors=None,
1123        faces=None,
1124        config_dict=None,
1125        global_step=None,
1126        walltime=None,
1127    ):
1128        """Add meshes or 3D point clouds to TensorBoard.
1129
1130        The visualization is based on Three.js,
1131        so it allows users to interact with the rendered object. Besides the basic definitions
1132        such as vertices, faces, users can further provide camera parameter, lighting condition, etc.
1133        Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for
1134        advanced usage.
1135
1136        Args:
1137            tag (str): Data identifier
1138            vertices (torch.Tensor): List of the 3D coordinates of vertices.
1139            colors (torch.Tensor): Colors for each vertex
1140            faces (torch.Tensor): Indices of vertices within each triangle. (Optional)
1141            config_dict: Dictionary with ThreeJS classes names and configuration.
1142            global_step (int): Global step value to record
1143            walltime (float): Optional override default walltime (time.time())
1144              seconds after epoch of event
1145
1146        Shape:
1147            vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels)
1148
1149            colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`.
1150
1151            faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`.
1152
1153        Examples::
1154
1155            from torch.utils.tensorboard import SummaryWriter
1156            vertices_tensor = torch.as_tensor([
1157                [1, 1, 1],
1158                [-1, -1, 1],
1159                [1, -1, -1],
1160                [-1, 1, -1],
1161            ], dtype=torch.float).unsqueeze(0)
1162            colors_tensor = torch.as_tensor([
1163                [255, 0, 0],
1164                [0, 255, 0],
1165                [0, 0, 255],
1166                [255, 0, 255],
1167            ], dtype=torch.int).unsqueeze(0)
1168            faces_tensor = torch.as_tensor([
1169                [0, 2, 3],
1170                [0, 3, 1],
1171                [0, 1, 2],
1172                [1, 3, 2],
1173            ], dtype=torch.int).unsqueeze(0)
1174
1175            writer = SummaryWriter()
1176            writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor)
1177
1178            writer.close()
1179        """
1180        torch._C._log_api_usage_once("tensorboard.logging.add_mesh")
1181        self._get_file_writer().add_summary(
1182            mesh(tag, vertices, colors, faces, config_dict), global_step, walltime
1183        )
1184
1185    def flush(self):
1186        """Flushes the event file to disk.
1187
1188        Call this method to make sure that all pending events have been written to
1189        disk.
1190        """
1191        if self.all_writers is None:
1192            return
1193        for writer in self.all_writers.values():
1194            writer.flush()
1195
1196    def close(self):
1197        if self.all_writers is None:
1198            return  # ignore double close
1199        for writer in self.all_writers.values():
1200            writer.flush()
1201            writer.close()
1202        self.file_writer = self.all_writers = None
1203
1204    def __enter__(self):
1205        return self
1206
1207    def __exit__(self, exc_type, exc_val, exc_tb):
1208        self.close()
1209