1# mypy: allow-untyped-defs 2from typing import ( 3 Generic, 4 Iterable, 5 Iterator, 6 List, 7 Optional, 8 Sequence, 9 Sized, 10 TypeVar, 11 Union, 12) 13 14import torch 15 16 17__all__ = [ 18 "BatchSampler", 19 "RandomSampler", 20 "Sampler", 21 "SequentialSampler", 22 "SubsetRandomSampler", 23 "WeightedRandomSampler", 24] 25 26 27_T_co = TypeVar("_T_co", covariant=True) 28 29 30class Sampler(Generic[_T_co]): 31 r"""Base class for all Samplers. 32 33 Every Sampler subclass has to provide an :meth:`__iter__` method, providing a 34 way to iterate over indices or lists of indices (batches) of dataset elements, 35 and may provide a :meth:`__len__` method that returns the length of the returned iterators. 36 37 Args: 38 data_source (Dataset): This argument is not used and will be removed in 2.2.0. 39 You may still have custom implementation that utilizes it. 40 41 Example: 42 >>> # xdoctest: +SKIP 43 >>> class AccedingSequenceLengthSampler(Sampler[int]): 44 >>> def __init__(self, data: List[str]) -> None: 45 >>> self.data = data 46 >>> 47 >>> def __len__(self) -> int: 48 >>> return len(self.data) 49 >>> 50 >>> def __iter__(self) -> Iterator[int]: 51 >>> sizes = torch.tensor([len(x) for x in self.data]) 52 >>> yield from torch.argsort(sizes).tolist() 53 >>> 54 >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): 55 >>> def __init__(self, data: List[str], batch_size: int) -> None: 56 >>> self.data = data 57 >>> self.batch_size = batch_size 58 >>> 59 >>> def __len__(self) -> int: 60 >>> return (len(self.data) + self.batch_size - 1) // self.batch_size 61 >>> 62 >>> def __iter__(self) -> Iterator[List[int]]: 63 >>> sizes = torch.tensor([len(x) for x in self.data]) 64 >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): 65 >>> yield batch.tolist() 66 67 .. note:: The :meth:`__len__` method isn't strictly required by 68 :class:`~torch.utils.data.DataLoader`, but is expected in any 69 calculation involving the length of a :class:`~torch.utils.data.DataLoader`. 70 """ 71 72 def __init__(self, data_source: Optional[Sized] = None) -> None: 73 if data_source is not None: 74 import warnings 75 76 warnings.warn( 77 "`data_source` argument is not used and will be removed in 2.2.0." 78 "You may still have custom implementation that utilizes it." 79 ) 80 81 def __iter__(self) -> Iterator[_T_co]: 82 raise NotImplementedError 83 84 # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 85 # 86 # Many times we have an abstract class representing a collection/iterable of 87 # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally 88 # implementing a `__len__` method. In such cases, we must make sure to not 89 # provide a default implementation, because both straightforward default 90 # implementations have their issues: 91 # 92 # + `return NotImplemented`: 93 # Calling `len(subclass_instance)` raises: 94 # TypeError: 'NotImplementedType' object cannot be interpreted as an integer 95 # 96 # + `raise NotImplementedError`: 97 # This prevents triggering some fallback behavior. E.g., the built-in 98 # `list(X)` tries to call `len(X)` first, and executes a different code 99 # path if the method is not found or `NotImplemented` is returned, while 100 # raising a `NotImplementedError` will propagate and make the call fail 101 # where it could have used `__iter__` to complete the call. 102 # 103 # Thus, the only two sensible things to do are 104 # 105 # + **not** provide a default `__len__`. 106 # 107 # + raise a `TypeError` instead, which is what Python uses when users call 108 # a method that is not defined on an object. 109 # (@ssnl verifies that this works on at least Python 3.7.) 110 111 112class SequentialSampler(Sampler[int]): 113 r"""Samples elements sequentially, always in the same order. 114 115 Args: 116 data_source (Dataset): dataset to sample from 117 """ 118 119 data_source: Sized 120 121 def __init__(self, data_source: Sized) -> None: 122 self.data_source = data_source 123 124 def __iter__(self) -> Iterator[int]: 125 return iter(range(len(self.data_source))) 126 127 def __len__(self) -> int: 128 return len(self.data_source) 129 130 131class RandomSampler(Sampler[int]): 132 r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 133 134 If with replacement, then user can specify :attr:`num_samples` to draw. 135 136 Args: 137 data_source (Dataset): dataset to sample from 138 replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` 139 num_samples (int): number of samples to draw, default=`len(dataset)`. 140 generator (Generator): Generator used in sampling. 141 """ 142 143 data_source: Sized 144 replacement: bool 145 146 def __init__( 147 self, 148 data_source: Sized, 149 replacement: bool = False, 150 num_samples: Optional[int] = None, 151 generator=None, 152 ) -> None: 153 self.data_source = data_source 154 self.replacement = replacement 155 self._num_samples = num_samples 156 self.generator = generator 157 158 if not isinstance(self.replacement, bool): 159 raise TypeError( 160 f"replacement should be a boolean value, but got replacement={self.replacement}" 161 ) 162 163 if not isinstance(self.num_samples, int) or self.num_samples <= 0: 164 raise ValueError( 165 f"num_samples should be a positive integer value, but got num_samples={self.num_samples}" 166 ) 167 168 @property 169 def num_samples(self) -> int: 170 # dataset size might change at runtime 171 if self._num_samples is None: 172 return len(self.data_source) 173 return self._num_samples 174 175 def __iter__(self) -> Iterator[int]: 176 n = len(self.data_source) 177 if self.generator is None: 178 seed = int(torch.empty((), dtype=torch.int64).random_().item()) 179 generator = torch.Generator() 180 generator.manual_seed(seed) 181 else: 182 generator = self.generator 183 184 if self.replacement: 185 for _ in range(self.num_samples // 32): 186 yield from torch.randint( 187 high=n, size=(32,), dtype=torch.int64, generator=generator 188 ).tolist() 189 yield from torch.randint( 190 high=n, 191 size=(self.num_samples % 32,), 192 dtype=torch.int64, 193 generator=generator, 194 ).tolist() 195 else: 196 for _ in range(self.num_samples // n): 197 yield from torch.randperm(n, generator=generator).tolist() 198 yield from torch.randperm(n, generator=generator).tolist()[ 199 : self.num_samples % n 200 ] 201 202 def __len__(self) -> int: 203 return self.num_samples 204 205 206class SubsetRandomSampler(Sampler[int]): 207 r"""Samples elements randomly from a given list of indices, without replacement. 208 209 Args: 210 indices (sequence): a sequence of indices 211 generator (Generator): Generator used in sampling. 212 """ 213 214 indices: Sequence[int] 215 216 def __init__(self, indices: Sequence[int], generator=None) -> None: 217 self.indices = indices 218 self.generator = generator 219 220 def __iter__(self) -> Iterator[int]: 221 for i in torch.randperm(len(self.indices), generator=self.generator): 222 yield self.indices[i] 223 224 def __len__(self) -> int: 225 return len(self.indices) 226 227 228class WeightedRandomSampler(Sampler[int]): 229 r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). 230 231 Args: 232 weights (sequence) : a sequence of weights, not necessary summing up to one 233 num_samples (int): number of samples to draw 234 replacement (bool): if ``True``, samples are drawn with replacement. 235 If not, they are drawn without replacement, which means that when a 236 sample index is drawn for a row, it cannot be drawn again for that row. 237 generator (Generator): Generator used in sampling. 238 239 Example: 240 >>> # xdoctest: +IGNORE_WANT("non-deterministic") 241 >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) 242 [4, 4, 1, 4, 5] 243 >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) 244 [0, 1, 4, 3, 2] 245 """ 246 247 weights: torch.Tensor 248 num_samples: int 249 replacement: bool 250 251 def __init__( 252 self, 253 weights: Sequence[float], 254 num_samples: int, 255 replacement: bool = True, 256 generator=None, 257 ) -> None: 258 if ( 259 not isinstance(num_samples, int) 260 or isinstance(num_samples, bool) 261 or num_samples <= 0 262 ): 263 raise ValueError( 264 f"num_samples should be a positive integer value, but got num_samples={num_samples}" 265 ) 266 if not isinstance(replacement, bool): 267 raise ValueError( 268 f"replacement should be a boolean value, but got replacement={replacement}" 269 ) 270 271 weights_tensor = torch.as_tensor(weights, dtype=torch.double) 272 if len(weights_tensor.shape) != 1: 273 raise ValueError( 274 "weights should be a 1d sequence but given " 275 f"weights have shape {tuple(weights_tensor.shape)}" 276 ) 277 278 self.weights = weights_tensor 279 self.num_samples = num_samples 280 self.replacement = replacement 281 self.generator = generator 282 283 def __iter__(self) -> Iterator[int]: 284 rand_tensor = torch.multinomial( 285 self.weights, self.num_samples, self.replacement, generator=self.generator 286 ) 287 yield from iter(rand_tensor.tolist()) 288 289 def __len__(self) -> int: 290 return self.num_samples 291 292 293class BatchSampler(Sampler[List[int]]): 294 r"""Wraps another sampler to yield a mini-batch of indices. 295 296 Args: 297 sampler (Sampler or Iterable): Base sampler. Can be any iterable object 298 batch_size (int): Size of mini-batch. 299 drop_last (bool): If ``True``, the sampler will drop the last batch if 300 its size would be less than ``batch_size`` 301 302 Example: 303 >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 304 [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 305 >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 306 [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 307 """ 308 309 def __init__( 310 self, 311 sampler: Union[Sampler[int], Iterable[int]], 312 batch_size: int, 313 drop_last: bool, 314 ) -> None: 315 # Since collections.abc.Iterable does not check for `__getitem__`, which 316 # is one way for an object to be an iterable, we don't do an `isinstance` 317 # check here. 318 if ( 319 not isinstance(batch_size, int) 320 or isinstance(batch_size, bool) 321 or batch_size <= 0 322 ): 323 raise ValueError( 324 f"batch_size should be a positive integer value, but got batch_size={batch_size}" 325 ) 326 if not isinstance(drop_last, bool): 327 raise ValueError( 328 f"drop_last should be a boolean value, but got drop_last={drop_last}" 329 ) 330 self.sampler = sampler 331 self.batch_size = batch_size 332 self.drop_last = drop_last 333 334 def __iter__(self) -> Iterator[List[int]]: 335 # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 336 if self.drop_last: 337 sampler_iter = iter(self.sampler) 338 while True: 339 try: 340 batch = [next(sampler_iter) for _ in range(self.batch_size)] 341 yield batch 342 except StopIteration: 343 break 344 else: 345 batch = [0] * self.batch_size 346 idx_in_batch = 0 347 for idx in self.sampler: 348 batch[idx_in_batch] = idx 349 idx_in_batch += 1 350 if idx_in_batch == self.batch_size: 351 yield batch 352 idx_in_batch = 0 353 batch = [0] * self.batch_size 354 if idx_in_batch > 0: 355 yield batch[:idx_in_batch] 356 357 def __len__(self) -> int: 358 # Can only be called if self.sampler has __len__ implemented 359 # We cannot enforce this condition, so we turn off typechecking for the 360 # implementation below. 361 # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 362 if self.drop_last: 363 return len(self.sampler) // self.batch_size # type: ignore[arg-type] 364 else: 365 return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] 366