1# mypy: allow-untyped-defs 2r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers. 3 4These methods are used to collate samples fetched from dataset into Tensor(s). 5These **needs** to be in global scope since Py2 doesn't support serializing 6static methods. 7 8`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. 9""" 10 11import collections 12import contextlib 13import copy 14import re 15from typing import Callable, Dict, Optional, Tuple, Type, Union 16 17import torch 18 19 20np_str_obj_array_pattern = re.compile(r"[SaUO]") 21 22 23def default_convert(data): 24 r""" 25 Convert each NumPy array element into a :class:`torch.Tensor`. 26 27 If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. 28 If the input is not an NumPy array, it is left unchanged. 29 This is used as the default function for collation when both `batch_sampler` and `batch_size` 30 are NOT defined in :class:`~torch.utils.data.DataLoader`. 31 32 The general input type to output type mapping is similar to that 33 of :func:`~torch.utils.data.default_collate`. See the description there for more details. 34 35 Args: 36 data: a single data point to be converted 37 38 Examples: 39 >>> # xdoctest: +SKIP 40 >>> # Example with `int` 41 >>> default_convert(0) 42 0 43 >>> # Example with NumPy array 44 >>> default_convert(np.array([0, 1])) 45 tensor([0, 1]) 46 >>> # Example with NamedTuple 47 >>> Point = namedtuple('Point', ['x', 'y']) 48 >>> default_convert(Point(0, 0)) 49 Point(x=0, y=0) 50 >>> default_convert(Point(np.array(0), np.array(0))) 51 Point(x=tensor(0), y=tensor(0)) 52 >>> # Example with List 53 >>> default_convert([np.array([0, 1]), np.array([2, 3])]) 54 [tensor([0, 1]), tensor([2, 3])] 55 """ 56 elem_type = type(data) 57 if isinstance(data, torch.Tensor): 58 return data 59 elif ( 60 elem_type.__module__ == "numpy" 61 and elem_type.__name__ != "str_" 62 and elem_type.__name__ != "string_" 63 ): 64 # array of string classes and object 65 if ( 66 elem_type.__name__ == "ndarray" 67 and np_str_obj_array_pattern.search(data.dtype.str) is not None 68 ): 69 return data 70 return torch.as_tensor(data) 71 elif isinstance(data, collections.abc.Mapping): 72 try: 73 if isinstance(data, collections.abc.MutableMapping): 74 # The mapping type may have extra properties, so we can't just 75 # use `type(data)(...)` to create the new mapping. 76 # Create a clone and update it if the mapping type is mutable. 77 clone = copy.copy(data) 78 clone.update({key: default_convert(data[key]) for key in data}) 79 return clone 80 else: 81 return elem_type({key: default_convert(data[key]) for key in data}) 82 except TypeError: 83 # The mapping type may not support `copy()` / `update(mapping)` 84 # or `__init__(iterable)`. 85 return {key: default_convert(data[key]) for key in data} 86 elif isinstance(data, tuple) and hasattr(data, "_fields"): # namedtuple 87 return elem_type(*(default_convert(d) for d in data)) 88 elif isinstance(data, tuple): 89 return [default_convert(d) for d in data] # Backwards compatibility. 90 elif isinstance(data, collections.abc.Sequence) and not isinstance( 91 data, (str, bytes) 92 ): 93 try: 94 if isinstance(data, collections.abc.MutableSequence): 95 # The sequence type may have extra properties, so we can't just 96 # use `type(data)(...)` to create the new sequence. 97 # Create a clone and update it if the sequence type is mutable. 98 clone = copy.copy(data) # type: ignore[arg-type] 99 for i, d in enumerate(data): 100 clone[i] = default_convert(d) 101 return clone 102 else: 103 return elem_type([default_convert(d) for d in data]) 104 except TypeError: 105 # The sequence type may not support `copy()` / `__setitem__(index, item)` 106 # or `__init__(iterable)` (e.g., `range`). 107 return [default_convert(d) for d in data] 108 else: 109 return data 110 111 112default_collate_err_msg_format = ( 113 "default_collate: batch must contain tensors, numpy arrays, numbers, " 114 "dicts or lists; found {}" 115) 116 117 118def collate( 119 batch, 120 *, 121 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 122): 123 r""" 124 General collate function that handles collection type of element within each batch. 125 126 The function also opens function registry to deal with specific element types. `default_collate_fn_map` 127 provides default collate functions for tensors, numpy arrays, numbers and strings. 128 129 Args: 130 batch: a single batch to be collated 131 collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function. 132 If the element type isn't present in this dictionary, 133 this function will go through each key of the dictionary in the insertion order to 134 invoke the corresponding collate function if the element type is a subclass of the key. 135 136 Examples: 137 >>> def collate_tensor_fn(batch, *, collate_fn_map): 138 ... # Extend this function to handle batch of tensors 139 ... return torch.stack(batch, 0) 140 >>> def custom_collate(batch): 141 ... collate_map = {torch.Tensor: collate_tensor_fn} 142 ... return collate(batch, collate_fn_map=collate_map) 143 >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` 144 >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) 145 146 Note: 147 Each collate function requires a positional argument for batch and a keyword argument 148 for the dictionary of collate functions as `collate_fn_map`. 149 """ 150 elem = batch[0] 151 elem_type = type(elem) 152 153 if collate_fn_map is not None: 154 if elem_type in collate_fn_map: 155 return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) 156 157 for collate_type in collate_fn_map: 158 if isinstance(elem, collate_type): 159 return collate_fn_map[collate_type]( 160 batch, collate_fn_map=collate_fn_map 161 ) 162 163 if isinstance(elem, collections.abc.Mapping): 164 try: 165 if isinstance(elem, collections.abc.MutableMapping): 166 # The mapping type may have extra properties, so we can't just 167 # use `type(data)(...)` to create the new mapping. 168 # Create a clone and update it if the mapping type is mutable. 169 clone = copy.copy(elem) 170 clone.update( 171 { 172 key: collate( 173 [d[key] for d in batch], collate_fn_map=collate_fn_map 174 ) 175 for key in elem 176 } 177 ) 178 return clone 179 else: 180 return elem_type( 181 { 182 key: collate( 183 [d[key] for d in batch], collate_fn_map=collate_fn_map 184 ) 185 for key in elem 186 } 187 ) 188 except TypeError: 189 # The mapping type may not support `copy()` / `update(mapping)` 190 # or `__init__(iterable)`. 191 return { 192 key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) 193 for key in elem 194 } 195 elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple 196 return elem_type( 197 *( 198 collate(samples, collate_fn_map=collate_fn_map) 199 for samples in zip(*batch) 200 ) 201 ) 202 elif isinstance(elem, collections.abc.Sequence): 203 # check to make sure that the elements in batch have consistent size 204 it = iter(batch) 205 elem_size = len(next(it)) 206 if not all(len(elem) == elem_size for elem in it): 207 raise RuntimeError("each element in list of batch should be of equal size") 208 transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. 209 210 if isinstance(elem, tuple): 211 return [ 212 collate(samples, collate_fn_map=collate_fn_map) 213 for samples in transposed 214 ] # Backwards compatibility. 215 else: 216 try: 217 if isinstance(elem, collections.abc.MutableSequence): 218 # The sequence type may have extra properties, so we can't just 219 # use `type(data)(...)` to create the new sequence. 220 # Create a clone and update it if the sequence type is mutable. 221 clone = copy.copy(elem) # type: ignore[arg-type] 222 for i, samples in enumerate(transposed): 223 clone[i] = collate(samples, collate_fn_map=collate_fn_map) 224 return clone 225 else: 226 return elem_type( 227 [ 228 collate(samples, collate_fn_map=collate_fn_map) 229 for samples in transposed 230 ] 231 ) 232 except TypeError: 233 # The sequence type may not support `copy()` / `__setitem__(index, item)` 234 # or `__init__(iterable)` (e.g., `range`). 235 return [ 236 collate(samples, collate_fn_map=collate_fn_map) 237 for samples in transposed 238 ] 239 240 raise TypeError(default_collate_err_msg_format.format(elem_type)) 241 242 243def collate_tensor_fn( 244 batch, 245 *, 246 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 247): 248 elem = batch[0] 249 out = None 250 if elem.is_nested: 251 raise RuntimeError( 252 "Batches of nested tensors are not currently supported by the default collate_fn; " 253 "please provide a custom collate_fn to handle them appropriately." 254 ) 255 if elem.layout in { 256 torch.sparse_coo, 257 torch.sparse_csr, 258 torch.sparse_bsr, 259 torch.sparse_csc, 260 torch.sparse_bsc, 261 }: 262 raise RuntimeError( 263 "Batches of sparse tensors are not currently supported by the default collate_fn; " 264 "please provide a custom collate_fn to handle them appropriately." 265 ) 266 if torch.utils.data.get_worker_info() is not None: 267 # If we're in a background process, concatenate directly into a 268 # shared memory tensor to avoid an extra copy 269 numel = sum(x.numel() for x in batch) 270 storage = elem._typed_storage()._new_shared(numel, device=elem.device) 271 out = elem.new(storage).resize_(len(batch), *list(elem.size())) 272 return torch.stack(batch, 0, out=out) 273 274 275def collate_numpy_array_fn( 276 batch, 277 *, 278 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 279): 280 elem = batch[0] 281 # array of string classes and object 282 if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 283 raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 284 285 return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map) 286 287 288def collate_numpy_scalar_fn( 289 batch, 290 *, 291 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 292): 293 return torch.as_tensor(batch) 294 295 296def collate_float_fn( 297 batch, 298 *, 299 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 300): 301 return torch.tensor(batch, dtype=torch.float64) 302 303 304def collate_int_fn( 305 batch, 306 *, 307 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 308): 309 return torch.tensor(batch) 310 311 312def collate_str_fn( 313 batch, 314 *, 315 collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None, 316): 317 return batch 318 319 320default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = { 321 torch.Tensor: collate_tensor_fn 322} 323with contextlib.suppress(ImportError): 324 import numpy as np 325 326 # For both ndarray and memmap (subclass of ndarray) 327 default_collate_fn_map[np.ndarray] = collate_numpy_array_fn 328 # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html 329 # Skip string scalars 330 default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn 331default_collate_fn_map[float] = collate_float_fn 332default_collate_fn_map[int] = collate_int_fn 333default_collate_fn_map[str] = collate_str_fn 334default_collate_fn_map[bytes] = collate_str_fn 335 336 337def default_collate(batch): 338 r""" 339 Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size. 340 341 The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a 342 Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. 343 This is used as the default function for collation when 344 `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. 345 346 Here is the general input type (based on the type of the element within the batch) to output type mapping: 347 348 * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) 349 * NumPy Arrays -> :class:`torch.Tensor` 350 * `float` -> :class:`torch.Tensor` 351 * `int` -> :class:`torch.Tensor` 352 * `str` -> `str` (unchanged) 353 * `bytes` -> `bytes` (unchanged) 354 * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` 355 * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), 356 default_collate([V2_1, V2_2, ...]), ...]` 357 * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), 358 default_collate([V2_1, V2_2, ...]), ...]` 359 360 Args: 361 batch: a single batch to be collated 362 363 Examples: 364 >>> # xdoctest: +SKIP 365 >>> # Example with a batch of `int`s: 366 >>> default_collate([0, 1, 2, 3]) 367 tensor([0, 1, 2, 3]) 368 >>> # Example with a batch of `str`s: 369 >>> default_collate(['a', 'b', 'c']) 370 ['a', 'b', 'c'] 371 >>> # Example with `Map` inside the batch: 372 >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) 373 {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} 374 >>> # Example with `NamedTuple` inside the batch: 375 >>> Point = namedtuple('Point', ['x', 'y']) 376 >>> default_collate([Point(0, 0), Point(1, 1)]) 377 Point(x=tensor([0, 1]), y=tensor([0, 1])) 378 >>> # Example with `Tuple` inside the batch: 379 >>> default_collate([(0, 1), (2, 3)]) 380 [tensor([0, 2]), tensor([1, 3])] 381 >>> # Example with `List` inside the batch: 382 >>> default_collate([[0, 1], [2, 3]]) 383 [tensor([0, 2]), tensor([1, 3])] 384 >>> # Two options to extend `default_collate` to handle specific type 385 >>> # Option 1: Write custom collate function and invoke `default_collate` 386 >>> def custom_collate(batch): 387 ... elem = batch[0] 388 ... if isinstance(elem, CustomType): # Some custom condition 389 ... return ... 390 ... else: # Fall back to `default_collate` 391 ... return default_collate(batch) 392 >>> # Option 2: In-place modify `default_collate_fn_map` 393 >>> def collate_customtype_fn(batch, *, collate_fn_map=None): 394 ... return ... 395 >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) 396 >>> default_collate(batch) # Handle `CustomType` automatically 397 """ 398 return collate(batch, collate_fn_map=default_collate_fn_map) 399