1# mypy: allow-untyped-defs 2"""Provide an API for writing protocol buffers to event files to be consumed by TensorBoard for visualization.""" 3 4import os 5import time 6from typing import List, Optional, TYPE_CHECKING, Union 7 8import torch 9 10if TYPE_CHECKING: 11 from matplotlib.figure import Figure 12from tensorboard.compat import tf 13from tensorboard.compat.proto import event_pb2 14from tensorboard.compat.proto.event_pb2 import Event, SessionLog 15from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig 16from tensorboard.summary.writer.event_file_writer import EventFileWriter 17 18from ._convert_np import make_np 19from ._embedding import get_embedding_info, make_mat, make_sprite, make_tsv, write_pbtxt 20from ._onnx_graph import load_onnx_graph 21from ._pytorch_graph import graph 22from ._utils import figure_to_image 23from .summary import ( 24 audio, 25 custom_scalars, 26 histogram, 27 histogram_raw, 28 hparams, 29 image, 30 image_boxes, 31 mesh, 32 pr_curve, 33 pr_curve_raw, 34 scalar, 35 tensor_proto, 36 text, 37 video, 38) 39 40__all__ = ["FileWriter", "SummaryWriter"] 41 42 43class FileWriter: 44 """Writes protocol buffers to event files to be consumed by TensorBoard. 45 46 The `FileWriter` class provides a mechanism to create an event file in a 47 given directory and add summaries and events to it. The class updates the 48 file contents asynchronously. This allows a training program to call methods 49 to add data to the file directly from the training loop, without slowing down 50 training. 51 """ 52 53 def __init__(self, log_dir, max_queue=10, flush_secs=120, filename_suffix=""): 54 """Create a `FileWriter` and an event file. 55 56 On construction the writer creates a new event file in `log_dir`. 57 The other arguments to the constructor control the asynchronous writes to 58 the event file. 59 60 Args: 61 log_dir: A string. Directory where event file will be written. 62 max_queue: Integer. Size of the queue for pending events and 63 summaries before one of the 'add' calls forces a flush to disk. 64 Default is ten items. 65 flush_secs: Number. How often, in seconds, to flush the 66 pending events and summaries to disk. Default is every two minutes. 67 filename_suffix: A string. Suffix added to all event filenames 68 in the log_dir directory. More details on filename construction in 69 tensorboard.summary.writer.event_file_writer.EventFileWriter. 70 """ 71 # Sometimes PosixPath is passed in and we need to coerce it to 72 # a string in all cases 73 # TODO: See if we can remove this in the future if we are 74 # actually the ones passing in a PosixPath 75 log_dir = str(log_dir) 76 self.event_writer = EventFileWriter( 77 log_dir, max_queue, flush_secs, filename_suffix 78 ) 79 80 def get_logdir(self): 81 """Return the directory where event file will be written.""" 82 return self.event_writer.get_logdir() 83 84 def add_event(self, event, step=None, walltime=None): 85 """Add an event to the event file. 86 87 Args: 88 event: An `Event` protocol buffer. 89 step: Number. Optional global step value for training process 90 to record with the event. 91 walltime: float. Optional walltime to override the default (current) 92 walltime (from time.time()) seconds after epoch 93 """ 94 event.wall_time = time.time() if walltime is None else walltime 95 if step is not None: 96 # Make sure step is converted from numpy or other formats 97 # since protobuf might not convert depending on version 98 event.step = int(step) 99 self.event_writer.add_event(event) 100 101 def add_summary(self, summary, global_step=None, walltime=None): 102 """Add a `Summary` protocol buffer to the event file. 103 104 This method wraps the provided summary in an `Event` protocol buffer 105 and adds it to the event file. 106 107 Args: 108 summary: A `Summary` protocol buffer. 109 global_step: Number. Optional global step value for training process 110 to record with the summary. 111 walltime: float. Optional walltime to override the default (current) 112 walltime (from time.time()) seconds after epoch 113 """ 114 event = event_pb2.Event(summary=summary) 115 self.add_event(event, global_step, walltime) 116 117 def add_graph(self, graph_profile, walltime=None): 118 """Add a `Graph` and step stats protocol buffer to the event file. 119 120 Args: 121 graph_profile: A `Graph` and step stats protocol buffer. 122 walltime: float. Optional walltime to override the default (current) 123 walltime (from time.time()) seconds after epoch 124 """ 125 graph = graph_profile[0] 126 stepstats = graph_profile[1] 127 event = event_pb2.Event(graph_def=graph.SerializeToString()) 128 self.add_event(event, None, walltime) 129 130 trm = event_pb2.TaggedRunMetadata( 131 tag="step1", run_metadata=stepstats.SerializeToString() 132 ) 133 event = event_pb2.Event(tagged_run_metadata=trm) 134 self.add_event(event, None, walltime) 135 136 def add_onnx_graph(self, graph, walltime=None): 137 """Add a `Graph` protocol buffer to the event file. 138 139 Args: 140 graph: A `Graph` protocol buffer. 141 walltime: float. Optional walltime to override the default (current) 142 _get_file_writerfrom time.time()) 143 """ 144 event = event_pb2.Event(graph_def=graph.SerializeToString()) 145 self.add_event(event, None, walltime) 146 147 def flush(self): 148 """Flushes the event file to disk. 149 150 Call this method to make sure that all pending events have been written to 151 disk. 152 """ 153 self.event_writer.flush() 154 155 def close(self): 156 """Flushes the event file to disk and close the file. 157 158 Call this method when you do not need the summary writer anymore. 159 """ 160 self.event_writer.close() 161 162 def reopen(self): 163 """Reopens the EventFileWriter. 164 165 Can be called after `close()` to add more events in the same directory. 166 The events will go into a new events file. 167 Does nothing if the EventFileWriter was not closed. 168 """ 169 self.event_writer.reopen() 170 171 172class SummaryWriter: 173 """Writes entries directly to event files in the log_dir to be consumed by TensorBoard. 174 175 The `SummaryWriter` class provides a high-level API to create an event file 176 in a given directory and add summaries and events to it. The class updates the 177 file contents asynchronously. This allows a training program to call methods 178 to add data to the file directly from the training loop, without slowing down 179 training. 180 """ 181 182 def __init__( 183 self, 184 log_dir=None, 185 comment="", 186 purge_step=None, 187 max_queue=10, 188 flush_secs=120, 189 filename_suffix="", 190 ): 191 """Create a `SummaryWriter` that will write out events and summaries to the event file. 192 193 Args: 194 log_dir (str): Save directory location. Default is 195 runs/**CURRENT_DATETIME_HOSTNAME**, which changes after each run. 196 Use hierarchical folder structure to compare 197 between runs easily. e.g. pass in 'runs/exp1', 'runs/exp2', etc. 198 for each new experiment to compare across them. 199 comment (str): Comment log_dir suffix appended to the default 200 ``log_dir``. If ``log_dir`` is assigned, this argument has no effect. 201 purge_step (int): 202 When logging crashes at step :math:`T+X` and restarts at step :math:`T`, 203 any events whose global_step larger or equal to :math:`T` will be 204 purged and hidden from TensorBoard. 205 Note that crashed and resumed experiments should have the same ``log_dir``. 206 max_queue (int): Size of the queue for pending events and 207 summaries before one of the 'add' calls forces a flush to disk. 208 Default is ten items. 209 flush_secs (int): How often, in seconds, to flush the 210 pending events and summaries to disk. Default is every two minutes. 211 filename_suffix (str): Suffix added to all event filenames in 212 the log_dir directory. More details on filename construction in 213 tensorboard.summary.writer.event_file_writer.EventFileWriter. 214 215 Examples:: 216 217 from torch.utils.tensorboard import SummaryWriter 218 219 # create a summary writer with automatically generated folder name. 220 writer = SummaryWriter() 221 # folder location: runs/May04_22-14-54_s-MacBook-Pro.local/ 222 223 # create a summary writer using the specified folder name. 224 writer = SummaryWriter("my_experiment") 225 # folder location: my_experiment 226 227 # create a summary writer with comment appended. 228 writer = SummaryWriter(comment="LR_0.1_BATCH_16") 229 # folder location: runs/May04_22-14-54_s-MacBook-Pro.localLR_0.1_BATCH_16/ 230 231 """ 232 torch._C._log_api_usage_once("tensorboard.create.summarywriter") 233 if not log_dir: 234 import socket 235 from datetime import datetime 236 237 current_time = datetime.now().strftime("%b%d_%H-%M-%S") 238 log_dir = os.path.join( 239 "runs", current_time + "_" + socket.gethostname() + comment 240 ) 241 self.log_dir = log_dir 242 self.purge_step = purge_step 243 self.max_queue = max_queue 244 self.flush_secs = flush_secs 245 self.filename_suffix = filename_suffix 246 247 # Initialize the file writers, but they can be cleared out on close 248 # and recreated later as needed. 249 self.file_writer = self.all_writers = None 250 self._get_file_writer() 251 252 # Create default bins for histograms, see generate_testdata.py in tensorflow/tensorboard 253 v = 1e-12 254 buckets = [] 255 neg_buckets = [] 256 while v < 1e20: 257 buckets.append(v) 258 neg_buckets.append(-v) 259 v *= 1.1 260 self.default_bins = neg_buckets[::-1] + [0] + buckets 261 262 def _get_file_writer(self): 263 """Return the default FileWriter instance. Recreates it if closed.""" 264 if self.all_writers is None or self.file_writer is None: 265 self.file_writer = FileWriter( 266 self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix 267 ) 268 self.all_writers = {self.file_writer.get_logdir(): self.file_writer} 269 if self.purge_step is not None: 270 most_recent_step = self.purge_step 271 self.file_writer.add_event( 272 Event(step=most_recent_step, file_version="brain.Event:2") 273 ) 274 self.file_writer.add_event( 275 Event( 276 step=most_recent_step, 277 session_log=SessionLog(status=SessionLog.START), 278 ) 279 ) 280 self.purge_step = None 281 return self.file_writer 282 283 def get_logdir(self): 284 """Return the directory where event files will be written.""" 285 return self.log_dir 286 287 def add_hparams( 288 self, 289 hparam_dict, 290 metric_dict, 291 hparam_domain_discrete=None, 292 run_name=None, 293 global_step=None, 294 ): 295 """Add a set of hyperparameters to be compared in TensorBoard. 296 297 Args: 298 hparam_dict (dict): Each key-value pair in the dictionary is the 299 name of the hyper parameter and it's corresponding value. 300 The type of the value can be one of `bool`, `string`, `float`, 301 `int`, or `None`. 302 metric_dict (dict): Each key-value pair in the dictionary is the 303 name of the metric and it's corresponding value. Note that the key used 304 here should be unique in the tensorboard record. Otherwise the value 305 you added by ``add_scalar`` will be displayed in hparam plugin. In most 306 cases, this is unwanted. 307 hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that 308 contains names of the hyperparameters and all discrete values they can hold 309 run_name (str): Name of the run, to be included as part of the logdir. 310 If unspecified, will use current timestamp. 311 global_step (int): Global step value to record 312 313 Examples:: 314 315 from torch.utils.tensorboard import SummaryWriter 316 with SummaryWriter() as w: 317 for i in range(5): 318 w.add_hparams({'lr': 0.1*i, 'bsize': i}, 319 {'hparam/accuracy': 10*i, 'hparam/loss': 10*i}) 320 321 Expected result: 322 323 .. image:: _static/img/tensorboard/add_hparam.png 324 :scale: 50 % 325 326 """ 327 torch._C._log_api_usage_once("tensorboard.logging.add_hparams") 328 if type(hparam_dict) is not dict or type(metric_dict) is not dict: 329 raise TypeError("hparam_dict and metric_dict should be dictionary.") 330 exp, ssi, sei = hparams(hparam_dict, metric_dict, hparam_domain_discrete) 331 332 if not run_name: 333 run_name = str(time.time()) 334 logdir = os.path.join(self._get_file_writer().get_logdir(), run_name) 335 with SummaryWriter(log_dir=logdir) as w_hp: 336 w_hp.file_writer.add_summary(exp, global_step) 337 w_hp.file_writer.add_summary(ssi, global_step) 338 w_hp.file_writer.add_summary(sei, global_step) 339 for k, v in metric_dict.items(): 340 w_hp.add_scalar(k, v, global_step) 341 342 def add_scalar( 343 self, 344 tag, 345 scalar_value, 346 global_step=None, 347 walltime=None, 348 new_style=False, 349 double_precision=False, 350 ): 351 """Add scalar data to summary. 352 353 Args: 354 tag (str): Data identifier 355 scalar_value (float or string/blobname): Value to save 356 global_step (int): Global step value to record 357 walltime (float): Optional override default walltime (time.time()) 358 with seconds after epoch of event 359 new_style (boolean): Whether to use new style (tensor field) or old 360 style (simple_value field). New style could lead to faster data loading. 361 Examples:: 362 363 from torch.utils.tensorboard import SummaryWriter 364 writer = SummaryWriter() 365 x = range(100) 366 for i in x: 367 writer.add_scalar('y=2x', i * 2, i) 368 writer.close() 369 370 Expected result: 371 372 .. image:: _static/img/tensorboard/add_scalar.png 373 :scale: 50 % 374 375 """ 376 torch._C._log_api_usage_once("tensorboard.logging.add_scalar") 377 378 summary = scalar( 379 tag, scalar_value, new_style=new_style, double_precision=double_precision 380 ) 381 self._get_file_writer().add_summary(summary, global_step, walltime) 382 383 def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None): 384 """Add many scalar data to summary. 385 386 Args: 387 main_tag (str): The parent name for the tags 388 tag_scalar_dict (dict): Key-value pair storing the tag and corresponding values 389 global_step (int): Global step value to record 390 walltime (float): Optional override default walltime (time.time()) 391 seconds after epoch of event 392 393 Examples:: 394 395 from torch.utils.tensorboard import SummaryWriter 396 writer = SummaryWriter() 397 r = 5 398 for i in range(100): 399 writer.add_scalars('run_14h', {'xsinx':i*np.sin(i/r), 400 'xcosx':i*np.cos(i/r), 401 'tanx': np.tan(i/r)}, i) 402 writer.close() 403 # This call adds three values to the same scalar plot with the tag 404 # 'run_14h' in TensorBoard's scalar section. 405 406 Expected result: 407 408 .. image:: _static/img/tensorboard/add_scalars.png 409 :scale: 50 % 410 411 """ 412 torch._C._log_api_usage_once("tensorboard.logging.add_scalars") 413 walltime = time.time() if walltime is None else walltime 414 fw_logdir = self._get_file_writer().get_logdir() 415 for tag, scalar_value in tag_scalar_dict.items(): 416 fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag 417 assert self.all_writers is not None 418 if fw_tag in self.all_writers.keys(): 419 fw = self.all_writers[fw_tag] 420 else: 421 fw = FileWriter( 422 fw_tag, self.max_queue, self.flush_secs, self.filename_suffix 423 ) 424 self.all_writers[fw_tag] = fw 425 fw.add_summary(scalar(main_tag, scalar_value), global_step, walltime) 426 427 def add_tensor( 428 self, 429 tag, 430 tensor, 431 global_step=None, 432 walltime=None, 433 ): 434 """Add tensor data to summary. 435 436 Args: 437 tag (str): Data identifier 438 tensor (torch.Tensor): tensor to save 439 global_step (int): Global step value to record 440 Examples:: 441 442 from torch.utils.tensorboard import SummaryWriter 443 writer = SummaryWriter() 444 x = torch.tensor([1,2,3]) 445 writer.add_scalar('x', x) 446 writer.close() 447 448 Expected result: 449 Summary::tensor::float_val [1,2,3] 450 ::tensor::shape [3] 451 ::tag 'x' 452 453 """ 454 torch._C._log_api_usage_once("tensorboard.logging.add_tensor") 455 456 summary = tensor_proto(tag, tensor) 457 self._get_file_writer().add_summary(summary, global_step, walltime) 458 459 def add_histogram( 460 self, 461 tag, 462 values, 463 global_step=None, 464 bins="tensorflow", 465 walltime=None, 466 max_bins=None, 467 ): 468 """Add histogram to summary. 469 470 Args: 471 tag (str): Data identifier 472 values (torch.Tensor, numpy.ndarray, or string/blobname): Values to build histogram 473 global_step (int): Global step value to record 474 bins (str): One of {'tensorflow','auto', 'fd', ...}. This determines how the bins are made. You can find 475 other options in: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram.html 476 walltime (float): Optional override default walltime (time.time()) 477 seconds after epoch of event 478 479 Examples:: 480 481 from torch.utils.tensorboard import SummaryWriter 482 import numpy as np 483 writer = SummaryWriter() 484 for i in range(10): 485 x = np.random.random(1000) 486 writer.add_histogram('distribution centers', x + i, i) 487 writer.close() 488 489 Expected result: 490 491 .. image:: _static/img/tensorboard/add_histogram.png 492 :scale: 50 % 493 494 """ 495 torch._C._log_api_usage_once("tensorboard.logging.add_histogram") 496 if isinstance(bins, str) and bins == "tensorflow": 497 bins = self.default_bins 498 self._get_file_writer().add_summary( 499 histogram(tag, values, bins, max_bins=max_bins), global_step, walltime 500 ) 501 502 def add_histogram_raw( 503 self, 504 tag, 505 min, 506 max, 507 num, 508 sum, 509 sum_squares, 510 bucket_limits, 511 bucket_counts, 512 global_step=None, 513 walltime=None, 514 ): 515 """Add histogram with raw data. 516 517 Args: 518 tag (str): Data identifier 519 min (float or int): Min value 520 max (float or int): Max value 521 num (int): Number of values 522 sum (float or int): Sum of all values 523 sum_squares (float or int): Sum of squares for all values 524 bucket_limits (torch.Tensor, numpy.ndarray): Upper value per bucket. 525 The number of elements of it should be the same as `bucket_counts`. 526 bucket_counts (torch.Tensor, numpy.ndarray): Number of values per bucket 527 global_step (int): Global step value to record 528 walltime (float): Optional override default walltime (time.time()) 529 seconds after epoch of event 530 see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/README.md 531 532 Examples:: 533 534 from torch.utils.tensorboard import SummaryWriter 535 import numpy as np 536 writer = SummaryWriter() 537 dummy_data = [] 538 for idx, value in enumerate(range(50)): 539 dummy_data += [idx + 0.001] * value 540 541 bins = list(range(50+2)) 542 bins = np.array(bins) 543 values = np.array(dummy_data).astype(float).reshape(-1) 544 counts, limits = np.histogram(values, bins=bins) 545 sum_sq = values.dot(values) 546 writer.add_histogram_raw( 547 tag='histogram_with_raw_data', 548 min=values.min(), 549 max=values.max(), 550 num=len(values), 551 sum=values.sum(), 552 sum_squares=sum_sq, 553 bucket_limits=limits[1:].tolist(), 554 bucket_counts=counts.tolist(), 555 global_step=0) 556 writer.close() 557 558 Expected result: 559 560 .. image:: _static/img/tensorboard/add_histogram_raw.png 561 :scale: 50 % 562 563 """ 564 torch._C._log_api_usage_once("tensorboard.logging.add_histogram_raw") 565 if len(bucket_limits) != len(bucket_counts): 566 raise ValueError( 567 "len(bucket_limits) != len(bucket_counts), see the document." 568 ) 569 self._get_file_writer().add_summary( 570 histogram_raw( 571 tag, min, max, num, sum, sum_squares, bucket_limits, bucket_counts 572 ), 573 global_step, 574 walltime, 575 ) 576 577 def add_image( 578 self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW" 579 ): 580 """Add image data to summary. 581 582 Note that this requires the ``pillow`` package. 583 584 Args: 585 tag (str): Data identifier 586 img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data 587 global_step (int): Global step value to record 588 walltime (float): Optional override default walltime (time.time()) 589 seconds after epoch of event 590 dataformats (str): Image data format specification of the form 591 CHW, HWC, HW, WH, etc. 592 Shape: 593 img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to 594 convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job. 595 Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as 596 corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``. 597 598 Examples:: 599 600 from torch.utils.tensorboard import SummaryWriter 601 import numpy as np 602 img = np.zeros((3, 100, 100)) 603 img[0] = np.arange(0, 10000).reshape(100, 100) / 10000 604 img[1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 605 606 img_HWC = np.zeros((100, 100, 3)) 607 img_HWC[:, :, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 608 img_HWC[:, :, 1] = 1 - np.arange(0, 10000).reshape(100, 100) / 10000 609 610 writer = SummaryWriter() 611 writer.add_image('my_image', img, 0) 612 613 # If you have non-default dimension setting, set the dataformats argument. 614 writer.add_image('my_image_HWC', img_HWC, 0, dataformats='HWC') 615 writer.close() 616 617 Expected result: 618 619 .. image:: _static/img/tensorboard/add_image.png 620 :scale: 50 % 621 622 """ 623 torch._C._log_api_usage_once("tensorboard.logging.add_image") 624 self._get_file_writer().add_summary( 625 image(tag, img_tensor, dataformats=dataformats), global_step, walltime 626 ) 627 628 def add_images( 629 self, tag, img_tensor, global_step=None, walltime=None, dataformats="NCHW" 630 ): 631 """Add batched image data to summary. 632 633 Note that this requires the ``pillow`` package. 634 635 Args: 636 tag (str): Data identifier 637 img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data 638 global_step (int): Global step value to record 639 walltime (float): Optional override default walltime (time.time()) 640 seconds after epoch of event 641 dataformats (str): Image data format specification of the form 642 NCHW, NHWC, CHW, HWC, HW, WH, etc. 643 Shape: 644 img_tensor: Default is :math:`(N, 3, H, W)`. If ``dataformats`` is specified, other shape will be 645 accepted. e.g. NCHW or NHWC. 646 647 Examples:: 648 649 from torch.utils.tensorboard import SummaryWriter 650 import numpy as np 651 652 img_batch = np.zeros((16, 3, 100, 100)) 653 for i in range(16): 654 img_batch[i, 0] = np.arange(0, 10000).reshape(100, 100) / 10000 / 16 * i 655 img_batch[i, 1] = (1 - np.arange(0, 10000).reshape(100, 100) / 10000) / 16 * i 656 657 writer = SummaryWriter() 658 writer.add_images('my_image_batch', img_batch, 0) 659 writer.close() 660 661 Expected result: 662 663 .. image:: _static/img/tensorboard/add_images.png 664 :scale: 30 % 665 666 """ 667 torch._C._log_api_usage_once("tensorboard.logging.add_images") 668 self._get_file_writer().add_summary( 669 image(tag, img_tensor, dataformats=dataformats), global_step, walltime 670 ) 671 672 def add_image_with_boxes( 673 self, 674 tag, 675 img_tensor, 676 box_tensor, 677 global_step=None, 678 walltime=None, 679 rescale=1, 680 dataformats="CHW", 681 labels=None, 682 ): 683 """Add image and draw bounding boxes on the image. 684 685 Args: 686 tag (str): Data identifier 687 img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data 688 box_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Box data (for detected objects) 689 box should be represented as [x1, y1, x2, y2]. 690 global_step (int): Global step value to record 691 walltime (float): Optional override default walltime (time.time()) 692 seconds after epoch of event 693 rescale (float): Optional scale override 694 dataformats (str): Image data format specification of the form 695 NCHW, NHWC, CHW, HWC, HW, WH, etc. 696 labels (list of string): The label to be shown for each bounding box. 697 Shape: 698 img_tensor: Default is :math:`(3, H, W)`. It can be specified with ``dataformats`` argument. 699 e.g. CHW or HWC 700 701 box_tensor: (torch.Tensor, numpy.ndarray, or string/blobname): NX4, where N is the number of 702 boxes and each 4 elements in a row represents (xmin, ymin, xmax, ymax). 703 """ 704 torch._C._log_api_usage_once("tensorboard.logging.add_image_with_boxes") 705 if labels is not None: 706 if isinstance(labels, str): 707 labels = [labels] 708 if len(labels) != box_tensor.shape[0]: 709 labels = None 710 self._get_file_writer().add_summary( 711 image_boxes( 712 tag, 713 img_tensor, 714 box_tensor, 715 rescale=rescale, 716 dataformats=dataformats, 717 labels=labels, 718 ), 719 global_step, 720 walltime, 721 ) 722 723 def add_figure( 724 self, 725 tag: str, 726 figure: Union["Figure", List["Figure"]], 727 global_step: Optional[int] = None, 728 close: bool = True, 729 walltime: Optional[float] = None, 730 ) -> None: 731 """Render matplotlib figure into an image and add it to summary. 732 733 Note that this requires the ``matplotlib`` package. 734 735 Args: 736 tag: Data identifier 737 figure: Figure or a list of figures 738 global_step: Global step value to record 739 close: Flag to automatically close the figure 740 walltime: Optional override default walltime (time.time()) 741 seconds after epoch of event 742 """ 743 torch._C._log_api_usage_once("tensorboard.logging.add_figure") 744 if isinstance(figure, list): 745 self.add_image( 746 tag, 747 figure_to_image(figure, close), 748 global_step, 749 walltime, 750 dataformats="NCHW", 751 ) 752 else: 753 self.add_image( 754 tag, 755 figure_to_image(figure, close), 756 global_step, 757 walltime, 758 dataformats="CHW", 759 ) 760 761 def add_video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None): 762 """Add video data to summary. 763 764 Note that this requires the ``moviepy`` package. 765 766 Args: 767 tag (str): Data identifier 768 vid_tensor (torch.Tensor): Video data 769 global_step (int): Global step value to record 770 fps (float or int): Frames per second 771 walltime (float): Optional override default walltime (time.time()) 772 seconds after epoch of event 773 Shape: 774 vid_tensor: :math:`(N, T, C, H, W)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. 775 """ 776 torch._C._log_api_usage_once("tensorboard.logging.add_video") 777 self._get_file_writer().add_summary( 778 video(tag, vid_tensor, fps), global_step, walltime 779 ) 780 781 def add_audio( 782 self, tag, snd_tensor, global_step=None, sample_rate=44100, walltime=None 783 ): 784 """Add audio data to summary. 785 786 Args: 787 tag (str): Data identifier 788 snd_tensor (torch.Tensor): Sound data 789 global_step (int): Global step value to record 790 sample_rate (int): sample rate in Hz 791 walltime (float): Optional override default walltime (time.time()) 792 seconds after epoch of event 793 Shape: 794 snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1]. 795 """ 796 torch._C._log_api_usage_once("tensorboard.logging.add_audio") 797 self._get_file_writer().add_summary( 798 audio(tag, snd_tensor, sample_rate=sample_rate), global_step, walltime 799 ) 800 801 def add_text(self, tag, text_string, global_step=None, walltime=None): 802 """Add text data to summary. 803 804 Args: 805 tag (str): Data identifier 806 text_string (str): String to save 807 global_step (int): Global step value to record 808 walltime (float): Optional override default walltime (time.time()) 809 seconds after epoch of event 810 Examples:: 811 812 writer.add_text('lstm', 'This is an lstm', 0) 813 writer.add_text('rnn', 'This is an rnn', 10) 814 """ 815 torch._C._log_api_usage_once("tensorboard.logging.add_text") 816 self._get_file_writer().add_summary( 817 text(tag, text_string), global_step, walltime 818 ) 819 820 def add_onnx_graph(self, prototxt): 821 torch._C._log_api_usage_once("tensorboard.logging.add_onnx_graph") 822 self._get_file_writer().add_onnx_graph(load_onnx_graph(prototxt)) 823 824 def add_graph( 825 self, model, input_to_model=None, verbose=False, use_strict_trace=True 826 ): 827 """Add graph data to summary. 828 829 Args: 830 model (torch.nn.Module): Model to draw. 831 input_to_model (torch.Tensor or list of torch.Tensor): A variable or a tuple of 832 variables to be fed. 833 verbose (bool): Whether to print graph structure in console. 834 use_strict_trace (bool): Whether to pass keyword argument `strict` to 835 `torch.jit.trace`. Pass False when you want the tracer to 836 record your mutable container types (list, dict) 837 """ 838 torch._C._log_api_usage_once("tensorboard.logging.add_graph") 839 # A valid PyTorch model should have a 'forward' method 840 self._get_file_writer().add_graph( 841 graph(model, input_to_model, verbose, use_strict_trace) 842 ) 843 844 @staticmethod 845 def _encode(rawstr): 846 # I'd use urllib but, I'm unsure about the differences from python3 to python2, etc. 847 retval = rawstr 848 retval = retval.replace("%", f"%{ord('%'):02x}") 849 retval = retval.replace("/", f"%{ord('/'):02x}") 850 retval = retval.replace("\\", "%%%02x" % (ord("\\"))) # noqa: UP031 851 return retval 852 853 def add_embedding( 854 self, 855 mat, 856 metadata=None, 857 label_img=None, 858 global_step=None, 859 tag="default", 860 metadata_header=None, 861 ): 862 """Add embedding projector data to summary. 863 864 Args: 865 mat (torch.Tensor or numpy.ndarray): A matrix which each row is the feature vector of the data point 866 metadata (list): A list of labels, each element will be converted to string 867 label_img (torch.Tensor): Images correspond to each data point 868 global_step (int): Global step value to record 869 tag (str): Name for the embedding 870 metadata_header (list): A list of headers for multi-column metadata. If given, each metadata must be 871 a list with values corresponding to headers. 872 Shape: 873 mat: :math:`(N, D)`, where N is number of data and D is feature dimension 874 875 label_img: :math:`(N, C, H, W)` 876 877 Examples:: 878 879 import keyword 880 import torch 881 meta = [] 882 while len(meta)<100: 883 meta = meta+keyword.kwlist # get some strings 884 meta = meta[:100] 885 886 for i, v in enumerate(meta): 887 meta[i] = v+str(i) 888 889 label_img = torch.rand(100, 3, 10, 32) 890 for i in range(100): 891 label_img[i]*=i/100.0 892 893 writer.add_embedding(torch.randn(100, 5), metadata=meta, label_img=label_img) 894 writer.add_embedding(torch.randn(100, 5), label_img=label_img) 895 writer.add_embedding(torch.randn(100, 5), metadata=meta) 896 897 .. note:: 898 Categorical (i.e. non-numeric) metadata cannot have more than 50 unique values if they are to be used for 899 coloring in the embedding projector. 900 901 """ 902 torch._C._log_api_usage_once("tensorboard.logging.add_embedding") 903 mat = make_np(mat) 904 if global_step is None: 905 global_step = 0 906 # clear pbtxt? 907 908 # Maybe we should encode the tag so slashes don't trip us up? 909 # I don't think this will mess us up, but better safe than sorry. 910 subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}" 911 save_path = os.path.join(self._get_file_writer().get_logdir(), subdir) 912 913 fs = tf.io.gfile 914 if fs.exists(save_path): 915 if fs.isdir(save_path): 916 print( 917 "warning: Embedding dir exists, did you set global_step for add_embedding()?" 918 ) 919 else: 920 raise NotADirectoryError( 921 f"Path: `{save_path}` exists, but is a file. Cannot proceed." 922 ) 923 else: 924 fs.makedirs(save_path) 925 926 if metadata is not None: 927 assert mat.shape[0] == len( 928 metadata 929 ), "#labels should equal with #data points" 930 make_tsv(metadata, save_path, metadata_header=metadata_header) 931 932 if label_img is not None: 933 assert ( 934 mat.shape[0] == label_img.shape[0] 935 ), "#images should equal with #data points" 936 make_sprite(label_img, save_path) 937 938 assert ( 939 mat.ndim == 2 940 ), "mat should be 2D, where mat.size(0) is the number of data points" 941 make_mat(mat, save_path) 942 943 # Filesystem doesn't necessarily have append semantics, so we store an 944 # internal buffer to append to and re-write whole file after each 945 # embedding is added 946 if not hasattr(self, "_projector_config"): 947 self._projector_config = ProjectorConfig() 948 embedding_info = get_embedding_info( 949 metadata, label_img, subdir, global_step, tag 950 ) 951 self._projector_config.embeddings.extend([embedding_info]) 952 953 from google.protobuf import text_format 954 955 config_pbtxt = text_format.MessageToString(self._projector_config) 956 write_pbtxt(self._get_file_writer().get_logdir(), config_pbtxt) 957 958 def add_pr_curve( 959 self, 960 tag, 961 labels, 962 predictions, 963 global_step=None, 964 num_thresholds=127, 965 weights=None, 966 walltime=None, 967 ): 968 """Add precision recall curve. 969 970 Plotting a precision-recall curve lets you understand your model's 971 performance under different threshold settings. With this function, 972 you provide the ground truth labeling (T/F) and prediction confidence 973 (usually the output of your model) for each target. The TensorBoard UI 974 will let you choose the threshold interactively. 975 976 Args: 977 tag (str): Data identifier 978 labels (torch.Tensor, numpy.ndarray, or string/blobname): 979 Ground truth data. Binary label for each element. 980 predictions (torch.Tensor, numpy.ndarray, or string/blobname): 981 The probability that an element be classified as true. 982 Value should be in [0, 1] 983 global_step (int): Global step value to record 984 num_thresholds (int): Number of thresholds used to draw the curve. 985 walltime (float): Optional override default walltime (time.time()) 986 seconds after epoch of event 987 988 Examples:: 989 990 from torch.utils.tensorboard import SummaryWriter 991 import numpy as np 992 labels = np.random.randint(2, size=100) # binary label 993 predictions = np.random.rand(100) 994 writer = SummaryWriter() 995 writer.add_pr_curve('pr_curve', labels, predictions, 0) 996 writer.close() 997 998 """ 999 torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve") 1000 labels, predictions = make_np(labels), make_np(predictions) 1001 self._get_file_writer().add_summary( 1002 pr_curve(tag, labels, predictions, num_thresholds, weights), 1003 global_step, 1004 walltime, 1005 ) 1006 1007 def add_pr_curve_raw( 1008 self, 1009 tag, 1010 true_positive_counts, 1011 false_positive_counts, 1012 true_negative_counts, 1013 false_negative_counts, 1014 precision, 1015 recall, 1016 global_step=None, 1017 num_thresholds=127, 1018 weights=None, 1019 walltime=None, 1020 ): 1021 """Add precision recall curve with raw data. 1022 1023 Args: 1024 tag (str): Data identifier 1025 true_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): true positive counts 1026 false_positive_counts (torch.Tensor, numpy.ndarray, or string/blobname): false positive counts 1027 true_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): true negative counts 1028 false_negative_counts (torch.Tensor, numpy.ndarray, or string/blobname): false negative counts 1029 precision (torch.Tensor, numpy.ndarray, or string/blobname): precision 1030 recall (torch.Tensor, numpy.ndarray, or string/blobname): recall 1031 global_step (int): Global step value to record 1032 num_thresholds (int): Number of thresholds used to draw the curve. 1033 walltime (float): Optional override default walltime (time.time()) 1034 seconds after epoch of event 1035 see: https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/README.md 1036 """ 1037 torch._C._log_api_usage_once("tensorboard.logging.add_pr_curve_raw") 1038 self._get_file_writer().add_summary( 1039 pr_curve_raw( 1040 tag, 1041 true_positive_counts, 1042 false_positive_counts, 1043 true_negative_counts, 1044 false_negative_counts, 1045 precision, 1046 recall, 1047 num_thresholds, 1048 weights, 1049 ), 1050 global_step, 1051 walltime, 1052 ) 1053 1054 def add_custom_scalars_multilinechart( 1055 self, tags, category="default", title="untitled" 1056 ): 1057 """Shorthand for creating multilinechart. Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*. 1058 1059 Args: 1060 tags (list): list of tags that have been used in ``add_scalar()`` 1061 1062 Examples:: 1063 1064 writer.add_custom_scalars_multilinechart(['twse/0050', 'twse/2330']) 1065 """ 1066 torch._C._log_api_usage_once( 1067 "tensorboard.logging.add_custom_scalars_multilinechart" 1068 ) 1069 layout = {category: {title: ["Multiline", tags]}} 1070 self._get_file_writer().add_summary(custom_scalars(layout)) 1071 1072 def add_custom_scalars_marginchart( 1073 self, tags, category="default", title="untitled" 1074 ): 1075 """Shorthand for creating marginchart. 1076 1077 Similar to ``add_custom_scalars()``, but the only necessary argument is *tags*, 1078 which should have exactly 3 elements. 1079 1080 Args: 1081 tags (list): list of tags that have been used in ``add_scalar()`` 1082 1083 Examples:: 1084 1085 writer.add_custom_scalars_marginchart(['twse/0050', 'twse/2330', 'twse/2006']) 1086 """ 1087 torch._C._log_api_usage_once( 1088 "tensorboard.logging.add_custom_scalars_marginchart" 1089 ) 1090 assert len(tags) == 3 1091 layout = {category: {title: ["Margin", tags]}} 1092 self._get_file_writer().add_summary(custom_scalars(layout)) 1093 1094 def add_custom_scalars(self, layout): 1095 """Create special chart by collecting charts tags in 'scalars'. 1096 1097 NOTE: This function can only be called once for each SummaryWriter() object. 1098 1099 Because it only provides metadata to tensorboard, the function can be called before or after the training loop. 1100 1101 Args: 1102 layout (dict): {categoryName: *charts*}, where *charts* is also a dictionary 1103 {chartName: *ListOfProperties*}. The first element in *ListOfProperties* is the chart's type 1104 (one of **Multiline** or **Margin**) and the second element should be a list containing the tags 1105 you have used in add_scalar function, which will be collected into the new chart. 1106 1107 Examples:: 1108 1109 layout = {'Taiwan':{'twse':['Multiline',['twse/0050', 'twse/2330']]}, 1110 'USA':{ 'dow':['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], 1111 'nasdaq':['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} 1112 1113 writer.add_custom_scalars(layout) 1114 """ 1115 torch._C._log_api_usage_once("tensorboard.logging.add_custom_scalars") 1116 self._get_file_writer().add_summary(custom_scalars(layout)) 1117 1118 def add_mesh( 1119 self, 1120 tag, 1121 vertices, 1122 colors=None, 1123 faces=None, 1124 config_dict=None, 1125 global_step=None, 1126 walltime=None, 1127 ): 1128 """Add meshes or 3D point clouds to TensorBoard. 1129 1130 The visualization is based on Three.js, 1131 so it allows users to interact with the rendered object. Besides the basic definitions 1132 such as vertices, faces, users can further provide camera parameter, lighting condition, etc. 1133 Please check https://threejs.org/docs/index.html#manual/en/introduction/Creating-a-scene for 1134 advanced usage. 1135 1136 Args: 1137 tag (str): Data identifier 1138 vertices (torch.Tensor): List of the 3D coordinates of vertices. 1139 colors (torch.Tensor): Colors for each vertex 1140 faces (torch.Tensor): Indices of vertices within each triangle. (Optional) 1141 config_dict: Dictionary with ThreeJS classes names and configuration. 1142 global_step (int): Global step value to record 1143 walltime (float): Optional override default walltime (time.time()) 1144 seconds after epoch of event 1145 1146 Shape: 1147 vertices: :math:`(B, N, 3)`. (batch, number_of_vertices, channels) 1148 1149 colors: :math:`(B, N, 3)`. The values should lie in [0, 255] for type `uint8` or [0, 1] for type `float`. 1150 1151 faces: :math:`(B, N, 3)`. The values should lie in [0, number_of_vertices] for type `uint8`. 1152 1153 Examples:: 1154 1155 from torch.utils.tensorboard import SummaryWriter 1156 vertices_tensor = torch.as_tensor([ 1157 [1, 1, 1], 1158 [-1, -1, 1], 1159 [1, -1, -1], 1160 [-1, 1, -1], 1161 ], dtype=torch.float).unsqueeze(0) 1162 colors_tensor = torch.as_tensor([ 1163 [255, 0, 0], 1164 [0, 255, 0], 1165 [0, 0, 255], 1166 [255, 0, 255], 1167 ], dtype=torch.int).unsqueeze(0) 1168 faces_tensor = torch.as_tensor([ 1169 [0, 2, 3], 1170 [0, 3, 1], 1171 [0, 1, 2], 1172 [1, 3, 2], 1173 ], dtype=torch.int).unsqueeze(0) 1174 1175 writer = SummaryWriter() 1176 writer.add_mesh('my_mesh', vertices=vertices_tensor, colors=colors_tensor, faces=faces_tensor) 1177 1178 writer.close() 1179 """ 1180 torch._C._log_api_usage_once("tensorboard.logging.add_mesh") 1181 self._get_file_writer().add_summary( 1182 mesh(tag, vertices, colors, faces, config_dict), global_step, walltime 1183 ) 1184 1185 def flush(self): 1186 """Flushes the event file to disk. 1187 1188 Call this method to make sure that all pending events have been written to 1189 disk. 1190 """ 1191 if self.all_writers is None: 1192 return 1193 for writer in self.all_writers.values(): 1194 writer.flush() 1195 1196 def close(self): 1197 if self.all_writers is None: 1198 return # ignore double close 1199 for writer in self.all_writers.values(): 1200 writer.flush() 1201 writer.close() 1202 self.file_writer = self.all_writers = None 1203 1204 def __enter__(self): 1205 return self 1206 1207 def __exit__(self, exc_type, exc_val, exc_tb): 1208 self.close() 1209