xref: /aosp_15_r20/external/pytorch/torch/utils/data/dataloader.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter.
3
4To support these two classes, in `./_utils` we define many utility methods and
5functions to be run in multiprocessing. E.g., the data loading worker loop is
6in `./_utils/worker.py`.
7"""
8
9import functools
10import itertools
11import logging
12import multiprocessing as python_multiprocessing
13import os
14import queue
15import threading
16import warnings
17from typing import Any, Callable, Generic, Iterable, List, Optional, TypeVar, Union
18
19import torch
20import torch.distributed as dist
21import torch.utils.data.graph_settings
22from torch._utils import ExceptionWrapper
23from torch.utils.data import _utils
24from torch.utils.data.datapipes.datapipe import (
25    _IterDataPipeSerializationWrapper,
26    _MapDataPipeSerializationWrapper,
27    IterDataPipe,
28    MapDataPipe,
29)
30from torch.utils.data.dataset import Dataset, IterableDataset
31from torch.utils.data.sampler import (
32    BatchSampler,
33    RandomSampler,
34    Sampler,
35    SequentialSampler,
36)
37
38
39__all__ = [
40    "DataLoader",
41    "get_worker_info",
42    "default_collate",
43    "default_convert",
44]
45
46
47_T = TypeVar("_T")
48_T_co = TypeVar("_T_co", covariant=True)
49_worker_init_fn_t = Callable[[int], None]
50
51# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
52# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
53# See https://github.com/python/mypy/issues/3737.
54_collate_fn_t = Callable[[List[_T]], Any]
55
56
57# These functions used to be defined in this file. However, it was moved to
58# _utils/collate.py. Although it is rather hard to access this from user land
59# (one has to explicitly directly `import torch.utils.data.dataloader`), there
60# probably is user code out there using it. This aliasing maintains BC in this
61# aspect.
62default_collate: _collate_fn_t = _utils.collate.default_collate
63default_convert = _utils.collate.default_convert
64
65get_worker_info = _utils.worker.get_worker_info
66
67logger = logging.getLogger(__name__)
68
69
70class _DatasetKind:
71    Map = 0
72    Iterable = 1
73
74    @staticmethod
75    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
76        if kind == _DatasetKind.Map:
77            return _utils.fetch._MapDatasetFetcher(
78                dataset, auto_collation, collate_fn, drop_last
79            )
80        else:
81            return _utils.fetch._IterableDatasetFetcher(
82                dataset, auto_collation, collate_fn, drop_last
83            )
84
85
86class _InfiniteConstantSampler(Sampler):
87    r"""Analogous to ``itertools.repeat(None, None)``.
88
89    Used as sampler for :class:`~torch.utils.data.IterableDataset`.
90    """
91
92    def __iter__(self):
93        while True:
94            yield None
95
96
97def _get_distributed_settings():
98    if dist.is_available() and dist.is_initialized():
99        return dist.get_world_size(), dist.get_rank()
100    else:
101        return 1, 0
102
103
104def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
105    global_worker_id = worker_id
106    info = torch.utils.data.get_worker_info()
107    assert info is not None
108    total_workers = info.num_workers
109    datapipe = info.dataset
110    assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
111    # To distribute elements across distributed process evenly, we should shard data on distributed
112    # processes first then shard on worker processes
113    total_workers *= world_size
114    global_worker_id = global_worker_id * world_size + rank_id
115    # For BC, use default SHARDING_PRIORITIES
116    torch.utils.data.graph_settings.apply_sharding(
117        datapipe, total_workers, global_worker_id
118    )
119    if worker_init_fn is not None:
120        worker_init_fn(worker_id)
121
122
123def _share_dist_seed(generator, pg):
124    _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
125    if isinstance(pg, dist.ProcessGroup):
126        dist.broadcast(_shared_seed, src=0, group=pg)
127    return _shared_seed.item()
128
129
130class DataLoader(Generic[_T_co]):
131    r"""
132    Data loader combines a dataset and a sampler, and provides an iterable over the given dataset.
133
134    The :class:`~torch.utils.data.DataLoader` supports both map-style and
135    iterable-style datasets with single- or multi-process loading, customizing
136    loading order and optional automatic batching (collation) and memory pinning.
137
138    See :py:mod:`torch.utils.data` documentation page for more details.
139
140    Args:
141        dataset (Dataset): dataset from which to load the data.
142        batch_size (int, optional): how many samples per batch to load
143            (default: ``1``).
144        shuffle (bool, optional): set to ``True`` to have the data reshuffled
145            at every epoch (default: ``False``).
146        sampler (Sampler or Iterable, optional): defines the strategy to draw
147            samples from the dataset. Can be any ``Iterable`` with ``__len__``
148            implemented. If specified, :attr:`shuffle` must not be specified.
149        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
150            returns a batch of indices at a time. Mutually exclusive with
151            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
152            and :attr:`drop_last`.
153        num_workers (int, optional): how many subprocesses to use for data
154            loading. ``0`` means that the data will be loaded in the main process.
155            (default: ``0``)
156        collate_fn (Callable, optional): merges a list of samples to form a
157            mini-batch of Tensor(s).  Used when using batched loading from a
158            map-style dataset.
159        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
160            into device/CUDA pinned memory before returning them.  If your data elements
161            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
162            see the example below.
163        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
164            if the dataset size is not divisible by the batch size. If ``False`` and
165            the size of dataset is not divisible by the batch size, then the last batch
166            will be smaller. (default: ``False``)
167        timeout (numeric, optional): if positive, the timeout value for collecting a batch
168            from workers. Should always be non-negative. (default: ``0``)
169        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
170            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
171            input, after seeding and before data loading. (default: ``None``)
172        multiprocessing_context (str or multiprocessing.context.BaseContext, optional): If
173            ``None``, the default `multiprocessing context`_ of your operating system will
174            be used. (default: ``None``)
175        generator (torch.Generator, optional): If not ``None``, this RNG will be used
176            by RandomSampler to generate random indexes and multiprocessing to generate
177            ``base_seed`` for workers. (default: ``None``)
178        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
179            in advance by each worker. ``2`` means there will be a total of
180            2 * num_workers batches prefetched across all workers. (default value depends
181            on the set value for num_workers. If value of num_workers=0 default is ``None``.
182            Otherwise, if value of ``num_workers > 0`` default is ``2``).
183        persistent_workers (bool, optional): If ``True``, the data loader will not shut down
184            the worker processes after a dataset has been consumed once. This allows to
185            maintain the workers `Dataset` instances alive. (default: ``False``)
186        pin_memory_device (str, optional): the device to :attr:`pin_memory` to if ``pin_memory`` is
187            ``True``.
188
189
190    .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
191                 cannot be an unpicklable object, e.g., a lambda function. See
192                 :ref:`multiprocessing-best-practices` on more details related
193                 to multiprocessing in PyTorch.
194
195    .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
196                 When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
197                 it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
198                 rounding depending on :attr:`drop_last`, regardless of multi-process loading
199                 configurations. This represents the best guess PyTorch can make because PyTorch
200                 trusts user :attr:`dataset` code in correctly handling multi-process
201                 loading to avoid duplicate data.
202
203                 However, if sharding results in multiple workers having incomplete last batches,
204                 this estimate can still be inaccurate, because (1) an otherwise complete batch can
205                 be broken into multiple ones and (2) more than one batch worth of samples can be
206                 dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
207                 cases in general.
208
209                 See `Dataset Types`_ for more details on these two types of datasets and how
210                 :class:`~torch.utils.data.IterableDataset` interacts with
211                 `Multi-process data loading`_.
212
213    .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
214                 :ref:`data-loading-randomness` notes for random seed related questions.
215
216    .. _multiprocessing context:
217        https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
218    """
219
220    dataset: Dataset[_T_co]
221    batch_size: Optional[int]
222    num_workers: int
223    pin_memory: bool
224    drop_last: bool
225    timeout: float
226    sampler: Union[Sampler, Iterable]
227    pin_memory_device: str
228    prefetch_factor: Optional[int]
229    _iterator: Optional["_BaseDataLoaderIter"]
230    __initialized = False
231
232    def __init__(
233        self,
234        dataset: Dataset[_T_co],
235        batch_size: Optional[int] = 1,
236        shuffle: Optional[bool] = None,
237        sampler: Union[Sampler, Iterable, None] = None,
238        batch_sampler: Union[Sampler[List], Iterable[List], None] = None,
239        num_workers: int = 0,
240        collate_fn: Optional[_collate_fn_t] = None,
241        pin_memory: bool = False,
242        drop_last: bool = False,
243        timeout: float = 0,
244        worker_init_fn: Optional[_worker_init_fn_t] = None,
245        multiprocessing_context=None,
246        generator=None,
247        *,
248        prefetch_factor: Optional[int] = None,
249        persistent_workers: bool = False,
250        pin_memory_device: str = "",
251    ):
252        torch._C._log_api_usage_once("python.data_loader")
253
254        if num_workers < 0:
255            raise ValueError(
256                "num_workers option should be non-negative; "
257                "use num_workers=0 to disable multiprocessing."
258            )
259
260        if timeout < 0:
261            raise ValueError("timeout option should be non-negative")
262
263        if num_workers == 0 and prefetch_factor is not None:
264            raise ValueError(
265                "prefetch_factor option could only be specified in multiprocessing."
266                "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
267            )
268        elif num_workers > 0 and prefetch_factor is None:
269            prefetch_factor = 2
270        elif prefetch_factor is not None and prefetch_factor < 0:
271            raise ValueError("prefetch_factor option should be non-negative")
272
273        if persistent_workers and num_workers == 0:
274            raise ValueError("persistent_workers option needs num_workers > 0")
275
276        self.dataset = dataset
277        self.num_workers = num_workers
278        self.prefetch_factor = prefetch_factor
279        self.pin_memory = pin_memory
280        self.pin_memory_device = pin_memory_device
281        self.timeout = timeout
282        self.worker_init_fn = worker_init_fn
283        self.multiprocessing_context = multiprocessing_context
284
285        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
286        #   _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
287        if isinstance(self.dataset, IterDataPipe):
288            self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
289        elif isinstance(self.dataset, MapDataPipe):
290            self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
291
292        # Arg-check dataset related before checking samplers because we want to
293        # tell users that iterable-style datasets are incompatible with custom
294        # samplers first, so that they don't learn that this combo doesn't work
295        # after spending time fixing the custom sampler errors.
296        if isinstance(dataset, IterableDataset):
297            self._dataset_kind = _DatasetKind.Iterable
298            # NOTE [ Custom Samplers and IterableDataset ]
299            #
300            # `IterableDataset` does not support custom `batch_sampler` or
301            # `sampler` since the key is irrelevant (unless we support
302            # generator-style dataset one day...).
303            #
304            # For `sampler`, we always create a dummy sampler. This is an
305            # infinite sampler even when the dataset may have an implemented
306            # finite `__len__` because in multi-process data loading, naive
307            # settings will return duplicated data (which may be desired), and
308            # thus using a sampler with length matching that of dataset will
309            # cause data lost (you may have duplicates of the first couple
310            # batches, but never see anything afterwards). Therefore,
311            # `Iterabledataset` always uses an infinite sampler, an instance of
312            # `_InfiniteConstantSampler` defined above.
313            #
314            # A custom `batch_sampler` essentially only controls the batch size.
315            # However, it is unclear how useful it would be since an iterable-style
316            # dataset can handle that within itself. Moreover, it is pointless
317            # in multi-process data loading as the assignment order of batches
318            # to workers is an implementation detail so users can not control
319            # how to batchify each worker's iterable. Thus, we disable this
320            # option. If this turns out to be useful in future, we can re-enable
321            # this, and support custom samplers that specify the assignments to
322            # specific workers.
323            if isinstance(dataset, IterDataPipe):
324                if shuffle is not None:
325                    dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
326                        dataset, shuffle=shuffle
327                    )
328            # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
329            elif shuffle not in {False, None}:
330                raise ValueError(
331                    f"DataLoader with IterableDataset: expected unspecified shuffle option, but got shuffle={shuffle}"
332                )
333
334            if sampler is not None:
335                # See NOTE [ Custom Samplers and IterableDataset ]
336                raise ValueError(
337                    f"DataLoader with IterableDataset: expected unspecified sampler option, but got sampler={sampler}"
338                )
339            elif batch_sampler is not None:
340                # See NOTE [ Custom Samplers and IterableDataset ]
341                raise ValueError(
342                    "DataLoader with IterableDataset: expected unspecified "
343                    f"batch_sampler option, but got batch_sampler={batch_sampler}"
344                )
345        else:
346            shuffle = bool(shuffle)
347            self._dataset_kind = _DatasetKind.Map
348
349        if sampler is not None and shuffle:
350            raise ValueError("sampler option is mutually exclusive with " "shuffle")
351
352        if batch_sampler is not None:
353            # auto_collation with custom batch_sampler
354            if batch_size != 1 or shuffle or sampler is not None or drop_last:
355                raise ValueError(
356                    "batch_sampler option is mutually exclusive "
357                    "with batch_size, shuffle, sampler, and "
358                    "drop_last"
359                )
360            batch_size = None
361            drop_last = False
362        elif batch_size is None:
363            # no auto_collation
364            if drop_last:
365                raise ValueError(
366                    "batch_size=None option disables auto-batching "
367                    "and is mutually exclusive with drop_last"
368                )
369
370        if sampler is None:  # give default samplers
371            if self._dataset_kind == _DatasetKind.Iterable:
372                # See NOTE [ Custom Samplers and IterableDataset ]
373                sampler = _InfiniteConstantSampler()
374            else:  # map-style
375                if shuffle:
376                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
377                else:
378                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
379
380        if batch_size is not None and batch_sampler is None:
381            # auto_collation without custom batch_sampler
382            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
383
384        self.batch_size = batch_size
385        self.drop_last = drop_last
386        self.sampler = sampler
387        self.batch_sampler = batch_sampler
388        self.generator = generator
389
390        if collate_fn is None:
391            if self._auto_collation:
392                collate_fn = _utils.collate.default_collate
393            else:
394                collate_fn = _utils.collate.default_convert
395
396        self.collate_fn = collate_fn
397        self.persistent_workers = persistent_workers
398
399        self.__initialized = True
400        self._IterableDataset_len_called = (
401            None  # See NOTE [ IterableDataset and __len__ ]
402        )
403
404        self._iterator = None
405
406        self.check_worker_number_rationality()
407
408        torch.set_vital("Dataloader", "enabled", "True")  # type: ignore[attr-defined]
409
410    def _get_iterator(self) -> "_BaseDataLoaderIter":
411        if self.num_workers == 0:
412            return _SingleProcessDataLoaderIter(self)
413        else:
414            self.check_worker_number_rationality()
415            return _MultiProcessingDataLoaderIter(self)
416
417    @property
418    def multiprocessing_context(self):
419        return self.__multiprocessing_context
420
421    @multiprocessing_context.setter
422    def multiprocessing_context(self, multiprocessing_context):
423        if multiprocessing_context is not None:
424            if self.num_workers > 0:
425                if isinstance(multiprocessing_context, str):
426                    valid_start_methods = torch.multiprocessing.get_all_start_methods()
427                    if multiprocessing_context not in valid_start_methods:
428                        raise ValueError(
429                            "multiprocessing_context option "
430                            f"should specify a valid start method in {valid_start_methods!r}, but got "
431                            f"multiprocessing_context={multiprocessing_context!r}"
432                        )
433                    multiprocessing_context = torch.multiprocessing.get_context(
434                        multiprocessing_context
435                    )
436
437                if not isinstance(
438                    multiprocessing_context, python_multiprocessing.context.BaseContext
439                ):
440                    raise TypeError(
441                        "multiprocessing_context option should be a valid context "
442                        "object or a string specifying the start method, but got "
443                        f"multiprocessing_context={multiprocessing_context}"
444                    )
445            else:
446                raise ValueError(
447                    "multiprocessing_context can only be used with "
448                    "multi-process loading (num_workers > 0), but got "
449                    f"num_workers={self.num_workers}"
450                )
451
452        self.__multiprocessing_context = multiprocessing_context
453
454    def __setattr__(self, attr, val):
455        if self.__initialized and attr in (
456            "batch_size",
457            "batch_sampler",
458            "sampler",
459            "drop_last",
460            "dataset",
461            "persistent_workers",
462        ):
463            raise ValueError(
464                f"{attr} attribute should not be set after {self.__class__.__name__} is initialized"
465            )
466
467        super().__setattr__(attr, val)
468
469    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
470    # since '_BaseDataLoaderIter' references 'DataLoader'.
471    def __iter__(self) -> "_BaseDataLoaderIter":
472        # When using a single worker the returned iterator should be
473        # created everytime to avoid resetting its state
474        # However, in the case of a multiple workers iterator
475        # the iterator is only created once in the lifetime of the
476        # DataLoader object so that workers can be reused
477        if self.persistent_workers and self.num_workers > 0:
478            if self._iterator is None:
479                self._iterator = self._get_iterator()
480            else:
481                self._iterator._reset(self)
482            return self._iterator
483        else:
484            return self._get_iterator()
485
486    @property
487    def _auto_collation(self):
488        return self.batch_sampler is not None
489
490    @property
491    def _index_sampler(self):
492        # The actual sampler used for generating indices for `_DatasetFetcher`
493        # (see _utils/fetch.py) to read data at each time. This would be
494        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
495        # We can't change `.sampler` and `.batch_sampler` attributes for BC
496        # reasons.
497        if self._auto_collation:
498            return self.batch_sampler
499        else:
500            return self.sampler
501
502    def __len__(self) -> int:
503        if self._dataset_kind == _DatasetKind.Iterable:
504            # NOTE [ IterableDataset and __len__ ]
505            #
506            # For `IterableDataset`, `__len__` could be inaccurate when one naively
507            # does multi-processing data loading, since the samples will be duplicated.
508            # However, no real use case should be actually using that behavior, so
509            # it should count as a user error. We should generally trust user
510            # code to do the proper thing (e.g., configure each replica differently
511            # in `__iter__`), and give us the correct `__len__` if they choose to
512            # implement it (this will still throw if the dataset does not implement
513            # a `__len__`).
514            #
515            # To provide a further warning, we track if `__len__` was called on the
516            # `DataLoader`, save the returned value in `self._len_called`, and warn
517            # if the iterator ends up yielding more than this number of samples.
518
519            # Cannot statically verify that dataset is Sized
520            length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore[assignment, arg-type]
521            if (
522                self.batch_size is not None
523            ):  # IterableDataset doesn't allow custom sampler or batch_sampler
524                from math import ceil
525
526                if self.drop_last:
527                    length = length // self.batch_size
528                else:
529                    length = ceil(length / self.batch_size)
530            return length
531        else:
532            return len(self._index_sampler)
533
534    def check_worker_number_rationality(self):
535        # This function check whether the dataloader's worker number is rational based on
536        # current system's resource. Current rule is that if the number of workers this
537        # Dataloader will create is bigger than the number of logical cpus that is allowed to
538        # use, than we will pop up a warning to let user pay attention.
539        #
540        # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
541        #     threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
542        #     DataLoader process can use half of them which is 32, then the rational max number of
543        #     worker that initiated from this process is 32.
544        #     Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
545        #     So the warning message is triggered to notify the user to lower the worker number if
546        #     necessary.
547        #
548        #
549        # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
550        #        available (available in most of Linux system, but not OSX and Windows).
551        #        When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
552        #        it doesn't repect cpuset.
553        #        We don't take threading into account since each worker process is single threaded
554        #        at this time.
555        #
556        #        We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
557        #        other than `torch.set_num_threads` to 1 in the worker process, if the passing
558        #        in functions use 3rd party modules that rely on those threading flags to determine
559        #        how many thread to create (eg. numpy, etc), then it is caller's responsibility to
560        #        set those flags correctly.
561        def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
562            suggested_max_worker_msg = (
563                (
564                    (
565                        "Our suggested max number of worker in current system is {}{}, which is smaller "
566                        "than what this DataLoader is going to create."
567                    ).format(
568                        num_worker_suggest,
569                        (
570                            ""
571                            if cpuset_checked
572                            else " (`cpuset` is not taken into account)"
573                        ),
574                    )
575                )
576                if num_worker_suggest is not None
577                else (
578                    "DataLoader is not able to compute a suggested max number of worker in current system."
579                )
580            )
581
582            warn_msg = (
583                f"This DataLoader will create {num_worker_created} worker processes in total. {suggested_max_worker_msg} "
584                "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
585                "lower the worker number to avoid potential slowness/freeze if necessary."
586            )
587            return warn_msg
588
589        if not self.num_workers or self.num_workers == 0:
590            return
591
592        # try to compute a suggested max number of worker based on system's resource
593        max_num_worker_suggest = None
594        cpuset_checked = False
595        if hasattr(os, "sched_getaffinity"):
596            try:
597                max_num_worker_suggest = len(os.sched_getaffinity(0))
598                cpuset_checked = True
599            except Exception:
600                pass
601        if max_num_worker_suggest is None:
602            # os.cpu_count() could return Optional[int]
603            # get cpu count first and check None in order to satisfy mypy check
604            cpu_count = os.cpu_count()
605            if cpu_count is not None:
606                max_num_worker_suggest = cpu_count
607
608        if max_num_worker_suggest is None:
609            warnings.warn(
610                _create_warning_msg(
611                    max_num_worker_suggest, self.num_workers, cpuset_checked
612                )
613            )
614            return
615
616        if self.num_workers > max_num_worker_suggest:
617            warnings.warn(
618                _create_warning_msg(
619                    max_num_worker_suggest, self.num_workers, cpuset_checked
620                )
621            )
622
623
624class _BaseDataLoaderIter:
625    def __init__(self, loader: DataLoader) -> None:
626        self._dataset = loader.dataset
627        self._shared_seed = None
628        self._pg = None
629        if isinstance(self._dataset, IterDataPipe):
630            if dist.is_available() and dist.is_initialized():
631                self._pg = dist.new_group(backend="gloo")
632            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
633            shared_rng = torch.Generator()
634            shared_rng.manual_seed(self._shared_seed)
635            self._dataset = torch.utils.data.graph_settings.apply_random_seed(
636                self._dataset, shared_rng
637            )
638        self._dataset_kind = loader._dataset_kind
639        self._IterableDataset_len_called = loader._IterableDataset_len_called
640        self._auto_collation = loader._auto_collation
641        self._drop_last = loader.drop_last
642        self._index_sampler = loader._index_sampler
643        self._num_workers = loader.num_workers
644        ws, rank = _get_distributed_settings()
645        self._world_size = ws
646        self._rank = rank
647        # for other backends, pin_memory_device need to set. if not set
648        # default behaviour is CUDA device. if pin_memory_device is selected
649        # and pin_memory is not set, the default behaviour false.
650        if len(loader.pin_memory_device) == 0:
651            self._pin_memory = loader.pin_memory and torch.cuda.is_available()
652            self._pin_memory_device = None
653        else:
654            if not loader.pin_memory:
655                warn_msg = (
656                    "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
657                    "please set pin_memory to true, if you need to use the device pin memory"
658                )
659                warnings.warn(warn_msg)
660
661            self._pin_memory = loader.pin_memory
662            self._pin_memory_device = loader.pin_memory_device
663        self._timeout = loader.timeout
664        self._collate_fn = loader.collate_fn
665        self._sampler_iter = iter(self._index_sampler)
666        self._base_seed = (
667            torch.empty((), dtype=torch.int64)
668            .random_(generator=loader.generator)
669            .item()
670        )
671        self._persistent_workers = loader.persistent_workers
672        self._num_yielded = 0
673        self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
674
675    def __iter__(self) -> "_BaseDataLoaderIter":
676        return self
677
678    def _reset(self, loader, first_iter=False):
679        self._sampler_iter = iter(self._index_sampler)
680        self._num_yielded = 0
681        self._IterableDataset_len_called = loader._IterableDataset_len_called
682        if isinstance(self._dataset, IterDataPipe):
683            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
684            shared_rng = torch.Generator()
685            shared_rng.manual_seed(self._shared_seed)
686            self._dataset = torch.utils.data.graph_settings.apply_random_seed(
687                self._dataset, shared_rng
688            )
689
690    def _next_index(self):
691        return next(self._sampler_iter)  # may raise StopIteration
692
693    def _next_data(self):
694        raise NotImplementedError
695
696    def __next__(self) -> Any:
697        with torch.autograd.profiler.record_function(self._profile_name):
698            if self._sampler_iter is None:
699                # TODO(https://github.com/pytorch/pytorch/issues/76750)
700                self._reset()  # type: ignore[call-arg]
701            data = self._next_data()
702            self._num_yielded += 1
703            if (
704                self._dataset_kind == _DatasetKind.Iterable
705                and self._IterableDataset_len_called is not None
706                and self._num_yielded > self._IterableDataset_len_called
707            ):
708                warn_msg = (
709                    f"Length of IterableDataset {self._dataset} was reported to be {self._IterableDataset_len_called}"
710                    f"(when accessing len(dataloader)), but {self._num_yielded} samples have been fetched. "
711                )
712                if self._num_workers > 0:
713                    warn_msg += (
714                        "For multiprocessing data-loading, this could be caused by not properly configuring the "
715                        "IterableDataset replica at each worker. Please see "
716                        "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
717                    )
718                warnings.warn(warn_msg)
719            return data
720
721    def __len__(self) -> int:
722        return len(self._index_sampler)
723
724    def __getstate__(self):
725        # TODO: add limited pickling support for sharing an iterator
726        # across multiple threads for HOGWILD.
727        # Probably the best way to do this is by moving the sample pushing
728        # to a separate thread and then just sharing the data queue
729        # but signalling the end is tricky without a non-blocking API
730        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
731
732
733class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
734    def __init__(self, loader):
735        super().__init__(loader)
736        assert self._timeout == 0
737        assert self._num_workers == 0
738
739        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
740        #   Taking care of distributed sharding
741        if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
742            # For BC, use default SHARDING_PRIORITIES
743            torch.utils.data.graph_settings.apply_sharding(
744                self._dataset, self._world_size, self._rank
745            )
746
747        self._dataset_fetcher = _DatasetKind.create_fetcher(
748            self._dataset_kind,
749            self._dataset,
750            self._auto_collation,
751            self._collate_fn,
752            self._drop_last,
753        )
754
755    def _next_data(self):
756        index = self._next_index()  # may raise StopIteration
757        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
758        if self._pin_memory:
759            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
760        return data
761
762
763class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
764    r"""Iterates once over the DataLoader's dataset, as specified by the sampler."""
765
766    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
767    #
768    # Preliminary:
769    #
770    # Our data model looks like this (queues are indicated with curly brackets):
771    #
772    #                main process                              ||
773    #                     |                                    ||
774    #               {index_queue}                              ||
775    #                     |                                    ||
776    #              worker processes                            ||     DATA
777    #                     |                                    ||
778    #            {worker_result_queue}                         ||     FLOW
779    #                     |                                    ||
780    #      pin_memory_thread of main process                   ||   DIRECTION
781    #                     |                                    ||
782    #               {data_queue}                               ||
783    #                     |                                    ||
784    #                data output                               \/
785    #
786    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
787    #      `pin_memory=False`.
788    #
789    #
790    # Terminating multiprocessing logic requires very careful design. In
791    # particular, we need to make sure that
792    #
793    #   1. The iterator gracefully exits the workers when its last reference is
794    #      gone or it is depleted.
795    #
796    #      In this case, the workers should be gracefully exited because the
797    #      main process may still need to continue to run, and we want cleaning
798    #      up code in the workers to be executed (e.g., releasing GPU memory).
799    #      Naturally, we implement the shutdown logic in `__del__` of
800    #      DataLoaderIterator.
801    #
802    #      We delay the discussion on the logic in this case until later.
803    #
804    #   2. The iterator exits the workers when the loader process and/or worker
805    #      processes exits normally or with error.
806    #
807    #      We set all workers and `pin_memory_thread` to have `daemon=True`.
808    #
809    #      You may ask, why can't we make the workers non-daemonic, and
810    #      gracefully exit using the same logic as we have in `__del__` when the
811    #      iterator gets deleted (see 1 above)?
812    #
813    #      First of all, `__del__` is **not** guaranteed to be called when
814    #      interpreter exits. Even if it is called, by the time it executes,
815    #      many Python core library resources may already be freed, and even
816    #      simple things like acquiring an internal lock of a queue may hang.
817    #      Therefore, in this case, we actually need to prevent `__del__` from
818    #      being executed, and rely on the automatic termination of daemonic
819    #      children.
820    #
821    #      Thus, we register an `atexit` hook that sets a global flag
822    #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
823    #      reverse order of registration, we are guaranteed that this flag is
824    #      set before library resources we use are freed (which, at least in
825    #      CPython, is done via an `atexit` handler defined in
826    #      `multiprocessing/util.py`
827    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
828    #      registered when an object requiring this mechanism is first
829    #      created, e.g., `mp.Queue`
830    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
831    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
832    #      )
833    #
834    #      So in `__del__`, we check if `_utils.python_exit_status` is set or
835    #      `None` (freed), and perform no-op if so.
836    #
837    #      However, simply letting library clean-up codes run can also be bad,
838    #      because such codes (i.e., `multiprocessing.util._exit_function()`)
839    #      include join putting threads for `mp.Queue`, which can be blocking.
840    #      Hence, the main process putting threads are called with
841    #      `cancel_join_thread` at creation.  See later section
842    #      [ 3b. A process won't hang when putting into a queue; ]
843    #      for more details.
844    #
845    #      Here are two example cases where library clean-up codes can run
846    #      before `__del__` is called:
847    #
848    #        1. If we hold onto a reference to the iterator, it more often
849    #           than not tries to do `multiprocessing` library cleaning before
850    #           clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
851    #           and thus prevents our cleaning-up code to run first.
852    #
853    #        2. A similar issue araises when a `DataLoader` is used in a subprocess.
854    #           When a process ends, it shuts the all its daemonic children
855    #           down with a SIGTERM (instead of joining them without a timeout).
856    #           Simiarly for threads, but by a different mechanism. This fact,
857    #           together with a few implementation details of multiprocessing, forces
858    #           us to make workers daemonic. All of our problems arise when a
859    #           DataLoader is used in a subprocess, and are caused by multiprocessing
860    #           code which looks more or less like this:
861    #
862    #               try:
863    #                   your_function_using_a_dataloader()
864    #               finally:
865    #                   multiprocessing.util._exit_function()
866    #
867    #           The joining/termination mentioned above happens inside
868    #           `_exit_function()`. Now, if `your_function_using_a_dataloader()`
869    #           throws, the stack trace stored in the exception will prevent the
870    #           frame which uses `DataLoaderIter` to be freed. If the frame has any
871    #           reference to the `DataLoaderIter` (e.g., in a method of the iter),
872    #           its  `__del__`, which starts the shutdown procedure, will not be
873    #           called. That, in turn, means that workers aren't notified. Attempting
874    #           to join in `_exit_function` will then result in a hang.
875    #
876    #           For context, `_exit_function` is also registered as an `atexit` call.
877    #           So it is unclear to me (@ssnl) why this is needed in a finally block.
878    #           The code dates back to 2008 and there is no comment on the original
879    #           PEP 371 or patch https://bugs.python.org/issue3050 (containing both
880    #           the finally block and the `atexit` registration) that explains this.
881    #
882    #
883    #      Finally, another choice is to just shutdown workers with logic in 1
884    #      above whenever we see an error in `next`. This isn't ideal because
885    #        a. It prevents users from using try-catch to resume data loading.
886    #        b. It doesn't prevent hanging if users have references to the
887    #           iterator.
888    #
889    #   3. All processes exit if any of them die unexpectedly by fatal signals.
890    #
891    #      As shown above, the workers are set as daemonic children of the main
892    #      process. However, automatic cleaning-up of such child processes only
893    #      happens if the parent process exits gracefully (e.g., not via fatal
894    #      signals like SIGKILL). So we must ensure that each process will exit
895    #      even the process that should send/receive data to/from it were
896    #      killed, i.e.,
897    #
898    #        a. A process won't hang when getting from a queue.
899    #
900    #           Even with carefully designed data dependencies (i.e., a `put()`
901    #           always corresponding to a `get()`), hanging on `get()` can still
902    #           happen when data in queue is corrupted (e.g., due to
903    #           `cancel_join_thread` or unexpected exit).
904    #
905    #           For child exit, we set a timeout whenever we try to get data
906    #           from `data_queue`, and check the workers' status on each timeout
907    #           and error.
908    #           See `_DataLoaderiter._get_batch()` and
909    #           `_DataLoaderiter._try_get_data()` for details.
910    #
911    #           Additionally, for child exit on non-Windows platforms, we also
912    #           register a SIGCHLD handler (which is supported on Windows) on
913    #           the main process, which checks if any of the workers fail in the
914    #           (Python) handler. This is more efficient and faster in detecting
915    #           worker failures, compared to only using the above mechanism.
916    #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
917    #
918    #           For `.get()` calls where the sender(s) is not the workers, we
919    #           guard them with timeouts, and check the status of the sender
920    #           when timeout happens:
921    #             + in the workers, the `_utils.worker.ManagerWatchdog` class
922    #               checks the status of the main process.
923    #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
924    #               check `pin_memory_thread` status periodically until `.get()`
925    #               returns or see that `pin_memory_thread` died.
926    #
927    #        b. A process won't hang when putting into a queue;
928    #
929    #           We use `mp.Queue` which has a separate background thread to put
930    #           objects from an unbounded buffer array. The background thread is
931    #           daemonic and usually automatically joined when the process
932    #           *exits*.
933    #
934    #           In case that the receiver has ended abruptly while
935    #           reading from the pipe, the join will hang forever.  The usual
936    #           solution for this in Python is calling  `q.cancel_join_thread`,
937    #           which prevents automatically joining it when finalizing
938    #           (exiting).
939    #
940    #           Nonetheless, `cancel_join_thread` must only be called when the
941    #           queue is **not** going to be read from or write into by another
942    #           process, because it may hold onto a lock or leave corrupted data
943    #           in the queue, leading other readers/writers to hang.
944    #
945    #           Hence,
946    #             + For worker processes, we only do so (for their output
947    #               queues, i.e., `worker_result_queue`) before exiting.
948    #             + For `pin_memory_thread`, its output queue `data_queue` is a
949    #               `queue.Queue` that does blocking `put` if the queue is full.
950    #               So there is no above problem, but as a result, in
951    #               `_pin_memory_loop`, we do need to  wrap the `put` in a loop
952    #               that breaks not only upon success, but also when the main
953    #               process stops reading, i.e., is shutting down.
954    #             + For loader process, we `cancel_join_thread()` for all
955    #               `_index_queues` because the whole purpose of workers and
956    #               `pin_memory_thread` is to serve the loader process.  If
957    #               loader process is already exiting, we don't really care if
958    #               the queues are corrupted.
959    #
960    #
961    # Now let's get back to 1:
962    #   how we gracefully exit the workers when the last reference to the
963    #   iterator is gone.
964    #
965    # To achieve this, we implement the following logic along with the design
966    # choices mentioned above:
967    #
968    # `workers_done_event`:
969    #   A `multiprocessing.Event` shared among the main process and all worker
970    #   processes. This is used to signal the workers that the iterator is
971    #   shutting down. After it is set, they will not send processed data to
972    #   queues anymore, and only wait for the final `None` before exiting.
973    #   `done_event` isn't strictly needed. I.e., we can just check for `None`
974    #   from the input queue, but it allows us to skip wasting resources
975    #   processing data if we are already shutting down.
976    #
977    # `pin_memory_thread_done_event`:
978    #   A `threading.Event` for a similar purpose to that of
979    #   `workers_done_event`, but is for the `pin_memory_thread`. The reason
980    #   that separate events are needed is that `pin_memory_thread` reads from
981    #   the output queue of the workers. But the workers, upon seeing that
982    #   `workers_done_event` is set, only wants to see the final `None`, and is
983    #   not required to flush all data in the output queue (e.g., it may call
984    #   `cancel_join_thread` on that queue if its `IterableDataset` iterator
985    #   happens to exhaust coincidentally, which is out of the control of the
986    #   main process). Thus, since we will exit `pin_memory_thread` before the
987    #   workers (see below), two separete events are used.
988    #
989    # NOTE: In short, the protocol is that the main process will set these
990    #       `done_event`s and then the corresponding processes/threads a `None`,
991    #       and that they may exit at any time after receiving the `None`.
992    #
993    # NOTE: Using `None` as the final signal is valid, since normal data will
994    #       always be a 2-tuple with the 1st element being the index of the data
995    #       transferred (different from dataset index/key), and the 2nd being
996    #       either the dataset key or the data sample (depending on which part
997    #       of the data model the queue is at).
998    #
999    # [ worker processes ]
1000    #   While loader process is alive:
1001    #     Get from `index_queue`.
1002    #       If get anything else,
1003    #          Check `workers_done_event`.
1004    #            If set, continue to next iteration
1005    #                    i.e., keep getting until see the `None`, then exit.
1006    #            Otherwise, process data:
1007    #                If is fetching from an `IterableDataset` and the iterator
1008    #                    is exhausted, send an `_IterableDatasetStopIteration`
1009    #                    object to signal iteration end. The main process, upon
1010    #                    receiving such an object, will send `None` to this
1011    #                    worker and not use the corresponding `index_queue`
1012    #                    anymore.
1013    #       If timed out,
1014    #          No matter `workers_done_event` is set (still need to see `None`)
1015    #          or not, must continue to next iteration.
1016    #   (outside loop)
1017    #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
1018    #     `data_queue.cancel_join_thread()`.  (Everything is ending here:
1019    #                                          main process won't read from it;
1020    #                                          other workers will also call
1021    #                                          `cancel_join_thread`.)
1022    #
1023    # [ pin_memory_thread ]
1024    #   # No need to check main thread. If this thread is alive, the main loader
1025    #   # thread must be alive, because this thread is set as daemonic.
1026    #   While `pin_memory_thread_done_event` is not set:
1027    #     Get from `worker_result_queue`.
1028    #       If timed out, continue to get in the next iteration.
1029    #       Otherwise, process data.
1030    #       While `pin_memory_thread_done_event` is not set:
1031    #         Put processed data to `data_queue` (a `queue.Queue` with blocking put)
1032    #         If timed out, continue to put in the next iteration.
1033    #         Otherwise, break, i.e., continuing to the out loop.
1034    #
1035    #   NOTE: we don't check the status of the main thread because
1036    #           1. if the process is killed by fatal signal, `pin_memory_thread`
1037    #              ends.
1038    #           2. in other cases, either the cleaning-up in __del__ or the
1039    #              automatic exit of daemonic thread will take care of it.
1040    #              This won't busy-wait either because `.get(timeout)` does not
1041    #              busy-wait.
1042    #
1043    # [ main process ]
1044    #   In the DataLoader Iter's `__del__`
1045    #     b. Exit `pin_memory_thread`
1046    #          i.   Set `pin_memory_thread_done_event`.
1047    #          ii   Put `None` in `worker_result_queue`.
1048    #          iii. Join the `pin_memory_thread`.
1049    #          iv.  `worker_result_queue.cancel_join_thread()`.
1050    #
1051    #     c. Exit the workers.
1052    #          i.   Set `workers_done_event`.
1053    #          ii.  Put `None` in each worker's `index_queue`.
1054    #          iii. Join the workers.
1055    #          iv.  Call `.cancel_join_thread()` on each worker's `index_queue`.
1056    #
1057    #        NOTE: (c) is better placed after (b) because it may leave corrupted
1058    #              data in `worker_result_queue`, which `pin_memory_thread`
1059    #              reads from, in which case the `pin_memory_thread` can only
1060    #              happen at timing out, which is slow. Nonetheless, same thing
1061    #              happens if a worker is killed by signal at unfortunate times,
1062    #              but in other cases, we are better off having a non-corrupted
1063    #              `worker_result_queue` for `pin_memory_thread`.
1064    #
1065    #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
1066    #         can be omitted
1067    #
1068    # NB: `done_event`s isn't strictly needed. E.g., we can just check for
1069    #     `None` from `index_queue`, but it allows us to skip wasting resources
1070    #     processing indices already in `index_queue` if we are already shutting
1071    #     down.
1072
1073    def __init__(self, loader):
1074        super().__init__(loader)
1075
1076        self._prefetch_factor = loader.prefetch_factor
1077
1078        assert self._num_workers > 0
1079        assert self._prefetch_factor > 0
1080
1081        if loader.multiprocessing_context is None:
1082            multiprocessing_context = torch.multiprocessing
1083        else:
1084            multiprocessing_context = loader.multiprocessing_context
1085
1086        self._worker_init_fn = loader.worker_init_fn
1087
1088        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
1089        #   Additional worker init function will take care of sharding in MP and Distributed
1090        if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
1091            self._worker_init_fn = functools.partial(
1092                _sharding_worker_init_fn,
1093                self._worker_init_fn,
1094                self._world_size,
1095                self._rank,
1096            )
1097
1098        # No certainty which module multiprocessing_context is
1099        self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
1100        self._worker_pids_set = False
1101        self._shutdown = False
1102        self._workers_done_event = multiprocessing_context.Event()
1103
1104        self._index_queues = []
1105        self._workers = []
1106        for i in range(self._num_workers):
1107            # No certainty which module multiprocessing_context is
1108            index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
1109            # Need to `cancel_join_thread` here!
1110            # See sections (2) and (3b) above.
1111            index_queue.cancel_join_thread()
1112            w = multiprocessing_context.Process(
1113                target=_utils.worker._worker_loop,
1114                args=(
1115                    self._dataset_kind,
1116                    self._dataset,
1117                    index_queue,
1118                    self._worker_result_queue,
1119                    self._workers_done_event,
1120                    self._auto_collation,
1121                    self._collate_fn,
1122                    self._drop_last,
1123                    self._base_seed,
1124                    self._worker_init_fn,
1125                    i,
1126                    self._num_workers,
1127                    self._persistent_workers,
1128                    self._shared_seed,
1129                ),
1130            )
1131            w.daemon = True
1132            # NB: Process.start() actually take some time as it needs to
1133            #     start a process and pass the arguments over via a pipe.
1134            #     Therefore, we only add a worker to self._workers list after
1135            #     it started, so that we do not call .join() if program dies
1136            #     before it starts, and __del__ tries to join but will get:
1137            #     AssertionError: can only join a started process.
1138            w.start()
1139            self._index_queues.append(index_queue)
1140            self._workers.append(w)
1141
1142        if self._pin_memory:
1143            self._pin_memory_thread_done_event = threading.Event()
1144
1145            # Queue is not type-annotated
1146            self._data_queue = queue.Queue()  # type: ignore[var-annotated]
1147            if self._pin_memory_device == "xpu":
1148                current_device = torch.xpu.current_device()  # type: ignore[attr-defined]
1149            elif self._pin_memory_device == torch._C._get_privateuse1_backend_name():
1150                custom_device_mod = getattr(
1151                    torch, torch._C._get_privateuse1_backend_name()
1152                )
1153                current_device = custom_device_mod.current_device()
1154            else:
1155                current_device = torch.cuda.current_device()  # choose cuda for default
1156            pin_memory_thread = threading.Thread(
1157                target=_utils.pin_memory._pin_memory_loop,
1158                args=(
1159                    self._worker_result_queue,
1160                    self._data_queue,
1161                    current_device,
1162                    self._pin_memory_thread_done_event,
1163                    self._pin_memory_device,
1164                ),
1165            )
1166            pin_memory_thread.daemon = True
1167            pin_memory_thread.start()
1168            # Similar to workers (see comment above), we only register
1169            # pin_memory_thread once it is started.
1170            self._pin_memory_thread = pin_memory_thread
1171        else:
1172            self._data_queue = self._worker_result_queue  # type: ignore[assignment]
1173
1174        # In some rare cases, persistent workers (daemonic processes)
1175        # would be terminated before `__del__` of iterator is invoked
1176        # when main process exits
1177        # It would cause failure when pin_memory_thread tries to read
1178        # corrupted data from worker_result_queue
1179        # atexit is used to shutdown thread and child processes in the
1180        # right sequence before main process exits
1181        if self._persistent_workers and self._pin_memory:
1182            import atexit
1183
1184            for w in self._workers:
1185                atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
1186
1187        # .pid can be None only before process is spawned (not the case, so ignore)
1188        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]
1189        _utils.signal_handling._set_SIGCHLD_handler()
1190        self._worker_pids_set = True
1191        self._reset(loader, first_iter=True)
1192
1193    def _reset(self, loader, first_iter=False):
1194        super()._reset(loader, first_iter)
1195        self._send_idx = 0  # idx of the next task to be sent to workers
1196        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
1197        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
1198        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
1199        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
1200        self._task_info = {}
1201        self._tasks_outstanding = (
1202            0  # always equal to count(v for v in task_info.values() if len(v) == 1)
1203        )
1204        # A list of booleans representing whether each worker still has work to
1205        # do, i.e., not having exhausted its iterable dataset object. It always
1206        # contains all `True`s if not using an iterable-style dataset
1207        # (i.e., if kind != Iterable).
1208        # Not that this indicates that a worker still has work to do *for this epoch*.
1209        # It does not mean that a worker is dead. In case of `_persistent_workers`,
1210        # the worker will be reset to available in the next epoch.
1211        self._workers_status = [True for i in range(self._num_workers)]
1212        # Reset the worker queue cycle so it resumes next epoch at worker 0
1213        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
1214        # We resume the prefetching in case it was enabled
1215        if not first_iter:
1216            for idx in range(self._num_workers):
1217                self._index_queues[idx].put(
1218                    _utils.worker._ResumeIteration(self._shared_seed)
1219                )
1220            resume_iteration_cnt = self._num_workers
1221            while resume_iteration_cnt > 0:
1222                return_idx, return_data = self._get_data()
1223                if isinstance(return_idx, _utils.worker._ResumeIteration):
1224                    assert return_data is None
1225                    resume_iteration_cnt -= 1
1226        # prime the prefetch loop
1227        for _ in range(self._prefetch_factor * self._num_workers):
1228            self._try_put_index()
1229
1230    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
1231        # Tries to fetch data from `self._data_queue` once for a given timeout.
1232        # This can also be used as inner loop of fetching without timeout, with
1233        # the sender status as the loop condition.
1234        #
1235        # This raises a `RuntimeError` if any worker died expectedly. This error
1236        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
1237        # (only for non-Windows platforms), or the manual check below on errors
1238        # and timeouts.
1239        #
1240        # Returns a 2-tuple:
1241        #   (bool: whether successfully get data, any: data if successful else None)
1242        try:
1243            data = self._data_queue.get(timeout=timeout)
1244            return (True, data)
1245        except Exception as e:
1246            # At timeout and error, we manually check whether any worker has
1247            # failed. Note that this is the only mechanism for Windows to detect
1248            # worker failures.
1249            failed_workers = []
1250            for worker_id, w in enumerate(self._workers):
1251                if self._workers_status[worker_id] and not w.is_alive():
1252                    failed_workers.append(w)
1253                    self._mark_worker_as_unavailable(worker_id)
1254            if len(failed_workers) > 0:
1255                pids_str = ", ".join(str(w.pid) for w in failed_workers)
1256                raise RuntimeError(
1257                    f"DataLoader worker (pid(s) {pids_str}) exited unexpectedly"
1258                ) from e
1259            if isinstance(e, queue.Empty):
1260                return (False, None)
1261
1262            import errno
1263            import tempfile
1264
1265            try:
1266                # Raise an exception if we are this close to the FDs limit.
1267                # Apparently, trying to open only one file is not a sufficient
1268                # test.
1269                # See NOTE [ DataLoader on Linux and open files limit ]
1270                fds_limit_margin = 10
1271                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
1272            except OSError as e:
1273                if e.errno == errno.EMFILE:
1274                    raise RuntimeError(
1275                        "Too many open files. Communication with the"
1276                        " workers is no longer possible. Please increase the"
1277                        " limit using `ulimit -n` in the shell or change the"
1278                        " sharing strategy by calling"
1279                        " `torch.multiprocessing.set_sharing_strategy('file_system')`"
1280                        " at the beginning of your code"
1281                    ) from None
1282            raise
1283
1284    # NOTE [ DataLoader on Linux and open files limit ]
1285    #
1286    # On Linux when DataLoader is used with multiprocessing we pass the data between
1287    # the root process and the workers through SHM files. We remove those files from
1288    # the filesystem as soon as they are created and keep them alive by
1289    # passing around their file descriptors through AF_UNIX sockets. (See
1290    # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
1291    # the wiki (https://github.com/pytorch/pytorch/wiki).)
1292    #
1293    # This sometimes leads us to exceeding the open files limit. When that happens,
1294    # and the offending file descriptor is coming over a socket, the `socket` Python
1295    # package silently strips the file descriptor from the message, setting only the
1296    # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
1297    # it _indicates that some control data were discarded due to lack of space in
1298    # the buffer for ancillary data_). This might reflect the C implementation of
1299    # AF_UNIX sockets.
1300    #
1301    # This behaviour can be reproduced with the script and instructions at the
1302    # bottom of this note.
1303    #
1304    # When that happens, the standard Python `multiprocessing` (and not
1305    # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
1306    #
1307    # Sometimes, instead of the FD being stripped, you may get an `OSError:
1308    # Too many open files`, both in the script below and in DataLoader. However,
1309    # this is rare and seems to be nondeterministic.
1310    #
1311    #
1312    #   #!/usr/bin/env python3
1313    #   import sys
1314    #   import socket
1315    #   import os
1316    #   import array
1317    #   import shutil
1318    #   import socket
1319    #
1320    #
1321    #   if len(sys.argv) != 4:
1322    #       print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
1323    #       sys.exit(1)
1324    #
1325    #   if __name__ == '__main__':
1326    #       dirname = sys.argv[1]
1327    #       sock_path = dirname + "/sock"
1328    #       iterations = int(sys.argv[2])
1329    #       def dummy_path(i):
1330    #           return dirname + "/" + str(i) + ".dummy"
1331    #
1332    #
1333    #       if sys.argv[3] == 'send':
1334    #           while not os.path.exists(sock_path):
1335    #               pass
1336    #           client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1337    #           client.connect(sock_path)
1338    #           for i in range(iterations):
1339    #               fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
1340    #               ancdata = array.array('i', [fd])
1341    #               msg = bytes([i % 256])
1342    #               print("Sending fd ", fd, " (iteration #", i, ")")
1343    #               client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
1344    #
1345    #
1346    #       else:
1347    #           assert sys.argv[3] == 'recv'
1348    #
1349    #           if os.path.exists(dirname):
1350    #               raise Exception("Directory exists")
1351    #
1352    #           os.mkdir(dirname)
1353    #
1354    #           print("Opening socket...")
1355    #           server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
1356    #           server.bind(sock_path)
1357    #
1358    #           print("Listening...")
1359    #           for i in range(iterations):
1360    #               a = array.array('i')
1361    #               msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
1362    #               assert(len(ancdata) == 1)
1363    #               cmsg_level, cmsg_type, cmsg_data = ancdata[0]
1364    #               a.frombytes(cmsg_data)
1365    #               print("Received fd ", a[0], " (iteration #", i, ")")
1366    #
1367    #           shutil.rmtree(dirname)
1368    #
1369    # Steps to reproduce:
1370    #
1371    # 1. Run two shells and set lower file descriptor limit in the receiving one:
1372    # (shell1) ulimit -n 1020
1373    # (shell2) ulimit -n 1022
1374    #
1375    # 2. Run the script above with the `recv` option in the first shell
1376    # (shell1) ./test_socket.py sock_tmp 1017 recv
1377    #
1378    # 3. Run the script with the `send` option in the second shell:
1379    # (shell2) ./test_socket.py sock_tmp 1017 send
1380
1381    def _get_data(self):
1382        # Fetches data from `self._data_queue`.
1383        #
1384        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
1385        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
1386        # in a loop. This is the only mechanism to detect worker failures for
1387        # Windows. For other platforms, a SIGCHLD handler is also used for
1388        # worker failure detection.
1389        #
1390        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
1391        # died at timeouts.
1392        if self._timeout > 0:
1393            success, data = self._try_get_data(self._timeout)
1394            if success:
1395                return data
1396            else:
1397                raise RuntimeError(
1398                    f"DataLoader timed out after {self._timeout} seconds"
1399                )
1400        elif self._pin_memory:
1401            while self._pin_memory_thread.is_alive():
1402                success, data = self._try_get_data()
1403                if success:
1404                    return data
1405            else:
1406                # while condition is false, i.e., pin_memory_thread died.
1407                raise RuntimeError("Pin memory thread exited unexpectedly")
1408            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
1409            # need to call `.task_done()` because we don't use `.join()`.
1410        else:
1411            while True:
1412                success, data = self._try_get_data()
1413                if success:
1414                    return data
1415
1416    def _next_data(self):
1417        while True:
1418            # If the worker responsible for `self._rcvd_idx` has already ended
1419            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
1420            # we try to advance `self._rcvd_idx` to find the next valid index.
1421            #
1422            # This part needs to run in the loop because both the `self._get_data()`
1423            # call and `_IterableDatasetStopIteration` check below can mark
1424            # extra worker(s) as dead.
1425            while self._rcvd_idx < self._send_idx:
1426                info = self._task_info[self._rcvd_idx]
1427                worker_id = info[0]
1428                if (
1429                    len(info) == 2 or self._workers_status[worker_id]
1430                ):  # has data or is still active
1431                    break
1432                del self._task_info[self._rcvd_idx]
1433                self._rcvd_idx += 1
1434            else:
1435                # no valid `self._rcvd_idx` is found (i.e., didn't break)
1436                if not self._persistent_workers:
1437                    self._shutdown_workers()
1438                raise StopIteration
1439
1440            # Now `self._rcvd_idx` is the batch index we want to fetch
1441
1442            # Check if the next sample has already been generated
1443            if len(self._task_info[self._rcvd_idx]) == 2:
1444                data = self._task_info.pop(self._rcvd_idx)[1]
1445                return self._process_data(data)
1446
1447            assert not self._shutdown and self._tasks_outstanding > 0
1448            idx, data = self._get_data()
1449            self._tasks_outstanding -= 1
1450            if self._dataset_kind == _DatasetKind.Iterable:
1451                # Check for _IterableDatasetStopIteration
1452                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
1453                    if self._persistent_workers:
1454                        self._workers_status[data.worker_id] = False
1455                    else:
1456                        self._mark_worker_as_unavailable(data.worker_id)
1457                    self._try_put_index()
1458                    continue
1459
1460            if idx != self._rcvd_idx:
1461                # store out-of-order samples
1462                self._task_info[idx] += (data,)
1463            else:
1464                del self._task_info[idx]
1465                return self._process_data(data)
1466
1467    def _try_put_index(self):
1468        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
1469
1470        try:
1471            index = self._next_index()
1472        except StopIteration:
1473            return
1474        for _ in range(self._num_workers):  # find the next active worker, if any
1475            worker_queue_idx = next(self._worker_queue_idx_cycle)
1476            if self._workers_status[worker_queue_idx]:
1477                break
1478        else:
1479            # not found (i.e., didn't break)
1480            return
1481
1482        self._index_queues[worker_queue_idx].put((self._send_idx, index))  # type: ignore[possibly-undefined]
1483        self._task_info[self._send_idx] = (worker_queue_idx,)
1484        self._tasks_outstanding += 1
1485        self._send_idx += 1
1486
1487    def _process_data(self, data):
1488        self._rcvd_idx += 1
1489        self._try_put_index()
1490        if isinstance(data, ExceptionWrapper):
1491            data.reraise()
1492        return data
1493
1494    def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
1495        # Mark a worker as having finished its work e.g., due to
1496        # exhausting an `IterableDataset`. This should be used only when this
1497        # `_MultiProcessingDataLoaderIter` is going to continue running.
1498
1499        assert self._workers_status[worker_id] or (
1500            self._persistent_workers and shutdown
1501        )
1502
1503        # Signal termination to that specific worker.
1504        q = self._index_queues[worker_id]
1505        # Indicate that no more data will be put on this queue by the current
1506        # process.
1507        q.put(None)
1508
1509        # Note that we don't actually join the worker here, nor do we remove the
1510        # worker's pid from C side struct because (1) joining may be slow, and
1511        # (2) since we don't join, the worker may still raise error, and we
1512        # prefer capturing those, rather than ignoring them, even though they
1513        # are raised after the worker has finished its job.
1514        # Joinning is deferred to `_shutdown_workers`, which it is called when
1515        # all workers finish their jobs (e.g., `IterableDataset` replicas) or
1516        # when this iterator is garbage collected.
1517
1518        self._workers_status[worker_id] = False
1519
1520        assert self._workers_done_event.is_set() == shutdown
1521
1522    def _shutdown_workers(self):
1523        # Called when shutting down this `_MultiProcessingDataLoaderIter`.
1524        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
1525        # the logic of this function.
1526        if (
1527            _utils is None
1528            or _utils.python_exit_status is True
1529            or _utils.python_exit_status is None
1530        ):
1531            # See (2) of the note. If Python is shutting down, do no-op.
1532            return
1533        # Normal exit when last reference is gone / iterator is depleted.
1534        # See (1) and the second half of the note.
1535        if not self._shutdown:
1536            self._shutdown = True
1537            try:
1538                # Normal exit when last reference is gone / iterator is depleted.
1539                # See (1) and the second half of the note.
1540
1541                # Exit `pin_memory_thread` first because exiting workers may leave
1542                # corrupted data in `worker_result_queue` which `pin_memory_thread`
1543                # reads from.
1544                if hasattr(self, "_pin_memory_thread"):
1545                    # Use hasattr in case error happens before we set the attribute.
1546                    self._pin_memory_thread_done_event.set()
1547                    # Send something to pin_memory_thread in case it is waiting
1548                    # so that it can wake up and check `pin_memory_thread_done_event`
1549                    self._worker_result_queue.put((None, None))
1550                    self._pin_memory_thread.join()
1551                    self._worker_result_queue.cancel_join_thread()
1552                    self._worker_result_queue.close()
1553
1554                # Exit workers now.
1555                self._workers_done_event.set()
1556                for worker_id in range(len(self._workers)):
1557                    # Get number of workers from `len(self._workers)` instead of
1558                    # `self._num_workers` in case we error before starting all
1559                    # workers.
1560                    # If we are using workers_status with persistent_workers
1561                    # we have to shut it down because the worker is paused
1562                    if self._persistent_workers or self._workers_status[worker_id]:
1563                        self._mark_worker_as_unavailable(worker_id, shutdown=True)
1564                for w in self._workers:
1565                    # We should be able to join here, but in case anything went
1566                    # wrong, we set a timeout and if the workers fail to join,
1567                    # they are killed in the `finally` block.
1568                    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1569                for q in self._index_queues:
1570                    q.cancel_join_thread()
1571                    q.close()
1572            finally:
1573                # Even though all this function does is putting into queues that
1574                # we have called `cancel_join_thread` on, weird things can
1575                # happen when a worker is killed by a signal, e.g., hanging in
1576                # `Event.set()`. So we need to guard this with SIGCHLD handler,
1577                # and remove pids from the C side data structure only at the
1578                # end.
1579                #
1580                # FIXME: Unfortunately, for Windows, we are missing a worker
1581                #        error detection mechanism here in this function, as it
1582                #        doesn't provide a SIGCHLD handler.
1583                if self._worker_pids_set:
1584                    _utils.signal_handling._remove_worker_pids(id(self))
1585                    self._worker_pids_set = False
1586                for w in self._workers:
1587                    if w.is_alive():
1588                        # Existing mechanisms try to make the workers exit
1589                        # peacefully, but in case that we unfortunately reach
1590                        # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
1591                        # we kill the worker.
1592                        w.terminate()
1593
1594    # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
1595    @staticmethod
1596    def _clean_up_worker(w):
1597        try:
1598            w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
1599        finally:
1600            if w.is_alive():
1601                w.terminate()
1602
1603    def __del__(self):
1604        self._shutdown_workers()
1605