1# mypy: allow-untyped-defs 2import threading 3from typing import Any, Dict 4 5import torch._C._lazy 6 7 8class DeviceContext: 9 _CONTEXTS: Dict[str, Any] = {} 10 _CONTEXTS_LOCK = threading.Lock() 11 12 def __init__(self, device): 13 self.device = device 14 15 16def get_device_context(device=None): 17 if device is None: 18 device = torch._C._lazy._get_default_device_type() 19 else: 20 device = str(device) 21 with DeviceContext._CONTEXTS_LOCK: 22 devctx = DeviceContext._CONTEXTS.get(device, None) 23 if devctx is None: 24 devctx = DeviceContext(device) 25 DeviceContext._CONTEXTS[device] = devctx 26 return devctx 27