1# mypy: allow-untyped-defs 2import bisect 3import itertools 4import math 5import warnings 6from typing import ( 7 cast, 8 Dict, 9 Generic, 10 Iterable, 11 List, 12 Optional, 13 Sequence, 14 Tuple, 15 TypeVar, 16 Union, 17) 18from typing_extensions import deprecated 19 20# No 'default_generator' in torch/__init__.pyi 21from torch import default_generator, Generator, randperm, Tensor 22 23 24__all__ = [ 25 "Dataset", 26 "IterableDataset", 27 "TensorDataset", 28 "StackDataset", 29 "ConcatDataset", 30 "ChainDataset", 31 "Subset", 32 "random_split", 33] 34 35 36_T = TypeVar("_T") 37_T_co = TypeVar("_T_co", covariant=True) 38_T_dict = Dict[str, _T_co] 39_T_tuple = Tuple[_T_co, ...] 40_T_stack = TypeVar("_T_stack", _T_tuple, _T_dict) 41 42 43class Dataset(Generic[_T_co]): 44 r"""An abstract class representing a :class:`Dataset`. 45 46 All datasets that represent a map from keys to data samples should subclass 47 it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a 48 data sample for a given key. Subclasses could also optionally overwrite 49 :meth:`__len__`, which is expected to return the size of the dataset by many 50 :class:`~torch.utils.data.Sampler` implementations and the default options 51 of :class:`~torch.utils.data.DataLoader`. Subclasses could also 52 optionally implement :meth:`__getitems__`, for speedup batched samples 53 loading. This method accepts list of indices of samples of batch and returns 54 list of samples. 55 56 .. note:: 57 :class:`~torch.utils.data.DataLoader` by default constructs an index 58 sampler that yields integral indices. To make it work with a map-style 59 dataset with non-integral indices/keys, a custom sampler must be provided. 60 """ 61 62 def __getitem__(self, index) -> _T_co: 63 raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") 64 65 # def __getitems__(self, indices: List) -> List[_T_co]: 66 # Not implemented to prevent false-positives in fetcher check in 67 # torch.utils.data._utils.fetch._MapDatasetFetcher 68 69 def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]": 70 return ConcatDataset([self, other]) 71 72 # No `def __len__(self)` default? 73 # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 74 # in pytorch/torch/utils/data/sampler.py 75 76 77class IterableDataset(Dataset[_T_co], Iterable[_T_co]): 78 r"""An iterable Dataset. 79 80 All datasets that represent an iterable of data samples should subclass it. 81 Such form of datasets is particularly useful when data come from a stream. 82 83 All subclasses should overwrite :meth:`__iter__`, which would return an 84 iterator of samples in this dataset. 85 86 When a subclass is used with :class:`~torch.utils.data.DataLoader`, each 87 item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` 88 iterator. When :attr:`num_workers > 0`, each worker process will have a 89 different copy of the dataset object, so it is often desired to configure 90 each copy independently to avoid having duplicate data returned from the 91 workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker 92 process, returns information about the worker. It can be used in either the 93 dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's 94 :attr:`worker_init_fn` option to modify each copy's behavior. 95 96 Example 1: splitting workload across all workers in :meth:`__iter__`:: 97 98 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) 99 >>> # xdoctest: +SKIP("Fails on MacOS12") 100 >>> class MyIterableDataset(torch.utils.data.IterableDataset): 101 ... def __init__(self, start, end): 102 ... super(MyIterableDataset).__init__() 103 ... assert end > start, "this example code only works with end >= start" 104 ... self.start = start 105 ... self.end = end 106 ... 107 ... def __iter__(self): 108 ... worker_info = torch.utils.data.get_worker_info() 109 ... if worker_info is None: # single-process data loading, return the full iterator 110 ... iter_start = self.start 111 ... iter_end = self.end 112 ... else: # in a worker process 113 ... # split workload 114 ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) 115 ... worker_id = worker_info.id 116 ... iter_start = self.start + worker_id * per_worker 117 ... iter_end = min(iter_start + per_worker, self.end) 118 ... return iter(range(iter_start, iter_end)) 119 ... 120 >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. 121 >>> ds = MyIterableDataset(start=3, end=7) 122 123 >>> # Single-process loading 124 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) 125 [tensor([3]), tensor([4]), tensor([5]), tensor([6])] 126 127 >>> # xdoctest: +REQUIRES(POSIX) 128 >>> # Mult-process loading with two worker processes 129 >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. 130 >>> # xdoctest: +IGNORE_WANT("non deterministic") 131 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) 132 [tensor([3]), tensor([5]), tensor([4]), tensor([6])] 133 134 >>> # With even more workers 135 >>> # xdoctest: +IGNORE_WANT("non deterministic") 136 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) 137 [tensor([3]), tensor([5]), tensor([4]), tensor([6])] 138 139 Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: 140 141 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_DATALOADER) 142 >>> class MyIterableDataset(torch.utils.data.IterableDataset): 143 ... def __init__(self, start, end): 144 ... super(MyIterableDataset).__init__() 145 ... assert end > start, "this example code only works with end >= start" 146 ... self.start = start 147 ... self.end = end 148 ... 149 ... def __iter__(self): 150 ... return iter(range(self.start, self.end)) 151 ... 152 >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. 153 >>> ds = MyIterableDataset(start=3, end=7) 154 155 >>> # Single-process loading 156 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) 157 [3, 4, 5, 6] 158 >>> 159 >>> # Directly doing multi-process loading yields duplicate data 160 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) 161 [3, 3, 4, 4, 5, 5, 6, 6] 162 163 >>> # Define a `worker_init_fn` that configures each dataset copy differently 164 >>> def worker_init_fn(worker_id): 165 ... worker_info = torch.utils.data.get_worker_info() 166 ... dataset = worker_info.dataset # the dataset copy in this worker process 167 ... overall_start = dataset.start 168 ... overall_end = dataset.end 169 ... # configure the dataset to only process the split workload 170 ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) 171 ... worker_id = worker_info.id 172 ... dataset.start = overall_start + worker_id * per_worker 173 ... dataset.end = min(dataset.start + per_worker, overall_end) 174 ... 175 176 >>> # Mult-process loading with the custom `worker_init_fn` 177 >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. 178 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) 179 [3, 5, 4, 6] 180 181 >>> # With even more workers 182 >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) 183 [3, 4, 5, 6] 184 """ 185 186 def __add__(self, other: Dataset[_T_co]): 187 return ChainDataset([self, other]) 188 189 # No `def __len__(self)` default? Subclasses raise `TypeError` when needed. 190 # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 191 192 193class TensorDataset(Dataset[Tuple[Tensor, ...]]): 194 r"""Dataset wrapping tensors. 195 196 Each sample will be retrieved by indexing tensors along the first dimension. 197 198 Args: 199 *tensors (Tensor): tensors that have the same size of the first dimension. 200 """ 201 202 tensors: Tuple[Tensor, ...] 203 204 def __init__(self, *tensors: Tensor) -> None: 205 assert all( 206 tensors[0].size(0) == tensor.size(0) for tensor in tensors 207 ), "Size mismatch between tensors" 208 self.tensors = tensors 209 210 def __getitem__(self, index): 211 return tuple(tensor[index] for tensor in self.tensors) 212 213 def __len__(self): 214 return self.tensors[0].size(0) 215 216 217class StackDataset(Dataset[_T_stack]): 218 r"""Dataset as a stacking of multiple datasets. 219 220 This class is useful to assemble different parts of complex input data, given as datasets. 221 222 Example: 223 >>> # xdoctest: +SKIP 224 >>> images = ImageDataset() 225 >>> texts = TextDataset() 226 >>> tuple_stack = StackDataset(images, texts) 227 >>> tuple_stack[0] == (images[0], texts[0]) 228 >>> dict_stack = StackDataset(image=images, text=texts) 229 >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} 230 231 Args: 232 *args (Dataset): Datasets for stacking returned as tuple. 233 **kwargs (Dataset): Datasets for stacking returned as dict. 234 """ 235 236 datasets: Union[tuple, dict] 237 238 def __init__(self, *args: Dataset[_T_co], **kwargs: Dataset[_T_co]) -> None: 239 if args: 240 if kwargs: 241 raise ValueError( 242 "Supported either ``tuple``- (via ``args``) or" 243 "``dict``- (via ``kwargs``) like input/output, but both types are given." 244 ) 245 self._length = len(args[0]) # type: ignore[arg-type] 246 if any(self._length != len(dataset) for dataset in args): # type: ignore[arg-type] 247 raise ValueError("Size mismatch between datasets") 248 self.datasets = args 249 elif kwargs: 250 tmp = list(kwargs.values()) 251 self._length = len(tmp[0]) # type: ignore[arg-type] 252 if any(self._length != len(dataset) for dataset in tmp): # type: ignore[arg-type] 253 raise ValueError("Size mismatch between datasets") 254 self.datasets = kwargs 255 else: 256 raise ValueError("At least one dataset should be passed") 257 258 def __getitem__(self, index): 259 if isinstance(self.datasets, dict): 260 return {k: dataset[index] for k, dataset in self.datasets.items()} 261 return tuple(dataset[index] for dataset in self.datasets) 262 263 def __getitems__(self, indices: list): 264 # add batched sampling support when parent datasets supports it. 265 if isinstance(self.datasets, dict): 266 dict_batch: List[_T_dict] = [{} for _ in indices] 267 for k, dataset in self.datasets.items(): 268 if callable(getattr(dataset, "__getitems__", None)): 269 items = dataset.__getitems__(indices) # type: ignore[attr-defined] 270 if len(items) != len(indices): 271 raise ValueError( 272 "Nested dataset's output size mismatch." 273 f" Expected {len(indices)}, got {len(items)}" 274 ) 275 for data, d_sample in zip(items, dict_batch): 276 d_sample[k] = data 277 else: 278 for idx, d_sample in zip(indices, dict_batch): 279 d_sample[k] = dataset[idx] 280 return dict_batch 281 282 # tuple data 283 list_batch: List[list] = [[] for _ in indices] 284 for dataset in self.datasets: 285 if callable(getattr(dataset, "__getitems__", None)): 286 items = dataset.__getitems__(indices) # type: ignore[attr-defined] 287 if len(items) != len(indices): 288 raise ValueError( 289 "Nested dataset's output size mismatch." 290 f" Expected {len(indices)}, got {len(items)}" 291 ) 292 for data, t_sample in zip(items, list_batch): 293 t_sample.append(data) 294 else: 295 for idx, t_sample in zip(indices, list_batch): 296 t_sample.append(dataset[idx]) 297 tuple_batch: List[_T_tuple] = [tuple(sample) for sample in list_batch] 298 return tuple_batch 299 300 def __len__(self): 301 return self._length 302 303 304class ConcatDataset(Dataset[_T_co]): 305 r"""Dataset as a concatenation of multiple datasets. 306 307 This class is useful to assemble different existing datasets. 308 309 Args: 310 datasets (sequence): List of datasets to be concatenated 311 """ 312 313 datasets: List[Dataset[_T_co]] 314 cumulative_sizes: List[int] 315 316 @staticmethod 317 def cumsum(sequence): 318 r, s = [], 0 319 for e in sequence: 320 l = len(e) 321 r.append(l + s) 322 s += l 323 return r 324 325 def __init__(self, datasets: Iterable[Dataset]) -> None: 326 super().__init__() 327 self.datasets = list(datasets) 328 assert len(self.datasets) > 0, "datasets should not be an empty iterable" # type: ignore[arg-type] 329 for d in self.datasets: 330 assert not isinstance( 331 d, IterableDataset 332 ), "ConcatDataset does not support IterableDataset" 333 self.cumulative_sizes = self.cumsum(self.datasets) 334 335 def __len__(self): 336 return self.cumulative_sizes[-1] 337 338 def __getitem__(self, idx): 339 if idx < 0: 340 if -idx > len(self): 341 raise ValueError( 342 "absolute value of index should not exceed dataset length" 343 ) 344 idx = len(self) + idx 345 dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 346 if dataset_idx == 0: 347 sample_idx = idx 348 else: 349 sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 350 return self.datasets[dataset_idx][sample_idx] 351 352 @property 353 @deprecated( 354 "`cummulative_sizes` attribute is renamed to `cumulative_sizes`", 355 category=FutureWarning, 356 ) 357 def cummulative_sizes(self): 358 return self.cumulative_sizes 359 360 361class ChainDataset(IterableDataset): 362 r"""Dataset for chaining multiple :class:`IterableDataset` s. 363 364 This class is useful to assemble different existing dataset streams. The 365 chaining operation is done on-the-fly, so concatenating large-scale 366 datasets with this class will be efficient. 367 368 Args: 369 datasets (iterable of IterableDataset): datasets to be chained together 370 """ 371 372 def __init__(self, datasets: Iterable[Dataset]) -> None: 373 super().__init__() 374 self.datasets = datasets 375 376 def __iter__(self): 377 for d in self.datasets: 378 assert isinstance( 379 d, IterableDataset 380 ), "ChainDataset only supports IterableDataset" 381 yield from d 382 383 def __len__(self): 384 total = 0 385 for d in self.datasets: 386 assert isinstance( 387 d, IterableDataset 388 ), "ChainDataset only supports IterableDataset" 389 total += len(d) # type: ignore[arg-type] 390 return total 391 392 393class Subset(Dataset[_T_co]): 394 r""" 395 Subset of a dataset at specified indices. 396 397 Args: 398 dataset (Dataset): The whole Dataset 399 indices (sequence): Indices in the whole set selected for subset 400 """ 401 402 dataset: Dataset[_T_co] 403 indices: Sequence[int] 404 405 def __init__(self, dataset: Dataset[_T_co], indices: Sequence[int]) -> None: 406 self.dataset = dataset 407 self.indices = indices 408 409 def __getitem__(self, idx): 410 if isinstance(idx, list): 411 return self.dataset[[self.indices[i] for i in idx]] 412 return self.dataset[self.indices[idx]] 413 414 def __getitems__(self, indices: List[int]) -> List[_T_co]: 415 # add batched sampling support when parent dataset supports it. 416 # see torch.utils.data._utils.fetch._MapDatasetFetcher 417 if callable(getattr(self.dataset, "__getitems__", None)): 418 return self.dataset.__getitems__([self.indices[idx] for idx in indices]) # type: ignore[attr-defined] 419 else: 420 return [self.dataset[self.indices[idx]] for idx in indices] 421 422 def __len__(self): 423 return len(self.indices) 424 425 426def random_split( 427 dataset: Dataset[_T], 428 lengths: Sequence[Union[int, float]], 429 generator: Optional[Generator] = default_generator, 430) -> List[Subset[_T]]: 431 r""" 432 Randomly split a dataset into non-overlapping new datasets of given lengths. 433 434 If a list of fractions that sum up to 1 is given, 435 the lengths will be computed automatically as 436 floor(frac * len(dataset)) for each fraction provided. 437 438 After computing the lengths, if there are any remainders, 1 count will be 439 distributed in round-robin fashion to the lengths 440 until there are no remainders left. 441 442 Optionally fix the generator for reproducible results, e.g.: 443 444 Example: 445 >>> # xdoctest: +SKIP 446 >>> generator1 = torch.Generator().manual_seed(42) 447 >>> generator2 = torch.Generator().manual_seed(42) 448 >>> random_split(range(10), [3, 7], generator=generator1) 449 >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) 450 451 Args: 452 dataset (Dataset): Dataset to be split 453 lengths (sequence): lengths or fractions of splits to be produced 454 generator (Generator): Generator used for the random permutation. 455 """ 456 if math.isclose(sum(lengths), 1) and sum(lengths) <= 1: 457 subset_lengths: List[int] = [] 458 for i, frac in enumerate(lengths): 459 if frac < 0 or frac > 1: 460 raise ValueError(f"Fraction at index {i} is not between 0 and 1") 461 n_items_in_split = int( 462 math.floor(len(dataset) * frac) # type: ignore[arg-type] 463 ) 464 subset_lengths.append(n_items_in_split) 465 remainder = len(dataset) - sum(subset_lengths) # type: ignore[arg-type] 466 # add 1 to all the lengths in round-robin fashion until the remainder is 0 467 for i in range(remainder): 468 idx_to_add_at = i % len(subset_lengths) 469 subset_lengths[idx_to_add_at] += 1 470 lengths = subset_lengths 471 for i, length in enumerate(lengths): 472 if length == 0: 473 warnings.warn( 474 f"Length of split at index {i} is 0. " 475 f"This might result in an empty dataset." 476 ) 477 478 # Cannot verify that dataset is Sized 479 if sum(lengths) != len(dataset): # type: ignore[arg-type] 480 raise ValueError( 481 "Sum of input lengths does not equal the length of the input dataset!" 482 ) 483 484 indices = randperm(sum(lengths), generator=generator).tolist() # type: ignore[arg-type, call-overload] 485 lengths = cast(Sequence[int], lengths) 486 return [ 487 Subset(dataset, indices[offset - length : offset]) 488 for offset, length in zip(itertools.accumulate(lengths), lengths) 489 ] 490