xref: /aosp_15_r20/external/pytorch/torch/_lazy/device_context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import threading
3from typing import Any, Dict
4
5import torch._C._lazy
6
7
8class DeviceContext:
9    _CONTEXTS: Dict[str, Any] = {}
10    _CONTEXTS_LOCK = threading.Lock()
11
12    def __init__(self, device):
13        self.device = device
14
15
16def get_device_context(device=None):
17    if device is None:
18        device = torch._C._lazy._get_default_device_type()
19    else:
20        device = str(device)
21    with DeviceContext._CONTEXTS_LOCK:
22        devctx = DeviceContext._CONTEXTS.get(device, None)
23        if devctx is None:
24            devctx = DeviceContext(device)
25            DeviceContext._CONTEXTS[device] = devctx
26        return devctx
27