xref: /aosp_15_r20/external/pytorch/torch/distributed/checkpoint/filesystem.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import dataclasses
4import io
5import operator
6import os
7import pickle
8import queue
9import threading
10import uuid
11import warnings
12from abc import ABC, abstractmethod
13from contextlib import contextmanager
14from dataclasses import dataclass
15from pathlib import Path
16from typing import (
17    Any,
18    Callable,
19    cast,
20    Dict,
21    Generator,
22    IO,
23    Iterable,
24    Iterator,
25    List,
26    Optional,
27    Tuple,
28    Union,
29)
30
31import torch
32from torch import Tensor
33from torch._utils import _get_available_device_type, _get_device_module
34from torch.distributed._shard._utils import narrow_tensor_by_index
35from torch.distributed.checkpoint.metadata import (
36    Metadata,
37    MetadataIndex,
38    STATE_DICT_TYPE,
39    StorageMeta,
40)
41from torch.distributed.checkpoint.planner import (
42    LoadItemType,
43    LoadPlan,
44    LoadPlanner,
45    ReadItem,
46    SavePlan,
47    SavePlanner,
48    WriteItem,
49    WriteItemType,
50)
51from torch.distributed.checkpoint.staging import BlockingAsyncStager
52from torch.distributed.checkpoint.storage import (
53    StorageReader,
54    StorageWriter,
55    WriteResult,
56)
57from torch.distributed.checkpoint.utils import _create_file_view
58from torch.futures import Future
59
60
61__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"]
62
63_metadata_fn: str = ".metadata"
64
65
66@dataclass
67class _StorageInfo:
68    """This is the per entry storage info."""
69
70    relative_path: str
71    offset: int
72    length: int
73
74
75@dataclass
76class _StoragePrefix:
77    prefix: str
78
79
80DEFAULT_SUFFIX = ".distcp"
81
82
83def _generate_uuid() -> str:
84    return str(uuid.uuid4())
85
86
87class _TensorLoader(ABC):
88    @abstractmethod
89    def add(self, size: int, obj: object) -> None:
90        pass
91
92    @abstractmethod
93    def start_loading(self) -> None:
94        pass
95
96    @abstractmethod
97    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
98        pass
99
100
101class _SerialCpuLoader(_TensorLoader):
102    def __init__(self, resolve_fun: Callable) -> None:
103        self.resolve_fun = resolve_fun
104        self.items: List[Tuple[int, object]] = []
105
106    def add(self, size: int, obj: object) -> None:
107        self.items.append((size, obj))
108
109    def start_loading(self) -> None:
110        pass
111
112    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
113        for _, obj in self.items:
114            tensor = self.resolve_fun(obj).detach()
115            tensor = tensor.cpu()
116            if tensor.storage().size() != tensor.numel():
117                tensor = tensor.clone()
118            yield (
119                tensor,
120                obj,
121            )
122
123
124class _OverlappingCpuLoader(_TensorLoader):
125    def __init__(
126        self,
127        resolve_fun: Callable,
128        stream: Optional[torch.Stream] = None,
129        inflight_threshhold: int = 1_000_000,
130    ) -> None:
131        self.resolve_fun = resolve_fun
132        self.items: List[Tuple[int, object]] = []
133        self.inflight_threshhold = inflight_threshhold
134        self.in_flight_data = 0
135        self.current_items: collections.deque = collections.deque()
136        self.idx = 0
137        self.started = False
138        self.device_type = (
139            stream.device_type if stream else _get_available_device_type()
140        )
141        self.device_module = _get_device_module(self.device_type)
142        self.stream = cast(
143            torch.cuda.Stream, stream or self.device_module.current_stream()
144        )
145        if self.stream != self.device_module.current_stream():
146            self.stream.wait_stream(self.device_module.current_stream())
147
148    @property
149    def _done(self) -> bool:
150        return self.idx >= len(self.items)
151
152    def _drain(self) -> List[Tuple[torch.Tensor, object]]:
153        drained = []
154        if self.in_flight_data >= self.inflight_threshhold:
155            self.stream.synchronize()
156        while self.in_flight_data >= self.inflight_threshhold:
157            val = self.current_items.popleft()
158            self.in_flight_data -= val[0].numel() * val[0].element_size()
159            drained.append(val)
160        return drained
161
162    def _refill(self) -> None:
163        with self.device_module.stream(self.stream):
164            while not self._done and self.in_flight_data < self.inflight_threshhold:
165                _, obj = self.items[self.idx]
166                self.idx += 1
167                tensor = self.resolve_fun(obj).detach()
168                if tensor.device.type == self.device_type:
169                    tensor = tensor.to(device="cpu", non_blocking=True)
170                elif tensor.device == torch.device("cpu"):
171                    if (
172                        tensor.untyped_storage().size()
173                        != tensor.numel() * tensor.itemsize
174                    ):
175                        # this forces the tensor to be both contiguous and with minimal storage
176                        tensor = tensor.clone()
177
178                self.current_items.append(
179                    (
180                        tensor,
181                        obj,
182                    )
183                )
184                self.in_flight_data += tensor.numel() * tensor.element_size()
185
186    def _finish(self) -> Iterable[Tuple[torch.Tensor, object]]:
187        assert self._done
188        if len(self.current_items) > 0:
189            self.stream.synchronize()
190        return self.current_items
191
192    def add(self, size: int, obj: object) -> None:
193        if self.started:
194            raise RuntimeError("cannot add items after loading started")
195        self.items.append((size, obj))
196
197    def start_loading(self) -> None:
198        if self.started:
199            return
200        self.started = True
201        self.items.sort(key=operator.itemgetter(0))
202        self._refill()
203
204    def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
205        self.start_loading()
206        while not self._done:
207            drained = self._drain()
208            self._refill()
209            yield from drained
210
211        yield from self._finish()
212
213
214def _item_size(item: WriteItem) -> int:
215    size = 1
216    assert item.tensor_data is not None
217    # can't use math.prod as PT needs to support older python
218    for s in item.tensor_data.size:
219        size *= s
220
221    dtype = item.tensor_data.properties.dtype
222    return size * torch._utils._element_size(dtype)
223
224
225def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
226    if bins == 1:
227        return [items]
228
229    bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
230    tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
231
232    buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
233    bucket_sizes = [0 for _ in range(bins)]
234
235    tensor_w.sort(key=_item_size, reverse=True)
236
237    for i, wi in enumerate(bytes_w):
238        buckets[i % bins].append(wi)
239
240    for wi in tensor_w:
241        # TODO replace with headq
242        idx = min(enumerate(bucket_sizes), key=operator.itemgetter(1))[0]
243        buckets[idx].append(wi)
244        bucket_sizes[idx] += _item_size(wi)
245
246    return buckets
247
248
249def _write_item(
250    stream: io.IOBase,
251    data: Union[io.BytesIO, torch.Tensor],
252    write_item: WriteItem,
253    storage_key: str,
254) -> WriteResult:
255    offset = stream.tell()
256
257    if write_item.type == WriteItemType.BYTE_IO:
258        assert isinstance(data, io.BytesIO)
259        stream.write(data.getbuffer())
260    else:
261        assert isinstance(data, torch.Tensor)
262        assert data.device == torch.device("cpu")
263        torch.save(data, cast(IO[bytes], stream))
264    length = stream.tell() - offset
265
266    return WriteResult(
267        index=write_item.index,
268        size_in_bytes=length,
269        storage_data=_StorageInfo(storage_key, offset, length),
270    )
271
272
273def _write_files_from_queue(
274    create_stream: Callable,
275    file_queue: queue.Queue,
276    result_queue: queue.Queue,
277    planner: SavePlanner,
278    inflight_threshhold: int,
279    use_fsync: bool,
280    thread_count: int,
281) -> None:
282    try:
283        while True:
284            file_name, storage_key, write_items = file_queue.get_nowait()
285            loader: _TensorLoader
286
287            custom_backend_name = torch._C._get_privateuse1_backend_name()
288            custom_device_mod = getattr(torch, custom_backend_name, None)
289
290            # TODO: Using the OverlappingCpuLoader with multiple threads creates significant
291            # performance degredation, observed as being related to cuda stream syncs. We
292            # should try to fix this and use _OverlappingCpuLoader for all threaded cases
293            if (
294                thread_count == 1
295                and (
296                    torch.cuda.is_available()
297                    or (custom_device_mod and custom_device_mod.is_available())
298                )
299                and inflight_threshhold > 0
300            ):
301                loader = _OverlappingCpuLoader(
302                    planner.resolve_data,
303                    inflight_threshhold=inflight_threshhold,
304                )
305            else:
306                loader = _SerialCpuLoader(
307                    planner.resolve_data,
308                )
309
310            tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
311            for write_item in tensor_w:
312                loader.add(_item_size(write_item), write_item)
313            loader.start_loading()
314
315            bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
316            write_results = []
317
318            with create_stream(file_name, "wb") as stream:
319                for write_item in bytes_w:
320                    data = planner.resolve_data(write_item)
321                    write_results.append(
322                        _write_item(stream, data, write_item, storage_key)
323                    )
324
325                for tensor, write_item in loader.values():
326                    assert tensor.is_cpu
327                    write_results.append(
328                        _write_item(stream, tensor, write_item, storage_key)
329                    )
330
331                if use_fsync:
332                    try:
333                        os.fsync(stream.fileno())
334                    except AttributeError:
335                        os.sync()
336            result_queue.put(write_results)
337    except queue.Empty:
338        pass
339
340
341class FileSystemBase(ABC):
342    @contextmanager
343    @abstractmethod
344    def create_stream(
345        self, path: Union[str, os.PathLike], mode: str
346    ) -> Generator[io.IOBase, None, None]:
347        ...
348
349    @abstractmethod
350    def concat_path(
351        self, path: Union[str, os.PathLike], suffix: str
352    ) -> Union[str, os.PathLike]:
353        ...
354
355    @abstractmethod
356    def rename(
357        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
358    ) -> None:
359        ...
360
361    @abstractmethod
362    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
363        ...
364
365    @abstractmethod
366    def mkdir(self, path: Union[str, os.PathLike]) -> None:
367        ...
368
369    @classmethod
370    @abstractmethod
371    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
372        ...
373
374    @abstractmethod
375    def exists(self, path: Union[str, os.PathLike]) -> bool:
376        ...
377
378    @abstractmethod
379    def rm_file(self, path: Union[str, os.PathLike]) -> None:
380        ...
381
382
383class FileSystem(FileSystemBase):
384    @contextmanager
385    def create_stream(
386        self, path: Union[str, os.PathLike], mode: str
387    ) -> Generator[io.IOBase, None, None]:
388        with cast(Path, path).open(mode) as stream:
389            yield cast(io.IOBase, stream)
390
391    def concat_path(
392        self, path: Union[str, os.PathLike], suffix: str
393    ) -> Union[str, os.PathLike]:
394        return cast(Path, path) / suffix
395
396    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
397        if not isinstance(path, Path):
398            path = Path(path)
399        return path
400
401    def rename(
402        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
403    ) -> None:
404        cast(Path, path).rename(cast(Path, new_path))
405
406    def mkdir(self, path: Union[str, os.PathLike]) -> None:
407        cast(Path, path).mkdir(parents=True, exist_ok=True)
408
409    @classmethod
410    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
411        if isinstance(checkpoint_id, Path):
412            return True
413
414        if "://" in str(checkpoint_id):
415            return False
416
417        for p in Path(checkpoint_id).parents:
418            if p.exists() and os.access(str(p), os.W_OK):
419                return True
420
421        return False
422
423    def exists(self, path: Union[str, os.PathLike]) -> bool:
424        return cast(Path, path).exists()
425
426    def rm_file(self, path: Union[str, os.PathLike]) -> None:
427        cast(Path, path).unlink()
428
429
430class _FileSystemWriter(StorageWriter):
431    """
432    Basic implementation of StorageWriter using file IO.
433
434    This implementation makes the following assumptions and simplifications:
435
436    * The checkpoint path is an empty or non-existing directory.
437    * File creation is atomic
438
439    The checkpoint consist of one file per write request plus
440    a `.metadata` file with the serialized metadata.
441
442    """
443
444    def __init__(
445        self,
446        path: Union[str, os.PathLike],
447        single_file_per_rank: bool = True,
448        sync_files: bool = True,
449        thread_count: int = 1,
450        per_thread_copy_ahead: int = 10_000_000,
451        overwrite: bool = True,
452        *args: Any,
453        **kwargs: Any,
454    ) -> None:
455        """
456        Initialize the writer pointing to `path`.
457
458        Args:
459            path: directory where the checkpoint will be written to.
460            single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
461            sync_files : force files to be synced to permanent storage. Default to True.
462            thread_count: Number of IO threads to use to write. Default to 1.
463            per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
464            overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
465
466        N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
467        """
468        super().__init__()
469        self.fs = FileSystem()
470        self.path = self.fs.init_path(path)
471        self.single_file_per_rank = single_file_per_rank
472        self.sync_files = sync_files
473        self.thread_count = thread_count
474        self.per_thread_copy_ahead = per_thread_copy_ahead
475        self.save_id = _generate_uuid()
476        self.overwrite = overwrite
477
478    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
479        if checkpoint_id:
480            self.path = self.fs.init_path(checkpoint_id)
481        self.save_id = _generate_uuid()
482
483    def set_up_storage_writer(self, is_coordinator: bool) -> None:
484        pass
485
486    def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
487        self.fs.mkdir(self.path)
488        if self.fs.exists(self.metadata_path):
489            if self.overwrite:
490                warnings.warn(
491                    f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}."
492                    " Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
493                    " maintain this functionality or False to raise when an existing checkpoint is found."
494                )
495            else:
496                raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.")
497
498        return plan
499
500    def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
501        new_plans = [
502            dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
503            for i, plan in enumerate(plans)
504        ]
505        return new_plans
506
507    def write_data(
508        self,
509        plan: SavePlan,
510        planner: SavePlanner,
511    ) -> Future[List[WriteResult]]:
512        storage_plan: _StoragePrefix = plan.storage_data
513        file_count = 0
514
515        def gen_file():
516            nonlocal file_count
517            file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
518            file_count += 1
519            return file_name
520
521        file_queue: queue.Queue = queue.Queue()
522        if self.single_file_per_rank:
523            for bucket in _split_by_size_and_type(self.thread_count, plan.items):
524                file_name = gen_file()
525                path = self.fs.concat_path(self.path, file_name)
526                file_queue.put((path, file_name, bucket))
527        else:
528            for item in plan.items:
529                file_name = gen_file()
530                path = self.fs.concat_path(self.path, file_name)
531                file_queue.put((path, file_name, [item]))
532
533        result_queue: queue.Queue = queue.Queue()
534
535        threads = []
536        for _ in range(1, self.thread_count):
537            t = threading.Thread(
538                target=_write_files_from_queue,
539                args=(
540                    self.fs.create_stream,
541                    file_queue,
542                    result_queue,
543                    planner,
544                    self.per_thread_copy_ahead,
545                    self.sync_files,
546                    self.thread_count,
547                ),
548            )
549            t.start()
550            threads.append(t)
551
552        _write_files_from_queue(
553            create_stream=self.fs.create_stream,
554            file_queue=file_queue,
555            result_queue=result_queue,
556            planner=planner,
557            inflight_threshhold=self.per_thread_copy_ahead,
558            use_fsync=self.sync_files,
559            thread_count=self.thread_count,
560        )
561
562        for t in threads:
563            t.join()
564
565        res = []
566        try:
567            while True:
568                res += result_queue.get_nowait()
569        except queue.Empty:
570            fut: Future[List[WriteResult]] = Future()
571            fut.set_result(res)
572            return fut
573
574    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
575        storage_md = {}
576        for wr_list in results:
577            storage_md.update({wr.index: wr.storage_data for wr in wr_list})
578        metadata.storage_data = storage_md
579
580        metadata.storage_meta = self.storage_meta()
581
582        tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
583        with self.fs.create_stream(tmp_path, "wb") as metadata_file:
584            pickle.dump(metadata, metadata_file)
585            if self.sync_files:
586                try:
587                    os.fsync(metadata_file.fileno())
588                except AttributeError:
589                    os.sync()
590
591        # delete in-case other checkpoints were present.
592        if self.fs.exists(self.metadata_path):
593            self.fs.rm_file(self.metadata_path)
594
595        self.fs.rename(tmp_path, self.metadata_path)
596
597    def storage_meta(self) -> Optional[StorageMeta]:
598        return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id)
599
600    @property
601    def metadata_path(self) -> Union[str, os.PathLike]:
602        return cast(Path, self.fs.concat_path(self.path, _metadata_fn))
603
604    @property
605    def checkpoint_id(self) -> Union[str, os.PathLike]:
606        """
607        return the checkpoint_id that will be used to save the checkpoint.
608        """
609        return self.path
610
611    @classmethod
612    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
613        return FileSystem.validate_checkpoint_id(checkpoint_id)
614
615
616class FileSystemReader(StorageReader):
617    def __init__(self, path: Union[str, os.PathLike]) -> None:
618        super().__init__()
619        self.fs = FileSystem()
620        self.path = self.fs.init_path(path)
621        self.storage_data: Dict[MetadataIndex, _StorageInfo] = {}
622        self.load_id = _generate_uuid()
623
624    def _slice_file(self, file, sinfo: _StorageInfo) -> io.IOBase:
625        return _create_file_view(file, sinfo.offset, sinfo.length)
626
627    def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
628        self.storage_data = {}
629        if checkpoint_id:
630            self.path = self.fs.init_path(checkpoint_id)
631        self.load_id = _generate_uuid()
632
633    def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
634        # group requests by file
635        per_file: Dict[str, List[ReadItem]] = {}
636        for read_item in plan.items:
637            item_md = self.storage_data[read_item.storage_index]
638            path = item_md.relative_path
639            per_file.setdefault(path, []).append(read_item)
640
641        for relative_path, reqs in per_file.items():
642            new_path = self.fs.concat_path(self.path, relative_path)
643            with self.fs.create_stream(new_path, "rb") as stream:
644                # TODO sort by offset and cache the reading
645                for req in reqs:
646                    item_md = self.storage_data[req.storage_index]
647                    file_slice = self._slice_file(stream, item_md)
648                    if req.type == LoadItemType.BYTE_IO:
649                        read_bytes = io.BytesIO(file_slice.read(item_md.length))
650                        read_bytes.seek(0)
651                        planner.load_bytes(req, read_bytes)
652                    else:
653                        tensor = cast(
654                            Tensor,
655                            torch.load(
656                                cast(IO[bytes], file_slice),
657                                map_location="cpu",
658                                weights_only=True,
659                            ),
660                        )
661                        tensor = narrow_tensor_by_index(
662                            tensor, req.storage_offsets, req.lengths
663                        )
664                        target_tensor = planner.resolve_tensor(req).detach()
665
666                        assert (
667                            target_tensor.size() == tensor.size()
668                        ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
669                        target_tensor.copy_(tensor)
670                        planner.commit_tensor(req, target_tensor)
671
672        fut: Future = Future()
673        fut.set_result(None)
674        return fut
675
676    # Implementing the abstract function in StorageReader
677    def read_metadata(self) -> Metadata:
678        path = self.fs.concat_path(self.path, ".metadata")
679        with self.fs.create_stream(path, "rb") as metadata_file:
680            metadata = pickle.load(metadata_file)
681
682        if getattr(metadata, "storage_meta", None) is None:
683            metadata.storage_meta = StorageMeta()
684        metadata.storage_meta.load_id = self.load_id
685
686        return metadata
687
688    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
689        self.storage_data = metadata.storage_data
690        assert self.storage_data is not None
691
692    def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
693        return plan
694
695    def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
696        return plans
697
698    @property
699    def checkpoint_id(self) -> Union[str, os.PathLike]:
700        """
701        return the checkpoint_id that will be used to load the checkpoint.
702        """
703        return self.path
704
705    @classmethod
706    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
707        return FileSystem.validate_checkpoint_id(checkpoint_id)
708
709
710class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
711    """
712    Basic implementation of StorageWriter using file IO.
713
714    This implementation makes the following assumptions and simplifications:
715
716    * The checkpoint path is an empty or non-existing directory.
717    * File creation is atomic
718
719    The checkpoint consist of one file per write request plus
720    a `.metadata` file with the serialized metadata.
721
722    """
723
724    def __init__(
725        self,
726        path: Union[str, os.PathLike],
727        single_file_per_rank: bool = True,
728        sync_files: bool = True,
729        thread_count: int = 1,
730        per_thread_copy_ahead: int = 10_000_000,
731        cache_staged_state_dict: bool = False,
732        overwrite: bool = True,
733    ) -> None:
734        """
735        Initialize the writer pointing to `path`.
736
737        Args:
738            path: directory where the checkpoint will be written to.
739            single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True.
740            sync_files : force files to be synced to permanent storage. Default to True.
741            thread_count: Number of IO threads to use to write. Default to 1.
742            per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb.
743            cache_staged_state_dict: Whether to cache the staged state_dict. This option decreases staging latency
744                at the cost of increases memory usage. Additionally, if this parameter is set to True, it's the expectation
745                that the stager is maintained and re-used for multiple dcp.async_save calls. Default to False.
746            overwrite: Whether to allow overwriting existing checkpoints. Defaults to True.
747
748        N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure.
749        """
750        super().__init__(
751            path=path,
752            single_file_per_rank=single_file_per_rank,
753            sync_files=sync_files,
754            thread_count=thread_count,
755            per_thread_copy_ahead=per_thread_copy_ahead,
756            cache_staged_state_dict=cache_staged_state_dict,
757            overwrite=overwrite,
758        )
759
760    def stage(self, state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
761        """Override of AsyncStager.stage"""
762        # in the async case, the state dict is already on CPU, so maintaining this
763        # buffer makes no sense
764        self.per_thread_copy_ahead = 0
765        return super().stage(state_dict)
766