xref: /aosp_15_r20/external/pytorch/torch/utils/data/_utils/pin_memory.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory.
3
4These **needs** to be in global scope since Py2 doesn't support serializing
5static methods.
6"""
7
8import collections
9import copy
10import queue
11
12import torch
13from torch._utils import ExceptionWrapper
14
15from . import MP_STATUS_CHECK_INTERVAL
16
17
18def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
19    # This setting is thread local, and prevents the copy in pin_memory from
20    # consuming all CPU cores.
21    torch.set_num_threads(1)
22
23    torch.multiprocessing._set_thread_name("pt_data_pin")
24
25    if device == "cuda":
26        torch.cuda.set_device(device_id)
27    elif device == "xpu":
28        torch.xpu.set_device(device_id)  # type: ignore[attr-defined]
29    elif device == torch._C._get_privateuse1_backend_name():
30        custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name())
31        custom_device_mod.set_device(device_id)
32
33    def do_one_step():
34        try:
35            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
36        except queue.Empty:
37            return
38        idx, data = r
39        if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
40            try:
41                data = pin_memory(data, device)
42            except Exception:
43                data = ExceptionWrapper(
44                    where=f"in pin memory thread for device {device_id}"
45                )
46            r = (idx, data)
47        while not done_event.is_set():
48            try:
49                out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
50                break
51            except queue.Full:
52                continue
53
54    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
55    # logic of this function.
56    while not done_event.is_set():
57        # Make sure that we don't preserve any object from one iteration
58        # to the next
59        do_one_step()
60
61
62def pin_memory(data, device=None):
63    if isinstance(data, torch.Tensor):
64        return data.pin_memory(device)
65    elif isinstance(data, (str, bytes)):
66        return data
67    elif isinstance(data, collections.abc.Mapping):
68        try:
69            if isinstance(data, collections.abc.MutableMapping):
70                # The sequence type may have extra properties, so we can't just
71                # use `type(data)(...)` to create the new sequence.
72                # Create a clone and update it if the sequence type is mutable.
73                clone = copy.copy(data)
74                clone.update(
75                    {k: pin_memory(sample, device) for k, sample in data.items()}
76                )
77                return clone
78            else:
79                return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
80        except TypeError:
81            # The mapping type may not support `copy()` / `update(mapping)`
82            # or `__init__(iterable)`.
83            return {k: pin_memory(sample, device) for k, sample in data.items()}
84    elif isinstance(data, tuple) and hasattr(data, "_fields"):  # namedtuple
85        return type(data)(*(pin_memory(sample, device) for sample in data))
86    elif isinstance(data, tuple):
87        return [
88            pin_memory(sample, device) for sample in data
89        ]  # Backwards compatibility.
90    elif isinstance(data, collections.abc.Sequence):
91        try:
92            if isinstance(data, collections.abc.MutableSequence):
93                # The sequence type may have extra properties, so we can't just
94                # use `type(data)(...)` to create the new sequence.
95                # Create a clone and update it if the sequence type is mutable.
96                clone = copy.copy(data)  # type: ignore[arg-type]
97                for i, item in enumerate(data):
98                    clone[i] = pin_memory(item, device)
99                return clone
100            return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
101        except TypeError:
102            # The sequence type may not support `copy()` / `__setitem__(index, item)`
103            # or `__init__(iterable)` (e.g., `range`).
104            return [pin_memory(sample, device) for sample in data]
105    elif hasattr(data, "pin_memory"):
106        return data.pin_memory()
107    else:
108        return data
109