1# mypy: allow-untyped-defs 2import collections 3import dataclasses 4import io 5import operator 6import os 7import pickle 8import queue 9import threading 10import uuid 11import warnings 12from abc import ABC, abstractmethod 13from contextlib import contextmanager 14from dataclasses import dataclass 15from pathlib import Path 16from typing import ( 17 Any, 18 Callable, 19 cast, 20 Dict, 21 Generator, 22 IO, 23 Iterable, 24 Iterator, 25 List, 26 Optional, 27 Tuple, 28 Union, 29) 30 31import torch 32from torch import Tensor 33from torch._utils import _get_available_device_type, _get_device_module 34from torch.distributed._shard._utils import narrow_tensor_by_index 35from torch.distributed.checkpoint.metadata import ( 36 Metadata, 37 MetadataIndex, 38 STATE_DICT_TYPE, 39 StorageMeta, 40) 41from torch.distributed.checkpoint.planner import ( 42 LoadItemType, 43 LoadPlan, 44 LoadPlanner, 45 ReadItem, 46 SavePlan, 47 SavePlanner, 48 WriteItem, 49 WriteItemType, 50) 51from torch.distributed.checkpoint.staging import BlockingAsyncStager 52from torch.distributed.checkpoint.storage import ( 53 StorageReader, 54 StorageWriter, 55 WriteResult, 56) 57from torch.distributed.checkpoint.utils import _create_file_view 58from torch.futures import Future 59 60 61__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] 62 63_metadata_fn: str = ".metadata" 64 65 66@dataclass 67class _StorageInfo: 68 """This is the per entry storage info.""" 69 70 relative_path: str 71 offset: int 72 length: int 73 74 75@dataclass 76class _StoragePrefix: 77 prefix: str 78 79 80DEFAULT_SUFFIX = ".distcp" 81 82 83def _generate_uuid() -> str: 84 return str(uuid.uuid4()) 85 86 87class _TensorLoader(ABC): 88 @abstractmethod 89 def add(self, size: int, obj: object) -> None: 90 pass 91 92 @abstractmethod 93 def start_loading(self) -> None: 94 pass 95 96 @abstractmethod 97 def values(self) -> Iterator[Tuple[torch.Tensor, object]]: 98 pass 99 100 101class _SerialCpuLoader(_TensorLoader): 102 def __init__(self, resolve_fun: Callable) -> None: 103 self.resolve_fun = resolve_fun 104 self.items: List[Tuple[int, object]] = [] 105 106 def add(self, size: int, obj: object) -> None: 107 self.items.append((size, obj)) 108 109 def start_loading(self) -> None: 110 pass 111 112 def values(self) -> Iterator[Tuple[torch.Tensor, object]]: 113 for _, obj in self.items: 114 tensor = self.resolve_fun(obj).detach() 115 tensor = tensor.cpu() 116 if tensor.storage().size() != tensor.numel(): 117 tensor = tensor.clone() 118 yield ( 119 tensor, 120 obj, 121 ) 122 123 124class _OverlappingCpuLoader(_TensorLoader): 125 def __init__( 126 self, 127 resolve_fun: Callable, 128 stream: Optional[torch.Stream] = None, 129 inflight_threshhold: int = 1_000_000, 130 ) -> None: 131 self.resolve_fun = resolve_fun 132 self.items: List[Tuple[int, object]] = [] 133 self.inflight_threshhold = inflight_threshhold 134 self.in_flight_data = 0 135 self.current_items: collections.deque = collections.deque() 136 self.idx = 0 137 self.started = False 138 self.device_type = ( 139 stream.device_type if stream else _get_available_device_type() 140 ) 141 self.device_module = _get_device_module(self.device_type) 142 self.stream = cast( 143 torch.cuda.Stream, stream or self.device_module.current_stream() 144 ) 145 if self.stream != self.device_module.current_stream(): 146 self.stream.wait_stream(self.device_module.current_stream()) 147 148 @property 149 def _done(self) -> bool: 150 return self.idx >= len(self.items) 151 152 def _drain(self) -> List[Tuple[torch.Tensor, object]]: 153 drained = [] 154 if self.in_flight_data >= self.inflight_threshhold: 155 self.stream.synchronize() 156 while self.in_flight_data >= self.inflight_threshhold: 157 val = self.current_items.popleft() 158 self.in_flight_data -= val[0].numel() * val[0].element_size() 159 drained.append(val) 160 return drained 161 162 def _refill(self) -> None: 163 with self.device_module.stream(self.stream): 164 while not self._done and self.in_flight_data < self.inflight_threshhold: 165 _, obj = self.items[self.idx] 166 self.idx += 1 167 tensor = self.resolve_fun(obj).detach() 168 if tensor.device.type == self.device_type: 169 tensor = tensor.to(device="cpu", non_blocking=True) 170 elif tensor.device == torch.device("cpu"): 171 if ( 172 tensor.untyped_storage().size() 173 != tensor.numel() * tensor.itemsize 174 ): 175 # this forces the tensor to be both contiguous and with minimal storage 176 tensor = tensor.clone() 177 178 self.current_items.append( 179 ( 180 tensor, 181 obj, 182 ) 183 ) 184 self.in_flight_data += tensor.numel() * tensor.element_size() 185 186 def _finish(self) -> Iterable[Tuple[torch.Tensor, object]]: 187 assert self._done 188 if len(self.current_items) > 0: 189 self.stream.synchronize() 190 return self.current_items 191 192 def add(self, size: int, obj: object) -> None: 193 if self.started: 194 raise RuntimeError("cannot add items after loading started") 195 self.items.append((size, obj)) 196 197 def start_loading(self) -> None: 198 if self.started: 199 return 200 self.started = True 201 self.items.sort(key=operator.itemgetter(0)) 202 self._refill() 203 204 def values(self) -> Iterator[Tuple[torch.Tensor, object]]: 205 self.start_loading() 206 while not self._done: 207 drained = self._drain() 208 self._refill() 209 yield from drained 210 211 yield from self._finish() 212 213 214def _item_size(item: WriteItem) -> int: 215 size = 1 216 assert item.tensor_data is not None 217 # can't use math.prod as PT needs to support older python 218 for s in item.tensor_data.size: 219 size *= s 220 221 dtype = item.tensor_data.properties.dtype 222 return size * torch._utils._element_size(dtype) 223 224 225def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]: 226 if bins == 1: 227 return [items] 228 229 bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] 230 tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] 231 232 buckets: List[List[WriteItem]] = [[] for _ in range(bins)] 233 bucket_sizes = [0 for _ in range(bins)] 234 235 tensor_w.sort(key=_item_size, reverse=True) 236 237 for i, wi in enumerate(bytes_w): 238 buckets[i % bins].append(wi) 239 240 for wi in tensor_w: 241 # TODO replace with headq 242 idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0] 243 buckets[idx].append(wi) 244 bucket_sizes[idx] += _item_size(wi) 245 246 return buckets 247 248 249def _write_item( 250 stream: io.IOBase, 251 data: Union[io.BytesIO, torch.Tensor], 252 write_item: WriteItem, 253 storage_key: str, 254) -> WriteResult: 255 offset = stream.tell() 256 257 if write_item.type == WriteItemType.BYTE_IO: 258 assert isinstance(data, io.BytesIO) 259 stream.write(data.getbuffer()) 260 else: 261 assert isinstance(data, torch.Tensor) 262 assert data.device == torch.device("cpu") 263 torch.save(data, cast(IO[bytes], stream)) 264 length = stream.tell() - offset 265 266 return WriteResult( 267 index=write_item.index, 268 size_in_bytes=length, 269 storage_data=_StorageInfo(storage_key, offset, length), 270 ) 271 272 273def _write_files_from_queue( 274 create_stream: Callable, 275 file_queue: queue.Queue, 276 result_queue: queue.Queue, 277 planner: SavePlanner, 278 inflight_threshhold: int, 279 use_fsync: bool, 280 thread_count: int, 281) -> None: 282 try: 283 while True: 284 file_name, storage_key, write_items = file_queue.get_nowait() 285 loader: _TensorLoader 286 287 custom_backend_name = torch._C._get_privateuse1_backend_name() 288 custom_device_mod = getattr(torch, custom_backend_name, None) 289 290 # TODO: Using the OverlappingCpuLoader with multiple threads creates significant 291 # performance degredation, observed as being related to cuda stream syncs. We 292 # should try to fix this and use _OverlappingCpuLoader for all threaded cases 293 if ( 294 thread_count == 1 295 and ( 296 torch.cuda.is_available() 297 or (custom_device_mod and custom_device_mod.is_available()) 298 ) 299 and inflight_threshhold > 0 300 ): 301 loader = _OverlappingCpuLoader( 302 planner.resolve_data, 303 inflight_threshhold=inflight_threshhold, 304 ) 305 else: 306 loader = _SerialCpuLoader( 307 planner.resolve_data, 308 ) 309 310 tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] 311 for write_item in tensor_w: 312 loader.add(_item_size(write_item), write_item) 313 loader.start_loading() 314 315 bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] 316 write_results = [] 317 318 with create_stream(file_name, "wb") as stream: 319 for write_item in bytes_w: 320 data = planner.resolve_data(write_item) 321 write_results.append( 322 _write_item(stream, data, write_item, storage_key) 323 ) 324 325 for tensor, write_item in loader.values(): 326 assert tensor.is_cpu 327 write_results.append( 328 _write_item(stream, tensor, write_item, storage_key) 329 ) 330 331 if use_fsync: 332 try: 333 os.fsync(stream.fileno()) 334 except AttributeError: 335 os.sync() 336 result_queue.put(write_results) 337 except queue.Empty: 338 pass 339 340 341class FileSystemBase(ABC): 342 @contextmanager 343 @abstractmethod 344 def create_stream( 345 self, path: Union[str, os.PathLike], mode: str 346 ) -> Generator[io.IOBase, None, None]: 347 ... 348 349 @abstractmethod 350 def concat_path( 351 self, path: Union[str, os.PathLike], suffix: str 352 ) -> Union[str, os.PathLike]: 353 ... 354 355 @abstractmethod 356 def rename( 357 self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] 358 ) -> None: 359 ... 360 361 @abstractmethod 362 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: 363 ... 364 365 @abstractmethod 366 def mkdir(self, path: Union[str, os.PathLike]) -> None: 367 ... 368 369 @classmethod 370 @abstractmethod 371 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 372 ... 373 374 @abstractmethod 375 def exists(self, path: Union[str, os.PathLike]) -> bool: 376 ... 377 378 @abstractmethod 379 def rm_file(self, path: Union[str, os.PathLike]) -> None: 380 ... 381 382 383class FileSystem(FileSystemBase): 384 @contextmanager 385 def create_stream( 386 self, path: Union[str, os.PathLike], mode: str 387 ) -> Generator[io.IOBase, None, None]: 388 with cast(Path, path).open(mode) as stream: 389 yield cast(io.IOBase, stream) 390 391 def concat_path( 392 self, path: Union[str, os.PathLike], suffix: str 393 ) -> Union[str, os.PathLike]: 394 return cast(Path, path) / suffix 395 396 def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]: 397 if not isinstance(path, Path): 398 path = Path(path) 399 return path 400 401 def rename( 402 self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike] 403 ) -> None: 404 cast(Path, path).rename(cast(Path, new_path)) 405 406 def mkdir(self, path: Union[str, os.PathLike]) -> None: 407 cast(Path, path).mkdir(parents=True, exist_ok=True) 408 409 @classmethod 410 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 411 if isinstance(checkpoint_id, Path): 412 return True 413 414 if "://" in str(checkpoint_id): 415 return False 416 417 for p in Path(checkpoint_id).parents: 418 if p.exists() and os.access(str(p), os.W_OK): 419 return True 420 421 return False 422 423 def exists(self, path: Union[str, os.PathLike]) -> bool: 424 return cast(Path, path).exists() 425 426 def rm_file(self, path: Union[str, os.PathLike]) -> None: 427 cast(Path, path).unlink() 428 429 430class _FileSystemWriter(StorageWriter): 431 """ 432 Basic implementation of StorageWriter using file IO. 433 434 This implementation makes the following assumptions and simplifications: 435 436 * The checkpoint path is an empty or non-existing directory. 437 * File creation is atomic 438 439 The checkpoint consist of one file per write request plus 440 a `.metadata` file with the serialized metadata. 441 442 """ 443 444 def __init__( 445 self, 446 path: Union[str, os.PathLike], 447 single_file_per_rank: bool = True, 448 sync_files: bool = True, 449 thread_count: int = 1, 450 per_thread_copy_ahead: int = 10_000_000, 451 overwrite: bool = True, 452 *args: Any, 453 **kwargs: Any, 454 ) -> None: 455 """ 456 Initialize the writer pointing to `path`. 457 458 Args: 459 path: directory where the checkpoint will be written to. 460 single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. 461 sync_files : force files to be synced to permanent storage. Default to True. 462 thread_count: Number of IO threads to use to write. Default to 1. 463 per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. 464 overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. 465 466 N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. 467 """ 468 super().__init__() 469 self.fs = FileSystem() 470 self.path = self.fs.init_path(path) 471 self.single_file_per_rank = single_file_per_rank 472 self.sync_files = sync_files 473 self.thread_count = thread_count 474 self.per_thread_copy_ahead = per_thread_copy_ahead 475 self.save_id = _generate_uuid() 476 self.overwrite = overwrite 477 478 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 479 if checkpoint_id: 480 self.path = self.fs.init_path(checkpoint_id) 481 self.save_id = _generate_uuid() 482 483 def set_up_storage_writer(self, is_coordinator: bool) -> None: 484 pass 485 486 def prepare_local_plan(self, plan: SavePlan) -> SavePlan: 487 self.fs.mkdir(self.path) 488 if self.fs.exists(self.metadata_path): 489 if self.overwrite: 490 warnings.warn( 491 f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}." 492 " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to" 493 " maintain this functionality or False to raise when an existing checkpoint is found." 494 ) 495 else: 496 raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.") 497 498 return plan 499 500 def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: 501 new_plans = [ 502 dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) 503 for i, plan in enumerate(plans) 504 ] 505 return new_plans 506 507 def write_data( 508 self, 509 plan: SavePlan, 510 planner: SavePlanner, 511 ) -> Future[List[WriteResult]]: 512 storage_plan: _StoragePrefix = plan.storage_data 513 file_count = 0 514 515 def gen_file(): 516 nonlocal file_count 517 file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" 518 file_count += 1 519 return file_name 520 521 file_queue: queue.Queue = queue.Queue() 522 if self.single_file_per_rank: 523 for bucket in _split_by_size_and_type(self.thread_count, plan.items): 524 file_name = gen_file() 525 path = self.fs.concat_path(self.path, file_name) 526 file_queue.put((path, file_name, bucket)) 527 else: 528 for item in plan.items: 529 file_name = gen_file() 530 path = self.fs.concat_path(self.path, file_name) 531 file_queue.put((path, file_name, [item])) 532 533 result_queue: queue.Queue = queue.Queue() 534 535 threads = [] 536 for _ in range(1, self.thread_count): 537 t = threading.Thread( 538 target=_write_files_from_queue, 539 args=( 540 self.fs.create_stream, 541 file_queue, 542 result_queue, 543 planner, 544 self.per_thread_copy_ahead, 545 self.sync_files, 546 self.thread_count, 547 ), 548 ) 549 t.start() 550 threads.append(t) 551 552 _write_files_from_queue( 553 create_stream=self.fs.create_stream, 554 file_queue=file_queue, 555 result_queue=result_queue, 556 planner=planner, 557 inflight_threshhold=self.per_thread_copy_ahead, 558 use_fsync=self.sync_files, 559 thread_count=self.thread_count, 560 ) 561 562 for t in threads: 563 t.join() 564 565 res = [] 566 try: 567 while True: 568 res += result_queue.get_nowait() 569 except queue.Empty: 570 fut: Future[List[WriteResult]] = Future() 571 fut.set_result(res) 572 return fut 573 574 def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: 575 storage_md = {} 576 for wr_list in results: 577 storage_md.update({wr.index: wr.storage_data for wr in wr_list}) 578 metadata.storage_data = storage_md 579 580 metadata.storage_meta = self.storage_meta() 581 582 tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp")) 583 with self.fs.create_stream(tmp_path, "wb") as metadata_file: 584 pickle.dump(metadata, metadata_file) 585 if self.sync_files: 586 try: 587 os.fsync(metadata_file.fileno()) 588 except AttributeError: 589 os.sync() 590 591 # delete in-case other checkpoints were present. 592 if self.fs.exists(self.metadata_path): 593 self.fs.rm_file(self.metadata_path) 594 595 self.fs.rename(tmp_path, self.metadata_path) 596 597 def storage_meta(self) -> Optional[StorageMeta]: 598 return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id) 599 600 @property 601 def metadata_path(self) -> Union[str, os.PathLike]: 602 return cast(Path, self.fs.concat_path(self.path, _metadata_fn)) 603 604 @property 605 def checkpoint_id(self) -> Union[str, os.PathLike]: 606 """ 607 return the checkpoint_id that will be used to save the checkpoint. 608 """ 609 return self.path 610 611 @classmethod 612 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 613 return FileSystem.validate_checkpoint_id(checkpoint_id) 614 615 616class FileSystemReader(StorageReader): 617 def __init__(self, path: Union[str, os.PathLike]) -> None: 618 super().__init__() 619 self.fs = FileSystem() 620 self.path = self.fs.init_path(path) 621 self.storage_data: Dict[MetadataIndex, _StorageInfo] = {} 622 self.load_id = _generate_uuid() 623 624 def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase: 625 return _create_file_view(file, sinfo.offset, sinfo.length) 626 627 def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: 628 self.storage_data = {} 629 if checkpoint_id: 630 self.path = self.fs.init_path(checkpoint_id) 631 self.load_id = _generate_uuid() 632 633 def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: 634 # group requests by file 635 per_file: Dict[str, List[ReadItem]] = {} 636 for read_item in plan.items: 637 item_md = self.storage_data[read_item.storage_index] 638 path = item_md.relative_path 639 per_file.setdefault(path, []).append(read_item) 640 641 for relative_path, reqs in per_file.items(): 642 new_path = self.fs.concat_path(self.path, relative_path) 643 with self.fs.create_stream(new_path, "rb") as stream: 644 # TODO sort by offset and cache the reading 645 for req in reqs: 646 item_md = self.storage_data[req.storage_index] 647 file_slice = self._slice_file(stream, item_md) 648 if req.type == LoadItemType.BYTE_IO: 649 read_bytes = io.BytesIO(file_slice.read(item_md.length)) 650 read_bytes.seek(0) 651 planner.load_bytes(req, read_bytes) 652 else: 653 tensor = cast( 654 Tensor, 655 torch.load( 656 cast(IO[bytes], file_slice), 657 map_location="cpu", 658 weights_only=True, 659 ), 660 ) 661 tensor = narrow_tensor_by_index( 662 tensor, req.storage_offsets, req.lengths 663 ) 664 target_tensor = planner.resolve_tensor(req).detach() 665 666 assert ( 667 target_tensor.size() == tensor.size() 668 ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" 669 target_tensor.copy_(tensor) 670 planner.commit_tensor(req, target_tensor) 671 672 fut: Future = Future() 673 fut.set_result(None) 674 return fut 675 676 # Implementing the abstract function in StorageReader 677 def read_metadata(self) -> Metadata: 678 path = self.fs.concat_path(self.path, ".metadata") 679 with self.fs.create_stream(path, "rb") as metadata_file: 680 metadata = pickle.load(metadata_file) 681 682 if getattr(metadata, "storage_meta", None) is None: 683 metadata.storage_meta = StorageMeta() 684 metadata.storage_meta.load_id = self.load_id 685 686 return metadata 687 688 def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: 689 self.storage_data = metadata.storage_data 690 assert self.storage_data is not None 691 692 def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: 693 return plan 694 695 def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: 696 return plans 697 698 @property 699 def checkpoint_id(self) -> Union[str, os.PathLike]: 700 """ 701 return the checkpoint_id that will be used to load the checkpoint. 702 """ 703 return self.path 704 705 @classmethod 706 def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: 707 return FileSystem.validate_checkpoint_id(checkpoint_id) 708 709 710class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager): 711 """ 712 Basic implementation of StorageWriter using file IO. 713 714 This implementation makes the following assumptions and simplifications: 715 716 * The checkpoint path is an empty or non-existing directory. 717 * File creation is atomic 718 719 The checkpoint consist of one file per write request plus 720 a `.metadata` file with the serialized metadata. 721 722 """ 723 724 def __init__( 725 self, 726 path: Union[str, os.PathLike], 727 single_file_per_rank: bool = True, 728 sync_files: bool = True, 729 thread_count: int = 1, 730 per_thread_copy_ahead: int = 10_000_000, 731 cache_staged_state_dict: bool = False, 732 overwrite: bool = True, 733 ) -> None: 734 """ 735 Initialize the writer pointing to `path`. 736 737 Args: 738 path: directory where the checkpoint will be written to. 739 single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. 740 sync_files : force files to be synced to permanent storage. Default to True. 741 thread_count: Number of IO threads to use to write. Default to 1. 742 per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. 743 cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency 744 at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation 745 that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False. 746 overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. 747 748 N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. 749 """ 750 super().__init__( 751 path=path, 752 single_file_per_rank=single_file_per_rank, 753 sync_files=sync_files, 754 thread_count=thread_count, 755 per_thread_copy_ahead=per_thread_copy_ahead, 756 cache_staged_state_dict=cache_staged_state_dict, 757 overwrite=overwrite, 758 ) 759 760 def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE: 761 """Override of AsyncStager.stage""" 762 # in the async case, the state dict is already on CPU, so maintaining this 763 # buffer makes no sense 764 self.per_thread_copy_ahead = 0 765 return super().stage(state_dict) 766