xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""torch.multiprocessing is a wrapper around the native :mod:`multiprocessing` module.
3
4It registers custom reducers, that use shared memory to provide shared
5views on the same data in different processes. Once the tensor/storage is moved
6to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
7to send it to other processes without making any copies.
8
9The API is 100% compatible with the original module - it's enough to change
10``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
11tensors sent through the queues or shared via other mechanisms, moved to shared
12memory.
13
14Because of the similarity of APIs we do not document most of this package
15contents, and we recommend referring to very good docs of the original module.
16"""
17import multiprocessing
18import sys
19
20import torch
21
22from .reductions import init_reductions
23
24
25__all__ = ["set_sharing_strategy", "get_sharing_strategy", "get_all_sharing_strategies"]
26
27
28from multiprocessing import *  # noqa: F403
29
30
31__all__ += multiprocessing.__all__  # noqa: PLE0605 type: ignore[attr-defined]
32
33
34# This call adds a Linux specific prctl(2) wrapper function to this module.
35# See https://github.com/pytorch/pytorch/pull/14391 for more information.
36torch._C._multiprocessing_init()
37
38
39"""Add helper function to spawn N processes and wait for completion of any of
40them. This depends `mp.get_context` which was added in Python 3.4."""
41from .spawn import (
42    ENV_VAR_PARALLEL_START,
43    ProcessContext,
44    ProcessExitedException,
45    ProcessRaisedException,
46    spawn,
47    SpawnContext,
48    start_processes,
49)
50
51
52if sys.platform == "darwin" or sys.platform == "win32":
53    _sharing_strategy = "file_system"
54    _all_sharing_strategies = {"file_system"}
55else:
56    _sharing_strategy = "file_descriptor"
57    _all_sharing_strategies = {"file_descriptor", "file_system"}
58
59
60def set_sharing_strategy(new_strategy):
61    """Set the strategy for sharing CPU tensors.
62
63    Args:
64        new_strategy (str): Name of the selected strategy. Should be one of
65            the values returned by :func:`get_all_sharing_strategies()`.
66    """
67    global _sharing_strategy
68    assert new_strategy in _all_sharing_strategies
69    _sharing_strategy = new_strategy
70
71
72def get_sharing_strategy():
73    """Return the current strategy for sharing CPU tensors."""
74    return _sharing_strategy
75
76
77def get_all_sharing_strategies():
78    """Return a set of sharing strategies supported on a current system."""
79    return _all_sharing_strategies
80
81
82def _set_thread_name(name: str) -> None:
83    """Set the name of the current thread.
84
85    Args:
86        name (str): Name of the current thread.
87    """
88    torch._C._set_thread_name(name)
89
90
91def _get_thread_name() -> str:
92    """Get the name of the current thread.
93
94    Returns:
95        str: Name of the current thread.
96    """
97    return torch._C._get_thread_name()
98
99
100init_reductions()
101