1# mypy: allow-untyped-defs 2r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter. 3 4To support these two classes, in `./_utils` we define many utility methods and 5functions to be run in multiprocessing. E.g., the data loading worker loop is 6in `./_utils/worker.py`. 7""" 8 9import functools 10import itertools 11import logging 12import multiprocessing as python_multiprocessing 13import os 14import queue 15import threading 16import warnings 17from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union 18 19import torch 20import torch.distributed as dist 21import torch.utils.data.graph_settings 22from torch._utils import ExceptionWrapper 23from torch.utils.data import _utils 24from torch.utils.data.datapipes.datapipe import ( 25 _IterDataPipeSerializationWrapper, 26 _MapDataPipeSerializationWrapper, 27 IterDataPipe, 28 MapDataPipe, 29) 30from torch.utils.data.dataset import Dataset, IterableDataset 31from torch.utils.data.sampler import ( 32 BatchSampler, 33 RandomSampler, 34 Sampler, 35 SequentialSampler, 36) 37 38 39__all__ = [ 40 "DataLoader", 41 "get_worker_info", 42 "default_collate", 43 "default_convert", 44] 45 46 47_T = TypeVar("_T") 48_T_co = TypeVar("_T_co", covariant=True) 49_worker_init_fn_t = Callable[[int], None] 50 51# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that 52# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'. 53# See https://github.com/python/mypy/issues/3737. 54_collate_fn_t = Callable[[List[_T]], Any] 55 56 57# These functions used to be defined in this file. However, it was moved to 58# _utils/collate.py. Although it is rather hard to access this from user land 59# (one has to explicitly directly `import torch.utils.data.dataloader`), there 60# probably is user code out there using it. This aliasing maintains BC in this 61# aspect. 62default_collate: _collate_fn_t = _utils.collate.default_collate 63default_convert = _utils.collate.default_convert 64 65get_worker_info = _utils.worker.get_worker_info 66 67logger = logging.getLogger(__name__) 68 69 70class _DatasetKind: 71 Map = 0 72 Iterable = 1 73 74 @staticmethod 75 def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last): 76 if kind == _DatasetKind.Map: 77 return _utils.fetch._MapDatasetFetcher( 78 dataset, auto_collation, collate_fn, drop_last 79 ) 80 else: 81 return _utils.fetch._IterableDatasetFetcher( 82 dataset, auto_collation, collate_fn, drop_last 83 ) 84 85 86class _InfiniteConstantSampler(Sampler): 87 r"""Analogous to ``itertools.repeat(None, None)``. 88 89 Used as sampler for :class:`~torch.utils.data.IterableDataset`. 90 """ 91 92 def __iter__(self): 93 while True: 94 yield None 95 96 97def _get_distributed_settings(): 98 if dist.is_available() and dist.is_initialized(): 99 return dist.get_world_size(), dist.get_rank() 100 else: 101 return 1, 0 102 103 104def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id): 105 global_worker_id = worker_id 106 info = torch.utils.data.get_worker_info() 107 assert info is not None 108 total_workers = info.num_workers 109 datapipe = info.dataset 110 assert isinstance(datapipe, (IterDataPipe, MapDataPipe)) 111 # To distribute elements across distributed process evenly, we should shard data on distributed 112 # processes first then shard on worker processes 113 total_workers *= world_size 114 global_worker_id = global_worker_id * world_size + rank_id 115 # For BC, use default SHARDING_PRIORITIES 116 torch.utils.data.graph_settings.apply_sharding( 117 datapipe, total_workers, global_worker_id 118 ) 119 if worker_init_fn is not None: 120 worker_init_fn(worker_id) 121 122 123def _share_dist_seed(generator, pg): 124 _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator) 125 if isinstance(pg, dist.ProcessGroup): 126 dist.broadcast(_shared_seed, src=0, group=pg) 127 return _shared_seed.item() 128 129 130class DataLoader(Generic[_T_co]): 131 r""" 132 Data loader combines a dataset and a sampler, and provides an iterable over the given dataset. 133 134 The :class:`~torch.utils.data.DataLoader` supports both map-style and 135 iterable-style datasets with single- or multi-process loading, customizing 136 loading order and optional automatic batching (collation) and memory pinning. 137 138 See :py:mod:`torch.utils.data` documentation page for more details. 139 140 Args: 141 dataset (Dataset): dataset from which to load the data. 142 batch_size (int, optional): how many samples per batch to load 143 (default: ``1``). 144 shuffle (bool, optional): set to ``True`` to have the data reshuffled 145 at every epoch (default: ``False``). 146 sampler (Sampler or Iterable, optional): defines the strategy to draw 147 samples from the dataset. Can be any ``Iterable`` with ``__len__`` 148 implemented. If specified, :attr:`shuffle` must not be specified. 149 batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but 150 returns a batch of indices at a time. Mutually exclusive with 151 :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`, 152 and :attr:`drop_last`. 153 num_workers (int, optional): how many subprocesses to use for data 154 loading. ``0`` means that the data will be loaded in the main process. 155 (default: ``0``) 156 collate_fn (Callable, optional): merges a list of samples to form a 157 mini-batch of Tensor(s). Used when using batched loading from a 158 map-style dataset. 159 pin_memory (bool, optional): If ``True``, the data loader will copy Tensors 160 into device/CUDA pinned memory before returning them. If your data elements 161 are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, 162 see the example below. 163 drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 164 if the dataset size is not divisible by the batch size. If ``False`` and 165 the size of dataset is not divisible by the batch size, then the last batch 166 will be smaller. (default: ``False``) 167 timeout (numeric, optional): if positive, the timeout value for collecting a batch 168 from workers. Should always be non-negative. (default: ``0``) 169 worker_init_fn (Callable, optional): If not ``None``, this will be called on each 170 worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as 171 input, after seeding and before data loading. (default: ``None``) 172 multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If 173 ``None``, the default `multiprocessing context`_ of your operating system will 174 be used. (default: ``None``) 175 generator (torch.Generator, optional): If not ``None``, this RNG will be used 176 by RandomSampler to generate random indexes and multiprocessing to generate 177 ``base_seed`` for workers. (default: ``None``) 178 prefetch_factor (int, optional, keyword-only arg): Number of batches loaded 179 in advance by each worker. ``2`` means there will be a total of 180 2 * num_workers batches prefetched across all workers. (default value depends 181 on the set value for num_workers. If value of num_workers=0 default is ``None``. 182 Otherwise, if value of ``num_workers > 0`` default is ``2``). 183 persistent_workers (bool, optional): If ``True``, the data loader will not shut down 184 the worker processes after a dataset has been consumed once. This allows to 185 maintain the workers `Dataset` instances alive. (default: ``False``) 186 pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is 187 ``True``. 188 189 190 .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn` 191 cannot be an unpicklable object, e.g., a lambda function. See 192 :ref:`multiprocessing-best-practices` on more details related 193 to multiprocessing in PyTorch. 194 195 .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used. 196 When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`, 197 it instead returns an estimate based on ``len(dataset) / batch_size``, with proper 198 rounding depending on :attr:`drop_last`, regardless of multi-process loading 199 configurations. This represents the best guess PyTorch can make because PyTorch 200 trusts user :attr:`dataset` code in correctly handling multi-process 201 loading to avoid duplicate data. 202 203 However, if sharding results in multiple workers having incomplete last batches, 204 this estimate can still be inaccurate, because (1) an otherwise complete batch can 205 be broken into multiple ones and (2) more than one batch worth of samples can be 206 dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such 207 cases in general. 208 209 See `Dataset Types`_ for more details on these two types of datasets and how 210 :class:`~torch.utils.data.IterableDataset` interacts with 211 `Multi-process data loading`_. 212 213 .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and 214 :ref:`data-loading-randomness` notes for random seed related questions. 215 216 .. _multiprocessing context: 217 https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods 218 """ 219 220 dataset: Dataset[_T_co] 221 batch_size: Optional[int] 222 num_workers: int 223 pin_memory: bool 224 drop_last: bool 225 timeout: float 226 sampler: Union[Sampler, Iterable] 227 pin_memory_device: str 228 prefetch_factor: Optional[int] 229 _iterator: Optional["_BaseDataLoaderIter"] 230 __initialized = False 231 232 def __init__( 233 self, 234 dataset: Dataset[_T_co], 235 batch_size: Optional[int] = 1, 236 shuffle: Optional[bool] = None, 237 sampler: Union[Sampler, Iterable, None] = None, 238 batch_sampler: Union[Sampler[List], Iterable[List], None] = None, 239 num_workers: int = 0, 240 collate_fn: Optional[_collate_fn_t] = None, 241 pin_memory: bool = False, 242 drop_last: bool = False, 243 timeout: float = 0, 244 worker_init_fn: Optional[_worker_init_fn_t] = None, 245 multiprocessing_context=None, 246 generator=None, 247 *, 248 prefetch_factor: Optional[int] = None, 249 persistent_workers: bool = False, 250 pin_memory_device: str = "", 251 ): 252 torch._C._log_api_usage_once("python.data_loader") 253 254 if num_workers < 0: 255 raise ValueError( 256 "num_workers option should be non-negative; " 257 "use num_workers=0 to disable multiprocessing." 258 ) 259 260 if timeout < 0: 261 raise ValueError("timeout option should be non-negative") 262 263 if num_workers == 0 and prefetch_factor is not None: 264 raise ValueError( 265 "prefetch_factor option could only be specified in multiprocessing." 266 "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None." 267 ) 268 elif num_workers > 0 and prefetch_factor is None: 269 prefetch_factor = 2 270 elif prefetch_factor is not None and prefetch_factor < 0: 271 raise ValueError("prefetch_factor option should be non-negative") 272 273 if persistent_workers and num_workers == 0: 274 raise ValueError("persistent_workers option needs num_workers > 0") 275 276 self.dataset = dataset 277 self.num_workers = num_workers 278 self.prefetch_factor = prefetch_factor 279 self.pin_memory = pin_memory 280 self.pin_memory_device = pin_memory_device 281 self.timeout = timeout 282 self.worker_init_fn = worker_init_fn 283 self.multiprocessing_context = multiprocessing_context 284 285 # Adds forward compatibilities so classic DataLoader can work with DataPipes: 286 # _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler 287 if isinstance(self.dataset, IterDataPipe): 288 self.dataset = _IterDataPipeSerializationWrapper(self.dataset) 289 elif isinstance(self.dataset, MapDataPipe): 290 self.dataset = _MapDataPipeSerializationWrapper(self.dataset) 291 292 # Arg-check dataset related before checking samplers because we want to 293 # tell users that iterable-style datasets are incompatible with custom 294 # samplers first, so that they don't learn that this combo doesn't work 295 # after spending time fixing the custom sampler errors. 296 if isinstance(dataset, IterableDataset): 297 self._dataset_kind = _DatasetKind.Iterable 298 # NOTE [ Custom Samplers and IterableDataset ] 299 # 300 # `IterableDataset` does not support custom `batch_sampler` or 301 # `sampler` since the key is irrelevant (unless we support 302 # generator-style dataset one day...). 303 # 304 # For `sampler`, we always create a dummy sampler. This is an 305 # infinite sampler even when the dataset may have an implemented 306 # finite `__len__` because in multi-process data loading, naive 307 # settings will return duplicated data (which may be desired), and 308 # thus using a sampler with length matching that of dataset will 309 # cause data lost (you may have duplicates of the first couple 310 # batches, but never see anything afterwards). Therefore, 311 # `Iterabledataset` always uses an infinite sampler, an instance of 312 # `_InfiniteConstantSampler` defined above. 313 # 314 # A custom `batch_sampler` essentially only controls the batch size. 315 # However, it is unclear how useful it would be since an iterable-style 316 # dataset can handle that within itself. Moreover, it is pointless 317 # in multi-process data loading as the assignment order of batches 318 # to workers is an implementation detail so users can not control 319 # how to batchify each worker's iterable. Thus, we disable this 320 # option. If this turns out to be useful in future, we can re-enable 321 # this, and support custom samplers that specify the assignments to 322 # specific workers. 323 if isinstance(dataset, IterDataPipe): 324 if shuffle is not None: 325 dataset = torch.utils.data.graph_settings.apply_shuffle_settings( 326 dataset, shuffle=shuffle 327 ) 328 # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default. 329 elif shuffle not in {False, None}: 330 raise ValueError( 331 f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}" 332 ) 333 334 if sampler is not None: 335 # See NOTE [ Custom Samplers and IterableDataset ] 336 raise ValueError( 337 f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}" 338 ) 339 elif batch_sampler is not None: 340 # See NOTE [ Custom Samplers and IterableDataset ] 341 raise ValueError( 342 "DataLoader with IterableDataset: expected unspecified " 343 f"batch_sampler option, but got batch_sampler={batch_sampler}" 344 ) 345 else: 346 shuffle = bool(shuffle) 347 self._dataset_kind = _DatasetKind.Map 348 349 if sampler is not None and shuffle: 350 raise ValueError("sampler option is mutually exclusive with " "shuffle") 351 352 if batch_sampler is not None: 353 # auto_collation with custom batch_sampler 354 if batch_size != 1 or shuffle or sampler is not None or drop_last: 355 raise ValueError( 356 "batch_sampler option is mutually exclusive " 357 "with batch_size, shuffle, sampler, and " 358 "drop_last" 359 ) 360 batch_size = None 361 drop_last = False 362 elif batch_size is None: 363 # no auto_collation 364 if drop_last: 365 raise ValueError( 366 "batch_size=None option disables auto-batching " 367 "and is mutually exclusive with drop_last" 368 ) 369 370 if sampler is None: # give default samplers 371 if self._dataset_kind == _DatasetKind.Iterable: 372 # See NOTE [ Custom Samplers and IterableDataset ] 373 sampler = _InfiniteConstantSampler() 374 else: # map-style 375 if shuffle: 376 sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] 377 else: 378 sampler = SequentialSampler(dataset) # type: ignore[arg-type] 379 380 if batch_size is not None and batch_sampler is None: 381 # auto_collation without custom batch_sampler 382 batch_sampler = BatchSampler(sampler, batch_size, drop_last) 383 384 self.batch_size = batch_size 385 self.drop_last = drop_last 386 self.sampler = sampler 387 self.batch_sampler = batch_sampler 388 self.generator = generator 389 390 if collate_fn is None: 391 if self._auto_collation: 392 collate_fn = _utils.collate.default_collate 393 else: 394 collate_fn = _utils.collate.default_convert 395 396 self.collate_fn = collate_fn 397 self.persistent_workers = persistent_workers 398 399 self.__initialized = True 400 self._IterableDataset_len_called = ( 401 None # See NOTE [ IterableDataset and __len__ ] 402 ) 403 404 self._iterator = None 405 406 self.check_worker_number_rationality() 407 408 torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] 409 410 def _get_iterator(self) -> "_BaseDataLoaderIter": 411 if self.num_workers == 0: 412 return _SingleProcessDataLoaderIter(self) 413 else: 414 self.check_worker_number_rationality() 415 return _MultiProcessingDataLoaderIter(self) 416 417 @property 418 def multiprocessing_context(self): 419 return self.__multiprocessing_context 420 421 @multiprocessing_context.setter 422 def multiprocessing_context(self, multiprocessing_context): 423 if multiprocessing_context is not None: 424 if self.num_workers > 0: 425 if isinstance(multiprocessing_context, str): 426 valid_start_methods = torch.multiprocessing.get_all_start_methods() 427 if multiprocessing_context not in valid_start_methods: 428 raise ValueError( 429 "multiprocessing_context option " 430 f"should specify a valid start method in {valid_start_methods!r}, but got " 431 f"multiprocessing_context={multiprocessing_context!r}" 432 ) 433 multiprocessing_context = torch.multiprocessing.get_context( 434 multiprocessing_context 435 ) 436 437 if not isinstance( 438 multiprocessing_context, python_multiprocessing.context.BaseContext 439 ): 440 raise TypeError( 441 "multiprocessing_context option should be a valid context " 442 "object or a string specifying the start method, but got " 443 f"multiprocessing_context={multiprocessing_context}" 444 ) 445 else: 446 raise ValueError( 447 "multiprocessing_context can only be used with " 448 "multi-process loading (num_workers > 0), but got " 449 f"num_workers={self.num_workers}" 450 ) 451 452 self.__multiprocessing_context = multiprocessing_context 453 454 def __setattr__(self, attr, val): 455 if self.__initialized and attr in ( 456 "batch_size", 457 "batch_sampler", 458 "sampler", 459 "drop_last", 460 "dataset", 461 "persistent_workers", 462 ): 463 raise ValueError( 464 f"{attr} attribute should not be set after {self.__class__.__name__} is initialized" 465 ) 466 467 super().__setattr__(attr, val) 468 469 # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up 470 # since '_BaseDataLoaderIter' references 'DataLoader'. 471 def __iter__(self) -> "_BaseDataLoaderIter": 472 # When using a single worker the returned iterator should be 473 # created everytime to avoid resetting its state 474 # However, in the case of a multiple workers iterator 475 # the iterator is only created once in the lifetime of the 476 # DataLoader object so that workers can be reused 477 if self.persistent_workers and self.num_workers > 0: 478 if self._iterator is None: 479 self._iterator = self._get_iterator() 480 else: 481 self._iterator._reset(self) 482 return self._iterator 483 else: 484 return self._get_iterator() 485 486 @property 487 def _auto_collation(self): 488 return self.batch_sampler is not None 489 490 @property 491 def _index_sampler(self): 492 # The actual sampler used for generating indices for `_DatasetFetcher` 493 # (see _utils/fetch.py) to read data at each time. This would be 494 # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise. 495 # We can't change `.sampler` and `.batch_sampler` attributes for BC 496 # reasons. 497 if self._auto_collation: 498 return self.batch_sampler 499 else: 500 return self.sampler 501 502 def __len__(self) -> int: 503 if self._dataset_kind == _DatasetKind.Iterable: 504 # NOTE [ IterableDataset and __len__ ] 505 # 506 # For `IterableDataset`, `__len__` could be inaccurate when one naively 507 # does multi-processing data loading, since the samples will be duplicated. 508 # However, no real use case should be actually using that behavior, so 509 # it should count as a user error. We should generally trust user 510 # code to do the proper thing (e.g., configure each replica differently 511 # in `__iter__`), and give us the correct `__len__` if they choose to 512 # implement it (this will still throw if the dataset does not implement 513 # a `__len__`). 514 # 515 # To provide a further warning, we track if `__len__` was called on the 516 # `DataLoader`, save the returned value in `self._len_called`, and warn 517 # if the iterator ends up yielding more than this number of samples. 518 519 # Cannot statically verify that dataset is Sized 520 length = self._IterableDataset_len_called = len(self.dataset) # type: ignore[assignment, arg-type] 521 if ( 522 self.batch_size is not None 523 ): # IterableDataset doesn't allow custom sampler or batch_sampler 524 from math import ceil 525 526 if self.drop_last: 527 length = length // self.batch_size 528 else: 529 length = ceil(length / self.batch_size) 530 return length 531 else: 532 return len(self._index_sampler) 533 534 def check_worker_number_rationality(self): 535 # This function check whether the dataloader's worker number is rational based on 536 # current system's resource. Current rule is that if the number of workers this 537 # Dataloader will create is bigger than the number of logical cpus that is allowed to 538 # use, than we will pop up a warning to let user pay attention. 539 # 540 # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2 541 # threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current 542 # DataLoader process can use half of them which is 32, then the rational max number of 543 # worker that initiated from this process is 32. 544 # Now, let's say the created DataLoader has num_works = 40, which is bigger than 32. 545 # So the warning message is triggered to notify the user to lower the worker number if 546 # necessary. 547 # 548 # 549 # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is 550 # available (available in most of Linux system, but not OSX and Windows). 551 # When os.sched_getaffinity is not available, os.cpu_count() is called instead, but 552 # it doesn't repect cpuset. 553 # We don't take threading into account since each worker process is single threaded 554 # at this time. 555 # 556 # We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc) 557 # other than `torch.set_num_threads` to 1 in the worker process, if the passing 558 # in functions use 3rd party modules that rely on those threading flags to determine 559 # how many thread to create (eg. numpy, etc), then it is caller's responsibility to 560 # set those flags correctly. 561 def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): 562 suggested_max_worker_msg = ( 563 ( 564 ( 565 "Our suggested max number of worker in current system is {}{}, which is smaller " 566 "than what this DataLoader is going to create." 567 ).format( 568 num_worker_suggest, 569 ( 570 "" 571 if cpuset_checked 572 else " (`cpuset` is not taken into account)" 573 ), 574 ) 575 ) 576 if num_worker_suggest is not None 577 else ( 578 "DataLoader is not able to compute a suggested max number of worker in current system." 579 ) 580 ) 581 582 warn_msg = ( 583 f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} " 584 "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, " 585 "lower the worker number to avoid potential slowness/freeze if necessary." 586 ) 587 return warn_msg 588 589 if not self.num_workers or self.num_workers == 0: 590 return 591 592 # try to compute a suggested max number of worker based on system's resource 593 max_num_worker_suggest = None 594 cpuset_checked = False 595 if hasattr(os, "sched_getaffinity"): 596 try: 597 max_num_worker_suggest = len(os.sched_getaffinity(0)) 598 cpuset_checked = True 599 except Exception: 600 pass 601 if max_num_worker_suggest is None: 602 # os.cpu_count() could return Optional[int] 603 # get cpu count first and check None in order to satisfy mypy check 604 cpu_count = os.cpu_count() 605 if cpu_count is not None: 606 max_num_worker_suggest = cpu_count 607 608 if max_num_worker_suggest is None: 609 warnings.warn( 610 _create_warning_msg( 611 max_num_worker_suggest, self.num_workers, cpuset_checked 612 ) 613 ) 614 return 615 616 if self.num_workers > max_num_worker_suggest: 617 warnings.warn( 618 _create_warning_msg( 619 max_num_worker_suggest, self.num_workers, cpuset_checked 620 ) 621 ) 622 623 624class _BaseDataLoaderIter: 625 def __init__(self, loader: DataLoader) -> None: 626 self._dataset = loader.dataset 627 self._shared_seed = None 628 self._pg = None 629 if isinstance(self._dataset, IterDataPipe): 630 if dist.is_available() and dist.is_initialized(): 631 self._pg = dist.new_group(backend="gloo") 632 self._shared_seed = _share_dist_seed(loader.generator, self._pg) 633 shared_rng = torch.Generator() 634 shared_rng.manual_seed(self._shared_seed) 635 self._dataset = torch.utils.data.graph_settings.apply_random_seed( 636 self._dataset, shared_rng 637 ) 638 self._dataset_kind = loader._dataset_kind 639 self._IterableDataset_len_called = loader._IterableDataset_len_called 640 self._auto_collation = loader._auto_collation 641 self._drop_last = loader.drop_last 642 self._index_sampler = loader._index_sampler 643 self._num_workers = loader.num_workers 644 ws, rank = _get_distributed_settings() 645 self._world_size = ws 646 self._rank = rank 647 # for other backends, pin_memory_device need to set. if not set 648 # default behaviour is CUDA device. if pin_memory_device is selected 649 # and pin_memory is not set, the default behaviour false. 650 if len(loader.pin_memory_device) == 0: 651 self._pin_memory = loader.pin_memory and torch.cuda.is_available() 652 self._pin_memory_device = None 653 else: 654 if not loader.pin_memory: 655 warn_msg = ( 656 "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used" 657 "please set pin_memory to true, if you need to use the device pin memory" 658 ) 659 warnings.warn(warn_msg) 660 661 self._pin_memory = loader.pin_memory 662 self._pin_memory_device = loader.pin_memory_device 663 self._timeout = loader.timeout 664 self._collate_fn = loader.collate_fn 665 self._sampler_iter = iter(self._index_sampler) 666 self._base_seed = ( 667 torch.empty((), dtype=torch.int64) 668 .random_(generator=loader.generator) 669 .item() 670 ) 671 self._persistent_workers = loader.persistent_workers 672 self._num_yielded = 0 673 self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" 674 675 def __iter__(self) -> "_BaseDataLoaderIter": 676 return self 677 678 def _reset(self, loader, first_iter=False): 679 self._sampler_iter = iter(self._index_sampler) 680 self._num_yielded = 0 681 self._IterableDataset_len_called = loader._IterableDataset_len_called 682 if isinstance(self._dataset, IterDataPipe): 683 self._shared_seed = _share_dist_seed(loader.generator, self._pg) 684 shared_rng = torch.Generator() 685 shared_rng.manual_seed(self._shared_seed) 686 self._dataset = torch.utils.data.graph_settings.apply_random_seed( 687 self._dataset, shared_rng 688 ) 689 690 def _next_index(self): 691 return next(self._sampler_iter) # may raise StopIteration 692 693 def _next_data(self): 694 raise NotImplementedError 695 696 def __next__(self) -> Any: 697 with torch.autograd.profiler.record_function(self._profile_name): 698 if self._sampler_iter is None: 699 # TODO(https://github.com/pytorch/pytorch/issues/76750) 700 self._reset() # type: ignore[call-arg] 701 data = self._next_data() 702 self._num_yielded += 1 703 if ( 704 self._dataset_kind == _DatasetKind.Iterable 705 and self._IterableDataset_len_called is not None 706 and self._num_yielded > self._IterableDataset_len_called 707 ): 708 warn_msg = ( 709 f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}" 710 f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. " 711 ) 712 if self._num_workers > 0: 713 warn_msg += ( 714 "For multiprocessing data-loading, this could be caused by not properly configuring the " 715 "IterableDataset replica at each worker. Please see " 716 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples." 717 ) 718 warnings.warn(warn_msg) 719 return data 720 721 def __len__(self) -> int: 722 return len(self._index_sampler) 723 724 def __getstate__(self): 725 # TODO: add limited pickling support for sharing an iterator 726 # across multiple threads for HOGWILD. 727 # Probably the best way to do this is by moving the sample pushing 728 # to a separate thread and then just sharing the data queue 729 # but signalling the end is tricky without a non-blocking API 730 raise NotImplementedError("{} cannot be pickled", self.__class__.__name__) 731 732 733class _SingleProcessDataLoaderIter(_BaseDataLoaderIter): 734 def __init__(self, loader): 735 super().__init__(loader) 736 assert self._timeout == 0 737 assert self._num_workers == 0 738 739 # Adds forward compatibilities so classic DataLoader can work with DataPipes: 740 # Taking care of distributed sharding 741 if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): 742 # For BC, use default SHARDING_PRIORITIES 743 torch.utils.data.graph_settings.apply_sharding( 744 self._dataset, self._world_size, self._rank 745 ) 746 747 self._dataset_fetcher = _DatasetKind.create_fetcher( 748 self._dataset_kind, 749 self._dataset, 750 self._auto_collation, 751 self._collate_fn, 752 self._drop_last, 753 ) 754 755 def _next_data(self): 756 index = self._next_index() # may raise StopIteration 757 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 758 if self._pin_memory: 759 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device) 760 return data 761 762 763class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): 764 r"""Iterates once over the DataLoader's dataset, as specified by the sampler.""" 765 766 # NOTE [ Data Loader Multiprocessing Shutdown Logic ] 767 # 768 # Preliminary: 769 # 770 # Our data model looks like this (queues are indicated with curly brackets): 771 # 772 # main process || 773 # | || 774 # {index_queue} || 775 # | || 776 # worker processes || DATA 777 # | || 778 # {worker_result_queue} || FLOW 779 # | || 780 # pin_memory_thread of main process || DIRECTION 781 # | || 782 # {data_queue} || 783 # | || 784 # data output \/ 785 # 786 # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if 787 # `pin_memory=False`. 788 # 789 # 790 # Terminating multiprocessing logic requires very careful design. In 791 # particular, we need to make sure that 792 # 793 # 1. The iterator gracefully exits the workers when its last reference is 794 # gone or it is depleted. 795 # 796 # In this case, the workers should be gracefully exited because the 797 # main process may still need to continue to run, and we want cleaning 798 # up code in the workers to be executed (e.g., releasing GPU memory). 799 # Naturally, we implement the shutdown logic in `__del__` of 800 # DataLoaderIterator. 801 # 802 # We delay the discussion on the logic in this case until later. 803 # 804 # 2. The iterator exits the workers when the loader process and/or worker 805 # processes exits normally or with error. 806 # 807 # We set all workers and `pin_memory_thread` to have `daemon=True`. 808 # 809 # You may ask, why can't we make the workers non-daemonic, and 810 # gracefully exit using the same logic as we have in `__del__` when the 811 # iterator gets deleted (see 1 above)? 812 # 813 # First of all, `__del__` is **not** guaranteed to be called when 814 # interpreter exits. Even if it is called, by the time it executes, 815 # many Python core library resources may already be freed, and even 816 # simple things like acquiring an internal lock of a queue may hang. 817 # Therefore, in this case, we actually need to prevent `__del__` from 818 # being executed, and rely on the automatic termination of daemonic 819 # children. 820 # 821 # Thus, we register an `atexit` hook that sets a global flag 822 # `_utils.python_exit_status`. Since `atexit` hooks are executed in the 823 # reverse order of registration, we are guaranteed that this flag is 824 # set before library resources we use are freed (which, at least in 825 # CPython, is done via an `atexit` handler defined in 826 # `multiprocessing/util.py` 827 # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362 828 # registered when an object requiring this mechanism is first 829 # created, e.g., `mp.Queue` 830 # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103 831 # https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29 832 # ) 833 # 834 # So in `__del__`, we check if `_utils.python_exit_status` is set or 835 # `None` (freed), and perform no-op if so. 836 # 837 # However, simply letting library clean-up codes run can also be bad, 838 # because such codes (i.e., `multiprocessing.util._exit_function()`) 839 # include join putting threads for `mp.Queue`, which can be blocking. 840 # Hence, the main process putting threads are called with 841 # `cancel_join_thread` at creation. See later section 842 # [ 3b. A process won't hang when putting into a queue; ] 843 # for more details. 844 # 845 # Here are two example cases where library clean-up codes can run 846 # before `__del__` is called: 847 # 848 # 1. If we hold onto a reference to the iterator, it more often 849 # than not tries to do `multiprocessing` library cleaning before 850 # clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666) 851 # and thus prevents our cleaning-up code to run first. 852 # 853 # 2. A similar issue araises when a `DataLoader` is used in a subprocess. 854 # When a process ends, it shuts the all its daemonic children 855 # down with a SIGTERM (instead of joining them without a timeout). 856 # Simiarly for threads, but by a different mechanism. This fact, 857 # together with a few implementation details of multiprocessing, forces 858 # us to make workers daemonic. All of our problems arise when a 859 # DataLoader is used in a subprocess, and are caused by multiprocessing 860 # code which looks more or less like this: 861 # 862 # try: 863 # your_function_using_a_dataloader() 864 # finally: 865 # multiprocessing.util._exit_function() 866 # 867 # The joining/termination mentioned above happens inside 868 # `_exit_function()`. Now, if `your_function_using_a_dataloader()` 869 # throws, the stack trace stored in the exception will prevent the 870 # frame which uses `DataLoaderIter` to be freed. If the frame has any 871 # reference to the `DataLoaderIter` (e.g., in a method of the iter), 872 # its `__del__`, which starts the shutdown procedure, will not be 873 # called. That, in turn, means that workers aren't notified. Attempting 874 # to join in `_exit_function` will then result in a hang. 875 # 876 # For context, `_exit_function` is also registered as an `atexit` call. 877 # So it is unclear to me (@ssnl) why this is needed in a finally block. 878 # The code dates back to 2008 and there is no comment on the original 879 # PEP 371 or patch https://bugs.python.org/issue3050 (containing both 880 # the finally block and the `atexit` registration) that explains this. 881 # 882 # 883 # Finally, another choice is to just shutdown workers with logic in 1 884 # above whenever we see an error in `next`. This isn't ideal because 885 # a. It prevents users from using try-catch to resume data loading. 886 # b. It doesn't prevent hanging if users have references to the 887 # iterator. 888 # 889 # 3. All processes exit if any of them die unexpectedly by fatal signals. 890 # 891 # As shown above, the workers are set as daemonic children of the main 892 # process. However, automatic cleaning-up of such child processes only 893 # happens if the parent process exits gracefully (e.g., not via fatal 894 # signals like SIGKILL). So we must ensure that each process will exit 895 # even the process that should send/receive data to/from it were 896 # killed, i.e., 897 # 898 # a. A process won't hang when getting from a queue. 899 # 900 # Even with carefully designed data dependencies (i.e., a `put()` 901 # always corresponding to a `get()`), hanging on `get()` can still 902 # happen when data in queue is corrupted (e.g., due to 903 # `cancel_join_thread` or unexpected exit). 904 # 905 # For child exit, we set a timeout whenever we try to get data 906 # from `data_queue`, and check the workers' status on each timeout 907 # and error. 908 # See `_DataLoaderiter._get_batch()` and 909 # `_DataLoaderiter._try_get_data()` for details. 910 # 911 # Additionally, for child exit on non-Windows platforms, we also 912 # register a SIGCHLD handler (which is supported on Windows) on 913 # the main process, which checks if any of the workers fail in the 914 # (Python) handler. This is more efficient and faster in detecting 915 # worker failures, compared to only using the above mechanism. 916 # See `DataLoader.cpp` and `_utils/signal_handling.py` for details. 917 # 918 # For `.get()` calls where the sender(s) is not the workers, we 919 # guard them with timeouts, and check the status of the sender 920 # when timeout happens: 921 # + in the workers, the `_utils.worker.ManagerWatchdog` class 922 # checks the status of the main process. 923 # + if `pin_memory=True`, when getting from `pin_memory_thread`, 924 # check `pin_memory_thread` status periodically until `.get()` 925 # returns or see that `pin_memory_thread` died. 926 # 927 # b. A process won't hang when putting into a queue; 928 # 929 # We use `mp.Queue` which has a separate background thread to put 930 # objects from an unbounded buffer array. The background thread is 931 # daemonic and usually automatically joined when the process 932 # *exits*. 933 # 934 # In case that the receiver has ended abruptly while 935 # reading from the pipe, the join will hang forever. The usual 936 # solution for this in Python is calling `q.cancel_join_thread`, 937 # which prevents automatically joining it when finalizing 938 # (exiting). 939 # 940 # Nonetheless, `cancel_join_thread` must only be called when the 941 # queue is **not** going to be read from or write into by another 942 # process, because it may hold onto a lock or leave corrupted data 943 # in the queue, leading other readers/writers to hang. 944 # 945 # Hence, 946 # + For worker processes, we only do so (for their output 947 # queues, i.e., `worker_result_queue`) before exiting. 948 # + For `pin_memory_thread`, its output queue `data_queue` is a 949 # `queue.Queue` that does blocking `put` if the queue is full. 950 # So there is no above problem, but as a result, in 951 # `_pin_memory_loop`, we do need to wrap the `put` in a loop 952 # that breaks not only upon success, but also when the main 953 # process stops reading, i.e., is shutting down. 954 # + For loader process, we `cancel_join_thread()` for all 955 # `_index_queues` because the whole purpose of workers and 956 # `pin_memory_thread` is to serve the loader process. If 957 # loader process is already exiting, we don't really care if 958 # the queues are corrupted. 959 # 960 # 961 # Now let's get back to 1: 962 # how we gracefully exit the workers when the last reference to the 963 # iterator is gone. 964 # 965 # To achieve this, we implement the following logic along with the design 966 # choices mentioned above: 967 # 968 # `workers_done_event`: 969 # A `multiprocessing.Event` shared among the main process and all worker 970 # processes. This is used to signal the workers that the iterator is 971 # shutting down. After it is set, they will not send processed data to 972 # queues anymore, and only wait for the final `None` before exiting. 973 # `done_event` isn't strictly needed. I.e., we can just check for `None` 974 # from the input queue, but it allows us to skip wasting resources 975 # processing data if we are already shutting down. 976 # 977 # `pin_memory_thread_done_event`: 978 # A `threading.Event` for a similar purpose to that of 979 # `workers_done_event`, but is for the `pin_memory_thread`. The reason 980 # that separate events are needed is that `pin_memory_thread` reads from 981 # the output queue of the workers. But the workers, upon seeing that 982 # `workers_done_event` is set, only wants to see the final `None`, and is 983 # not required to flush all data in the output queue (e.g., it may call 984 # `cancel_join_thread` on that queue if its `IterableDataset` iterator 985 # happens to exhaust coincidentally, which is out of the control of the 986 # main process). Thus, since we will exit `pin_memory_thread` before the 987 # workers (see below), two separete events are used. 988 # 989 # NOTE: In short, the protocol is that the main process will set these 990 # `done_event`s and then the corresponding processes/threads a `None`, 991 # and that they may exit at any time after receiving the `None`. 992 # 993 # NOTE: Using `None` as the final signal is valid, since normal data will 994 # always be a 2-tuple with the 1st element being the index of the data 995 # transferred (different from dataset index/key), and the 2nd being 996 # either the dataset key or the data sample (depending on which part 997 # of the data model the queue is at). 998 # 999 # [ worker processes ] 1000 # While loader process is alive: 1001 # Get from `index_queue`. 1002 # If get anything else, 1003 # Check `workers_done_event`. 1004 # If set, continue to next iteration 1005 # i.e., keep getting until see the `None`, then exit. 1006 # Otherwise, process data: 1007 # If is fetching from an `IterableDataset` and the iterator 1008 # is exhausted, send an `_IterableDatasetStopIteration` 1009 # object to signal iteration end. The main process, upon 1010 # receiving such an object, will send `None` to this 1011 # worker and not use the corresponding `index_queue` 1012 # anymore. 1013 # If timed out, 1014 # No matter `workers_done_event` is set (still need to see `None`) 1015 # or not, must continue to next iteration. 1016 # (outside loop) 1017 # If `workers_done_event` is set, (this can be False with `IterableDataset`) 1018 # `data_queue.cancel_join_thread()`. (Everything is ending here: 1019 # main process won't read from it; 1020 # other workers will also call 1021 # `cancel_join_thread`.) 1022 # 1023 # [ pin_memory_thread ] 1024 # # No need to check main thread. If this thread is alive, the main loader 1025 # # thread must be alive, because this thread is set as daemonic. 1026 # While `pin_memory_thread_done_event` is not set: 1027 # Get from `worker_result_queue`. 1028 # If timed out, continue to get in the next iteration. 1029 # Otherwise, process data. 1030 # While `pin_memory_thread_done_event` is not set: 1031 # Put processed data to `data_queue` (a `queue.Queue` with blocking put) 1032 # If timed out, continue to put in the next iteration. 1033 # Otherwise, break, i.e., continuing to the out loop. 1034 # 1035 # NOTE: we don't check the status of the main thread because 1036 # 1. if the process is killed by fatal signal, `pin_memory_thread` 1037 # ends. 1038 # 2. in other cases, either the cleaning-up in __del__ or the 1039 # automatic exit of daemonic thread will take care of it. 1040 # This won't busy-wait either because `.get(timeout)` does not 1041 # busy-wait. 1042 # 1043 # [ main process ] 1044 # In the DataLoader Iter's `__del__` 1045 # b. Exit `pin_memory_thread` 1046 # i. Set `pin_memory_thread_done_event`. 1047 # ii Put `None` in `worker_result_queue`. 1048 # iii. Join the `pin_memory_thread`. 1049 # iv. `worker_result_queue.cancel_join_thread()`. 1050 # 1051 # c. Exit the workers. 1052 # i. Set `workers_done_event`. 1053 # ii. Put `None` in each worker's `index_queue`. 1054 # iii. Join the workers. 1055 # iv. Call `.cancel_join_thread()` on each worker's `index_queue`. 1056 # 1057 # NOTE: (c) is better placed after (b) because it may leave corrupted 1058 # data in `worker_result_queue`, which `pin_memory_thread` 1059 # reads from, in which case the `pin_memory_thread` can only 1060 # happen at timing out, which is slow. Nonetheless, same thing 1061 # happens if a worker is killed by signal at unfortunate times, 1062 # but in other cases, we are better off having a non-corrupted 1063 # `worker_result_queue` for `pin_memory_thread`. 1064 # 1065 # NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b) 1066 # can be omitted 1067 # 1068 # NB: `done_event`s isn't strictly needed. E.g., we can just check for 1069 # `None` from `index_queue`, but it allows us to skip wasting resources 1070 # processing indices already in `index_queue` if we are already shutting 1071 # down. 1072 1073 def __init__(self, loader): 1074 super().__init__(loader) 1075 1076 self._prefetch_factor = loader.prefetch_factor 1077 1078 assert self._num_workers > 0 1079 assert self._prefetch_factor > 0 1080 1081 if loader.multiprocessing_context is None: 1082 multiprocessing_context = torch.multiprocessing 1083 else: 1084 multiprocessing_context = loader.multiprocessing_context 1085 1086 self._worker_init_fn = loader.worker_init_fn 1087 1088 # Adds forward compatibilities so classic DataLoader can work with DataPipes: 1089 # Additional worker init function will take care of sharding in MP and Distributed 1090 if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): 1091 self._worker_init_fn = functools.partial( 1092 _sharding_worker_init_fn, 1093 self._worker_init_fn, 1094 self._world_size, 1095 self._rank, 1096 ) 1097 1098 # No certainty which module multiprocessing_context is 1099 self._worker_result_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] 1100 self._worker_pids_set = False 1101 self._shutdown = False 1102 self._workers_done_event = multiprocessing_context.Event() 1103 1104 self._index_queues = [] 1105 self._workers = [] 1106 for i in range(self._num_workers): 1107 # No certainty which module multiprocessing_context is 1108 index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated] 1109 # Need to `cancel_join_thread` here! 1110 # See sections (2) and (3b) above. 1111 index_queue.cancel_join_thread() 1112 w = multiprocessing_context.Process( 1113 target=_utils.worker._worker_loop, 1114 args=( 1115 self._dataset_kind, 1116 self._dataset, 1117 index_queue, 1118 self._worker_result_queue, 1119 self._workers_done_event, 1120 self._auto_collation, 1121 self._collate_fn, 1122 self._drop_last, 1123 self._base_seed, 1124 self._worker_init_fn, 1125 i, 1126 self._num_workers, 1127 self._persistent_workers, 1128 self._shared_seed, 1129 ), 1130 ) 1131 w.daemon = True 1132 # NB: Process.start() actually take some time as it needs to 1133 # start a process and pass the arguments over via a pipe. 1134 # Therefore, we only add a worker to self._workers list after 1135 # it started, so that we do not call .join() if program dies 1136 # before it starts, and __del__ tries to join but will get: 1137 # AssertionError: can only join a started process. 1138 w.start() 1139 self._index_queues.append(index_queue) 1140 self._workers.append(w) 1141 1142 if self._pin_memory: 1143 self._pin_memory_thread_done_event = threading.Event() 1144 1145 # Queue is not type-annotated 1146 self._data_queue = queue.Queue() # type: ignore[var-annotated] 1147 if self._pin_memory_device == "xpu": 1148 current_device = torch.xpu.current_device() # type: ignore[attr-defined] 1149 elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): 1150 custom_device_mod = getattr( 1151 torch, torch._C._get_privateuse1_backend_name() 1152 ) 1153 current_device = custom_device_mod.current_device() 1154 else: 1155 current_device = torch.cuda.current_device() # choose cuda for default 1156 pin_memory_thread = threading.Thread( 1157 target=_utils.pin_memory._pin_memory_loop, 1158 args=( 1159 self._worker_result_queue, 1160 self._data_queue, 1161 current_device, 1162 self._pin_memory_thread_done_event, 1163 self._pin_memory_device, 1164 ), 1165 ) 1166 pin_memory_thread.daemon = True 1167 pin_memory_thread.start() 1168 # Similar to workers (see comment above), we only register 1169 # pin_memory_thread once it is started. 1170 self._pin_memory_thread = pin_memory_thread 1171 else: 1172 self._data_queue = self._worker_result_queue # type: ignore[assignment] 1173 1174 # In some rare cases, persistent workers (daemonic processes) 1175 # would be terminated before `__del__` of iterator is invoked 1176 # when main process exits 1177 # It would cause failure when pin_memory_thread tries to read 1178 # corrupted data from worker_result_queue 1179 # atexit is used to shutdown thread and child processes in the 1180 # right sequence before main process exits 1181 if self._persistent_workers and self._pin_memory: 1182 import atexit 1183 1184 for w in self._workers: 1185 atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w) 1186 1187 # .pid can be None only before process is spawned (not the case, so ignore) 1188 _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc] 1189 _utils.signal_handling._set_SIGCHLD_handler() 1190 self._worker_pids_set = True 1191 self._reset(loader, first_iter=True) 1192 1193 def _reset(self, loader, first_iter=False): 1194 super()._reset(loader, first_iter) 1195 self._send_idx = 0 # idx of the next task to be sent to workers 1196 self._rcvd_idx = 0 # idx of the next task to be returned in __next__ 1197 # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx). 1198 # map: task idx => - (worker_id,) if data isn't fetched (outstanding) 1199 # \ (worker_id, data) if data is already fetched (out-of-order) 1200 self._task_info = {} 1201 self._tasks_outstanding = ( 1202 0 # always equal to count(v for v in task_info.values() if len(v) == 1) 1203 ) 1204 # A list of booleans representing whether each worker still has work to 1205 # do, i.e., not having exhausted its iterable dataset object. It always 1206 # contains all `True`s if not using an iterable-style dataset 1207 # (i.e., if kind != Iterable). 1208 # Not that this indicates that a worker still has work to do *for this epoch*. 1209 # It does not mean that a worker is dead. In case of `_persistent_workers`, 1210 # the worker will be reset to available in the next epoch. 1211 self._workers_status = [True for i in range(self._num_workers)] 1212 # Reset the worker queue cycle so it resumes next epoch at worker 0 1213 self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers)) 1214 # We resume the prefetching in case it was enabled 1215 if not first_iter: 1216 for idx in range(self._num_workers): 1217 self._index_queues[idx].put( 1218 _utils.worker._ResumeIteration(self._shared_seed) 1219 ) 1220 resume_iteration_cnt = self._num_workers 1221 while resume_iteration_cnt > 0: 1222 return_idx, return_data = self._get_data() 1223 if isinstance(return_idx, _utils.worker._ResumeIteration): 1224 assert return_data is None 1225 resume_iteration_cnt -= 1 1226 # prime the prefetch loop 1227 for _ in range(self._prefetch_factor * self._num_workers): 1228 self._try_put_index() 1229 1230 def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): 1231 # Tries to fetch data from `self._data_queue` once for a given timeout. 1232 # This can also be used as inner loop of fetching without timeout, with 1233 # the sender status as the loop condition. 1234 # 1235 # This raises a `RuntimeError` if any worker died expectedly. This error 1236 # can come from either the SIGCHLD handler in `_utils/signal_handling.py` 1237 # (only for non-Windows platforms), or the manual check below on errors 1238 # and timeouts. 1239 # 1240 # Returns a 2-tuple: 1241 # (bool: whether successfully get data, any: data if successful else None) 1242 try: 1243 data = self._data_queue.get(timeout=timeout) 1244 return (True, data) 1245 except Exception as e: 1246 # At timeout and error, we manually check whether any worker has 1247 # failed. Note that this is the only mechanism for Windows to detect 1248 # worker failures. 1249 failed_workers = [] 1250 for worker_id, w in enumerate(self._workers): 1251 if self._workers_status[worker_id] and not w.is_alive(): 1252 failed_workers.append(w) 1253 self._mark_worker_as_unavailable(worker_id) 1254 if len(failed_workers) > 0: 1255 pids_str = ", ".join(str(w.pid) for w in failed_workers) 1256 raise RuntimeError( 1257 f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly" 1258 ) from e 1259 if isinstance(e, queue.Empty): 1260 return (False, None) 1261 1262 import errno 1263 import tempfile 1264 1265 try: 1266 # Raise an exception if we are this close to the FDs limit. 1267 # Apparently, trying to open only one file is not a sufficient 1268 # test. 1269 # See NOTE [ DataLoader on Linux and open files limit ] 1270 fds_limit_margin = 10 1271 fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] 1272 except OSError as e: 1273 if e.errno == errno.EMFILE: 1274 raise RuntimeError( 1275 "Too many open files. Communication with the" 1276 " workers is no longer possible. Please increase the" 1277 " limit using `ulimit -n` in the shell or change the" 1278 " sharing strategy by calling" 1279 " `torch.multiprocessing.set_sharing_strategy('file_system')`" 1280 " at the beginning of your code" 1281 ) from None 1282 raise 1283 1284 # NOTE [ DataLoader on Linux and open files limit ] 1285 # 1286 # On Linux when DataLoader is used with multiprocessing we pass the data between 1287 # the root process and the workers through SHM files. We remove those files from 1288 # the filesystem as soon as they are created and keep them alive by 1289 # passing around their file descriptors through AF_UNIX sockets. (See 1290 # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in 1291 # the wiki (https://github.com/pytorch/pytorch/wiki).) 1292 # 1293 # This sometimes leads us to exceeding the open files limit. When that happens, 1294 # and the offending file descriptor is coming over a socket, the `socket` Python 1295 # package silently strips the file descriptor from the message, setting only the 1296 # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that 1297 # it _indicates that some control data were discarded due to lack of space in 1298 # the buffer for ancillary data_). This might reflect the C implementation of 1299 # AF_UNIX sockets. 1300 # 1301 # This behaviour can be reproduced with the script and instructions at the 1302 # bottom of this note. 1303 # 1304 # When that happens, the standard Python `multiprocessing` (and not 1305 # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata` 1306 # 1307 # Sometimes, instead of the FD being stripped, you may get an `OSError: 1308 # Too many open files`, both in the script below and in DataLoader. However, 1309 # this is rare and seems to be nondeterministic. 1310 # 1311 # 1312 # #!/usr/bin/env python3 1313 # import sys 1314 # import socket 1315 # import os 1316 # import array 1317 # import shutil 1318 # import socket 1319 # 1320 # 1321 # if len(sys.argv) != 4: 1322 # print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)") 1323 # sys.exit(1) 1324 # 1325 # if __name__ == '__main__': 1326 # dirname = sys.argv[1] 1327 # sock_path = dirname + "/sock" 1328 # iterations = int(sys.argv[2]) 1329 # def dummy_path(i): 1330 # return dirname + "/" + str(i) + ".dummy" 1331 # 1332 # 1333 # if sys.argv[3] == 'send': 1334 # while not os.path.exists(sock_path): 1335 # pass 1336 # client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) 1337 # client.connect(sock_path) 1338 # for i in range(iterations): 1339 # fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT) 1340 # ancdata = array.array('i', [fd]) 1341 # msg = bytes([i % 256]) 1342 # print("Sending fd ", fd, " (iteration #", i, ")") 1343 # client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)]) 1344 # 1345 # 1346 # else: 1347 # assert sys.argv[3] == 'recv' 1348 # 1349 # if os.path.exists(dirname): 1350 # raise Exception("Directory exists") 1351 # 1352 # os.mkdir(dirname) 1353 # 1354 # print("Opening socket...") 1355 # server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) 1356 # server.bind(sock_path) 1357 # 1358 # print("Listening...") 1359 # for i in range(iterations): 1360 # a = array.array('i') 1361 # msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize)) 1362 # assert(len(ancdata) == 1) 1363 # cmsg_level, cmsg_type, cmsg_data = ancdata[0] 1364 # a.frombytes(cmsg_data) 1365 # print("Received fd ", a[0], " (iteration #", i, ")") 1366 # 1367 # shutil.rmtree(dirname) 1368 # 1369 # Steps to reproduce: 1370 # 1371 # 1. Run two shells and set lower file descriptor limit in the receiving one: 1372 # (shell1) ulimit -n 1020 1373 # (shell2) ulimit -n 1022 1374 # 1375 # 2. Run the script above with the `recv` option in the first shell 1376 # (shell1) ./test_socket.py sock_tmp 1017 recv 1377 # 1378 # 3. Run the script with the `send` option in the second shell: 1379 # (shell2) ./test_socket.py sock_tmp 1017 send 1380 1381 def _get_data(self): 1382 # Fetches data from `self._data_queue`. 1383 # 1384 # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds, 1385 # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)` 1386 # in a loop. This is the only mechanism to detect worker failures for 1387 # Windows. For other platforms, a SIGCHLD handler is also used for 1388 # worker failure detection. 1389 # 1390 # If `pin_memory=True`, we also need check if `pin_memory_thread` had 1391 # died at timeouts. 1392 if self._timeout > 0: 1393 success, data = self._try_get_data(self._timeout) 1394 if success: 1395 return data 1396 else: 1397 raise RuntimeError( 1398 f"DataLoader timed out after {self._timeout} seconds" 1399 ) 1400 elif self._pin_memory: 1401 while self._pin_memory_thread.is_alive(): 1402 success, data = self._try_get_data() 1403 if success: 1404 return data 1405 else: 1406 # while condition is false, i.e., pin_memory_thread died. 1407 raise RuntimeError("Pin memory thread exited unexpectedly") 1408 # In this case, `self._data_queue` is a `queue.Queue`,. But we don't 1409 # need to call `.task_done()` because we don't use `.join()`. 1410 else: 1411 while True: 1412 success, data = self._try_get_data() 1413 if success: 1414 return data 1415 1416 def _next_data(self): 1417 while True: 1418 # If the worker responsible for `self._rcvd_idx` has already ended 1419 # and was unable to fulfill this task (due to exhausting an `IterableDataset`), 1420 # we try to advance `self._rcvd_idx` to find the next valid index. 1421 # 1422 # This part needs to run in the loop because both the `self._get_data()` 1423 # call and `_IterableDatasetStopIteration` check below can mark 1424 # extra worker(s) as dead. 1425 while self._rcvd_idx < self._send_idx: 1426 info = self._task_info[self._rcvd_idx] 1427 worker_id = info[0] 1428 if ( 1429 len(info) == 2 or self._workers_status[worker_id] 1430 ): # has data or is still active 1431 break 1432 del self._task_info[self._rcvd_idx] 1433 self._rcvd_idx += 1 1434 else: 1435 # no valid `self._rcvd_idx` is found (i.e., didn't break) 1436 if not self._persistent_workers: 1437 self._shutdown_workers() 1438 raise StopIteration 1439 1440 # Now `self._rcvd_idx` is the batch index we want to fetch 1441 1442 # Check if the next sample has already been generated 1443 if len(self._task_info[self._rcvd_idx]) == 2: 1444 data = self._task_info.pop(self._rcvd_idx)[1] 1445 return self._process_data(data) 1446 1447 assert not self._shutdown and self._tasks_outstanding > 0 1448 idx, data = self._get_data() 1449 self._tasks_outstanding -= 1 1450 if self._dataset_kind == _DatasetKind.Iterable: 1451 # Check for _IterableDatasetStopIteration 1452 if isinstance(data, _utils.worker._IterableDatasetStopIteration): 1453 if self._persistent_workers: 1454 self._workers_status[data.worker_id] = False 1455 else: 1456 self._mark_worker_as_unavailable(data.worker_id) 1457 self._try_put_index() 1458 continue 1459 1460 if idx != self._rcvd_idx: 1461 # store out-of-order samples 1462 self._task_info[idx] += (data,) 1463 else: 1464 del self._task_info[idx] 1465 return self._process_data(data) 1466 1467 def _try_put_index(self): 1468 assert self._tasks_outstanding < self._prefetch_factor * self._num_workers 1469 1470 try: 1471 index = self._next_index() 1472 except StopIteration: 1473 return 1474 for _ in range(self._num_workers): # find the next active worker, if any 1475 worker_queue_idx = next(self._worker_queue_idx_cycle) 1476 if self._workers_status[worker_queue_idx]: 1477 break 1478 else: 1479 # not found (i.e., didn't break) 1480 return 1481 1482 self._index_queues[worker_queue_idx].put((self._send_idx, index)) # type: ignore[possibly-undefined] 1483 self._task_info[self._send_idx] = (worker_queue_idx,) 1484 self._tasks_outstanding += 1 1485 self._send_idx += 1 1486 1487 def _process_data(self, data): 1488 self._rcvd_idx += 1 1489 self._try_put_index() 1490 if isinstance(data, ExceptionWrapper): 1491 data.reraise() 1492 return data 1493 1494 def _mark_worker_as_unavailable(self, worker_id, shutdown=False): 1495 # Mark a worker as having finished its work e.g., due to 1496 # exhausting an `IterableDataset`. This should be used only when this 1497 # `_MultiProcessingDataLoaderIter` is going to continue running. 1498 1499 assert self._workers_status[worker_id] or ( 1500 self._persistent_workers and shutdown 1501 ) 1502 1503 # Signal termination to that specific worker. 1504 q = self._index_queues[worker_id] 1505 # Indicate that no more data will be put on this queue by the current 1506 # process. 1507 q.put(None) 1508 1509 # Note that we don't actually join the worker here, nor do we remove the 1510 # worker's pid from C side struct because (1) joining may be slow, and 1511 # (2) since we don't join, the worker may still raise error, and we 1512 # prefer capturing those, rather than ignoring them, even though they 1513 # are raised after the worker has finished its job. 1514 # Joinning is deferred to `_shutdown_workers`, which it is called when 1515 # all workers finish their jobs (e.g., `IterableDataset` replicas) or 1516 # when this iterator is garbage collected. 1517 1518 self._workers_status[worker_id] = False 1519 1520 assert self._workers_done_event.is_set() == shutdown 1521 1522 def _shutdown_workers(self): 1523 # Called when shutting down this `_MultiProcessingDataLoaderIter`. 1524 # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on 1525 # the logic of this function. 1526 if ( 1527 _utils is None 1528 or _utils.python_exit_status is True 1529 or _utils.python_exit_status is None 1530 ): 1531 # See (2) of the note. If Python is shutting down, do no-op. 1532 return 1533 # Normal exit when last reference is gone / iterator is depleted. 1534 # See (1) and the second half of the note. 1535 if not self._shutdown: 1536 self._shutdown = True 1537 try: 1538 # Normal exit when last reference is gone / iterator is depleted. 1539 # See (1) and the second half of the note. 1540 1541 # Exit `pin_memory_thread` first because exiting workers may leave 1542 # corrupted data in `worker_result_queue` which `pin_memory_thread` 1543 # reads from. 1544 if hasattr(self, "_pin_memory_thread"): 1545 # Use hasattr in case error happens before we set the attribute. 1546 self._pin_memory_thread_done_event.set() 1547 # Send something to pin_memory_thread in case it is waiting 1548 # so that it can wake up and check `pin_memory_thread_done_event` 1549 self._worker_result_queue.put((None, None)) 1550 self._pin_memory_thread.join() 1551 self._worker_result_queue.cancel_join_thread() 1552 self._worker_result_queue.close() 1553 1554 # Exit workers now. 1555 self._workers_done_event.set() 1556 for worker_id in range(len(self._workers)): 1557 # Get number of workers from `len(self._workers)` instead of 1558 # `self._num_workers` in case we error before starting all 1559 # workers. 1560 # If we are using workers_status with persistent_workers 1561 # we have to shut it down because the worker is paused 1562 if self._persistent_workers or self._workers_status[worker_id]: 1563 self._mark_worker_as_unavailable(worker_id, shutdown=True) 1564 for w in self._workers: 1565 # We should be able to join here, but in case anything went 1566 # wrong, we set a timeout and if the workers fail to join, 1567 # they are killed in the `finally` block. 1568 w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) 1569 for q in self._index_queues: 1570 q.cancel_join_thread() 1571 q.close() 1572 finally: 1573 # Even though all this function does is putting into queues that 1574 # we have called `cancel_join_thread` on, weird things can 1575 # happen when a worker is killed by a signal, e.g., hanging in 1576 # `Event.set()`. So we need to guard this with SIGCHLD handler, 1577 # and remove pids from the C side data structure only at the 1578 # end. 1579 # 1580 # FIXME: Unfortunately, for Windows, we are missing a worker 1581 # error detection mechanism here in this function, as it 1582 # doesn't provide a SIGCHLD handler. 1583 if self._worker_pids_set: 1584 _utils.signal_handling._remove_worker_pids(id(self)) 1585 self._worker_pids_set = False 1586 for w in self._workers: 1587 if w.is_alive(): 1588 # Existing mechanisms try to make the workers exit 1589 # peacefully, but in case that we unfortunately reach 1590 # here, which we shouldn't, (e.g., pytorch/pytorch#39570), 1591 # we kill the worker. 1592 w.terminate() 1593 1594 # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter` 1595 @staticmethod 1596 def _clean_up_worker(w): 1597 try: 1598 w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL) 1599 finally: 1600 if w.is_alive(): 1601 w.terminate() 1602 1603 def __del__(self): 1604 self._shutdown_workers() 1605