1# mypy: allow-untyped-defs 2r"""Contains definitions of the methods used by the _BaseDataLoaderIter to put fetched tensors into pinned memory. 3 4These **needs** to be in global scope since Py2 doesn't support serializing 5static methods. 6""" 7 8import collections 9import copy 10import queue 11 12import torch 13from torch._utils import ExceptionWrapper 14 15from . import MP_STATUS_CHECK_INTERVAL 16 17 18def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device): 19 # This setting is thread local, and prevents the copy in pin_memory from 20 # consuming all CPU cores. 21 torch.set_num_threads(1) 22 23 torch.multiprocessing._set_thread_name("pt_data_pin") 24 25 if device == "cuda": 26 torch.cuda.set_device(device_id) 27 elif device == "xpu": 28 torch.xpu.set_device(device_id) # type: ignore[attr-defined] 29 elif device == torch._C._get_privateuse1_backend_name(): 30 custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) 31 custom_device_mod.set_device(device_id) 32 33 def do_one_step(): 34 try: 35 r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 36 except queue.Empty: 37 return 38 idx, data = r 39 if not done_event.is_set() and not isinstance(data, ExceptionWrapper): 40 try: 41 data = pin_memory(data, device) 42 except Exception: 43 data = ExceptionWrapper( 44 where=f"in pin memory thread for device {device_id}" 45 ) 46 r = (idx, data) 47 while not done_event.is_set(): 48 try: 49 out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL) 50 break 51 except queue.Full: 52 continue 53 54 # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the 55 # logic of this function. 56 while not done_event.is_set(): 57 # Make sure that we don't preserve any object from one iteration 58 # to the next 59 do_one_step() 60 61 62def pin_memory(data, device=None): 63 if isinstance(data, torch.Tensor): 64 return data.pin_memory(device) 65 elif isinstance(data, (str, bytes)): 66 return data 67 elif isinstance(data, collections.abc.Mapping): 68 try: 69 if isinstance(data, collections.abc.MutableMapping): 70 # The sequence type may have extra properties, so we can't just 71 # use `type(data)(...)` to create the new sequence. 72 # Create a clone and update it if the sequence type is mutable. 73 clone = copy.copy(data) 74 clone.update( 75 {k: pin_memory(sample, device) for k, sample in data.items()} 76 ) 77 return clone 78 else: 79 return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg] 80 except TypeError: 81 # The mapping type may not support `copy()` / `update(mapping)` 82 # or `__init__(iterable)`. 83 return {k: pin_memory(sample, device) for k, sample in data.items()} 84 elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple 85 return type(data)(*(pin_memory(sample, device) for sample in data)) 86 elif isinstance(data, tuple): 87 return [ 88 pin_memory(sample, device) for sample in data 89 ] # Backwards compatibility. 90 elif isinstance(data, collections.abc.Sequence): 91 try: 92 if isinstance(data, collections.abc.MutableSequence): 93 # The sequence type may have extra properties, so we can't just 94 # use `type(data)(...)` to create the new sequence. 95 # Create a clone and update it if the sequence type is mutable. 96 clone = copy.copy(data) # type: ignore[arg-type] 97 for i, item in enumerate(data): 98 clone[i] = pin_memory(item, device) 99 return clone 100 return type(data)([pin_memory(sample, device) for sample in data]) # type: ignore[call-arg] 101 except TypeError: 102 # The sequence type may not support `copy()` / `__setitem__(index, item)` 103 # or `__init__(iterable)` (e.g., `range`). 104 return [pin_memory(sample, device) for sample in data] 105 elif hasattr(data, "pin_memory"): 106 return data.pin_memory() 107 else: 108 return data 109