xref: /aosp_15_r20/external/pytorch/torch/utils/data/_utils/collate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
3
4These methods are used to collate samples fetched from dataset into Tensor(s).
5These **needs** to be in global scope since Py2 doesn't support serializing
6static methods.
7
8`default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
9"""
10
11import collections
12import contextlib
13import copy
14import re
15from typing import Callable, Dict, Optional, Tuple, Type, Union
16
17import torch
18
19
20np_str_obj_array_pattern = re.compile(r"[SaUO]")
21
22
23def default_convert(data):
24    r"""
25    Convert each NumPy array element into a :class:`torch.Tensor`.
26
27    If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
28    If the input is not an NumPy array, it is left unchanged.
29    This is used as the default function for collation when both `batch_sampler` and `batch_size`
30    are NOT defined in :class:`~torch.utils.data.DataLoader`.
31
32    The general input type to output type mapping is similar to that
33    of :func:`~torch.utils.data.default_collate`. See the description there for more details.
34
35    Args:
36        data: a single data point to be converted
37
38    Examples:
39        >>> # xdoctest: +SKIP
40        >>> # Example with `int`
41        >>> default_convert(0)
42        0
43        >>> # Example with NumPy array
44        >>> default_convert(np.array([0, 1]))
45        tensor([0, 1])
46        >>> # Example with NamedTuple
47        >>> Point = namedtuple('Point', ['x', 'y'])
48        >>> default_convert(Point(0, 0))
49        Point(x=0, y=0)
50        >>> default_convert(Point(np.array(0), np.array(0)))
51        Point(x=tensor(0), y=tensor(0))
52        >>> # Example with List
53        >>> default_convert([np.array([0, 1]), np.array([2, 3])])
54        [tensor([0, 1]), tensor([2, 3])]
55    """
56    elem_type = type(data)
57    if isinstance(data, torch.Tensor):
58        return data
59    elif (
60        elem_type.__module__ == "numpy"
61        and elem_type.__name__ != "str_"
62        and elem_type.__name__ != "string_"
63    ):
64        # array of string classes and object
65        if (
66            elem_type.__name__ == "ndarray"
67            and np_str_obj_array_pattern.search(data.dtype.str) is not None
68        ):
69            return data
70        return torch.as_tensor(data)
71    elif isinstance(data, collections.abc.Mapping):
72        try:
73            if isinstance(data, collections.abc.MutableMapping):
74                # The mapping type may have extra properties, so we can't just
75                # use `type(data)(...)` to create the new mapping.
76                # Create a clone and update it if the mapping type is mutable.
77                clone = copy.copy(data)
78                clone.update({key: default_convert(data[key]) for key in data})
79                return clone
80            else:
81                return elem_type({key: default_convert(data[key]) for key in data})
82        except TypeError:
83            # The mapping type may not support `copy()` / `update(mapping)`
84            # or `__init__(iterable)`.
85            return {key: default_convert(data[key]) for key in data}
86    elif isinstance(data, tuple) and hasattr(data, "_fields"):  # namedtuple
87        return elem_type(*(default_convert(d) for d in data))
88    elif isinstance(data, tuple):
89        return [default_convert(d) for d in data]  # Backwards compatibility.
90    elif isinstance(data, collections.abc.Sequence) and not isinstance(
91        data, (str, bytes)
92    ):
93        try:
94            if isinstance(data, collections.abc.MutableSequence):
95                # The sequence type may have extra properties, so we can't just
96                # use `type(data)(...)` to create the new sequence.
97                # Create a clone and update it if the sequence type is mutable.
98                clone = copy.copy(data)  # type: ignore[arg-type]
99                for i, d in enumerate(data):
100                    clone[i] = default_convert(d)
101                return clone
102            else:
103                return elem_type([default_convert(d) for d in data])
104        except TypeError:
105            # The sequence type may not support `copy()` / `__setitem__(index, item)`
106            # or `__init__(iterable)` (e.g., `range`).
107            return [default_convert(d) for d in data]
108    else:
109        return data
110
111
112default_collate_err_msg_format = (
113    "default_collate: batch must contain tensors, numpy arrays, numbers, "
114    "dicts or lists; found {}"
115)
116
117
118def collate(
119    batch,
120    *,
121    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
122):
123    r"""
124    General collate function that handles collection type of element within each batch.
125
126    The function also opens function registry to deal with specific element types. `default_collate_fn_map`
127    provides default collate functions for tensors, numpy arrays, numbers and strings.
128
129    Args:
130        batch: a single batch to be collated
131        collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
132            If the element type isn't present in this dictionary,
133            this function will go through each key of the dictionary in the insertion order to
134            invoke the corresponding collate function if the element type is a subclass of the key.
135
136    Examples:
137        >>> def collate_tensor_fn(batch, *, collate_fn_map):
138        ...     # Extend this function to handle batch of tensors
139        ...     return torch.stack(batch, 0)
140        >>> def custom_collate(batch):
141        ...     collate_map = {torch.Tensor: collate_tensor_fn}
142        ...     return collate(batch, collate_fn_map=collate_map)
143        >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
144        >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
145
146    Note:
147        Each collate function requires a positional argument for batch and a keyword argument
148        for the dictionary of collate functions as `collate_fn_map`.
149    """
150    elem = batch[0]
151    elem_type = type(elem)
152
153    if collate_fn_map is not None:
154        if elem_type in collate_fn_map:
155            return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
156
157        for collate_type in collate_fn_map:
158            if isinstance(elem, collate_type):
159                return collate_fn_map[collate_type](
160                    batch, collate_fn_map=collate_fn_map
161                )
162
163    if isinstance(elem, collections.abc.Mapping):
164        try:
165            if isinstance(elem, collections.abc.MutableMapping):
166                # The mapping type may have extra properties, so we can't just
167                # use `type(data)(...)` to create the new mapping.
168                # Create a clone and update it if the mapping type is mutable.
169                clone = copy.copy(elem)
170                clone.update(
171                    {
172                        key: collate(
173                            [d[key] for d in batch], collate_fn_map=collate_fn_map
174                        )
175                        for key in elem
176                    }
177                )
178                return clone
179            else:
180                return elem_type(
181                    {
182                        key: collate(
183                            [d[key] for d in batch], collate_fn_map=collate_fn_map
184                        )
185                        for key in elem
186                    }
187                )
188        except TypeError:
189            # The mapping type may not support `copy()` / `update(mapping)`
190            # or `__init__(iterable)`.
191            return {
192                key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map)
193                for key in elem
194            }
195    elif isinstance(elem, tuple) and hasattr(elem, "_fields"):  # namedtuple
196        return elem_type(
197            *(
198                collate(samples, collate_fn_map=collate_fn_map)
199                for samples in zip(*batch)
200            )
201        )
202    elif isinstance(elem, collections.abc.Sequence):
203        # check to make sure that the elements in batch have consistent size
204        it = iter(batch)
205        elem_size = len(next(it))
206        if not all(len(elem) == elem_size for elem in it):
207            raise RuntimeError("each element in list of batch should be of equal size")
208        transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.
209
210        if isinstance(elem, tuple):
211            return [
212                collate(samples, collate_fn_map=collate_fn_map)
213                for samples in transposed
214            ]  # Backwards compatibility.
215        else:
216            try:
217                if isinstance(elem, collections.abc.MutableSequence):
218                    # The sequence type may have extra properties, so we can't just
219                    # use `type(data)(...)` to create the new sequence.
220                    # Create a clone and update it if the sequence type is mutable.
221                    clone = copy.copy(elem)  # type: ignore[arg-type]
222                    for i, samples in enumerate(transposed):
223                        clone[i] = collate(samples, collate_fn_map=collate_fn_map)
224                    return clone
225                else:
226                    return elem_type(
227                        [
228                            collate(samples, collate_fn_map=collate_fn_map)
229                            for samples in transposed
230                        ]
231                    )
232            except TypeError:
233                # The sequence type may not support `copy()` / `__setitem__(index, item)`
234                # or `__init__(iterable)` (e.g., `range`).
235                return [
236                    collate(samples, collate_fn_map=collate_fn_map)
237                    for samples in transposed
238                ]
239
240    raise TypeError(default_collate_err_msg_format.format(elem_type))
241
242
243def collate_tensor_fn(
244    batch,
245    *,
246    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
247):
248    elem = batch[0]
249    out = None
250    if elem.is_nested:
251        raise RuntimeError(
252            "Batches of nested tensors are not currently supported by the default collate_fn; "
253            "please provide a custom collate_fn to handle them appropriately."
254        )
255    if elem.layout in {
256        torch.sparse_coo,
257        torch.sparse_csr,
258        torch.sparse_bsr,
259        torch.sparse_csc,
260        torch.sparse_bsc,
261    }:
262        raise RuntimeError(
263            "Batches of sparse tensors are not currently supported by the default collate_fn; "
264            "please provide a custom collate_fn to handle them appropriately."
265        )
266    if torch.utils.data.get_worker_info() is not None:
267        # If we're in a background process, concatenate directly into a
268        # shared memory tensor to avoid an extra copy
269        numel = sum(x.numel() for x in batch)
270        storage = elem._typed_storage()._new_shared(numel, device=elem.device)
271        out = elem.new(storage).resize_(len(batch), *list(elem.size()))
272    return torch.stack(batch, 0, out=out)
273
274
275def collate_numpy_array_fn(
276    batch,
277    *,
278    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
279):
280    elem = batch[0]
281    # array of string classes and object
282    if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
283        raise TypeError(default_collate_err_msg_format.format(elem.dtype))
284
285    return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
286
287
288def collate_numpy_scalar_fn(
289    batch,
290    *,
291    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
292):
293    return torch.as_tensor(batch)
294
295
296def collate_float_fn(
297    batch,
298    *,
299    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
300):
301    return torch.tensor(batch, dtype=torch.float64)
302
303
304def collate_int_fn(
305    batch,
306    *,
307    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
308):
309    return torch.tensor(batch)
310
311
312def collate_str_fn(
313    batch,
314    *,
315    collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None,
316):
317    return batch
318
319
320default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {
321    torch.Tensor: collate_tensor_fn
322}
323with contextlib.suppress(ImportError):
324    import numpy as np
325
326    # For both ndarray and memmap (subclass of ndarray)
327    default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
328    # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
329    # Skip string scalars
330    default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
331default_collate_fn_map[float] = collate_float_fn
332default_collate_fn_map[int] = collate_int_fn
333default_collate_fn_map[str] = collate_str_fn
334default_collate_fn_map[bytes] = collate_str_fn
335
336
337def default_collate(batch):
338    r"""
339    Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
340
341    The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
342    Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
343    This is used as the default function for collation when
344    `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
345
346    Here is the general input type (based on the type of the element within the batch) to output type mapping:
347
348        * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
349        * NumPy Arrays -> :class:`torch.Tensor`
350        * `float` -> :class:`torch.Tensor`
351        * `int` -> :class:`torch.Tensor`
352        * `str` -> `str` (unchanged)
353        * `bytes` -> `bytes` (unchanged)
354        * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
355        * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
356          default_collate([V2_1, V2_2, ...]), ...]`
357        * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
358          default_collate([V2_1, V2_2, ...]), ...]`
359
360    Args:
361        batch: a single batch to be collated
362
363    Examples:
364        >>> # xdoctest: +SKIP
365        >>> # Example with a batch of `int`s:
366        >>> default_collate([0, 1, 2, 3])
367        tensor([0, 1, 2, 3])
368        >>> # Example with a batch of `str`s:
369        >>> default_collate(['a', 'b', 'c'])
370        ['a', 'b', 'c']
371        >>> # Example with `Map` inside the batch:
372        >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
373        {'A': tensor([  0, 100]), 'B': tensor([  1, 100])}
374        >>> # Example with `NamedTuple` inside the batch:
375        >>> Point = namedtuple('Point', ['x', 'y'])
376        >>> default_collate([Point(0, 0), Point(1, 1)])
377        Point(x=tensor([0, 1]), y=tensor([0, 1]))
378        >>> # Example with `Tuple` inside the batch:
379        >>> default_collate([(0, 1), (2, 3)])
380        [tensor([0, 2]), tensor([1, 3])]
381        >>> # Example with `List` inside the batch:
382        >>> default_collate([[0, 1], [2, 3]])
383        [tensor([0, 2]), tensor([1, 3])]
384        >>> # Two options to extend `default_collate` to handle specific type
385        >>> # Option 1: Write custom collate function and invoke `default_collate`
386        >>> def custom_collate(batch):
387        ...     elem = batch[0]
388        ...     if isinstance(elem, CustomType):  # Some custom condition
389        ...         return ...
390        ...     else:  # Fall back to `default_collate`
391        ...         return default_collate(batch)
392        >>> # Option 2: In-place modify `default_collate_fn_map`
393        >>> def collate_customtype_fn(batch, *, collate_fn_map=None):
394        ...     return ...
395        >>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
396        >>> default_collate(batch)  # Handle `CustomType` automatically
397    """
398    return collate(batch, collate_fn_map=default_collate_fn_map)
399