1from __future__ import annotations 2 3import json 4import os 5import typing 6from abc import abstractmethod 7from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union 8from typing_extensions import override, TypeAlias 9 10from torch._inductor import config 11 12 13try: 14 import redis 15except ImportError: 16 redis = None # type: ignore[assignment] 17 18 19if config.is_fbcode(): 20 from rfe.scubadata.scubadata_py3 import ( # type: ignore[import-not-found] 21 Sample as Sample_, 22 ) 23 24 Sample: TypeAlias = Sample_ 25else: 26 Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] 27 28 29_T = TypeVar("_T") 30_U = TypeVar("_U") 31 32 33class RemoteCacheBackend(Generic[_T]): 34 """ 35 A backend implementation for accessing a remote/distributed cache. Only 36 works with bytes in/out. For structured data use a RemoteCache. 37 """ 38 39 @abstractmethod 40 def get(self, key: str) -> Optional[_T]: 41 pass 42 43 @abstractmethod 44 def put(self, key: str, data: _T) -> None: 45 pass 46 47 48# Serde that encodes from _T to _U and decodes from _U to _T. 49class RemoteCacheSerde(Generic[_T, _U]): 50 @abstractmethod 51 def encode(self, data: _T) -> _U: 52 pass 53 54 @abstractmethod 55 def decode(self, data: _U) -> _T: 56 pass 57 58 59JsonDataTy = Optional[ 60 Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] 61] 62 63 64class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]): 65 def encode(self, data: JsonDataTy) -> bytes: 66 return bytes(json.dumps(data), "ascii") 67 68 def decode(self, data: bytes) -> JsonDataTy: 69 return json.loads(data) 70 71 72class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]): 73 def encode(self, data: _T) -> _T: 74 return data 75 76 def decode(self, data: _T) -> _T: 77 return data 78 79 80class RemoteCache(Generic[_T]): 81 backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None 82 83 def __init__( 84 self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U] 85 ) -> None: 86 # Support for testing. 87 if (override_cls := self.__class__.backend_override_cls) is not None: 88 self.backend = override_cls() 89 else: 90 self.backend = backend 91 self.serde = serde 92 93 def get(self, key: str) -> Optional[_T]: 94 sample = self._create_sample() 95 result = self._get(key, sample) 96 self._log_sample(sample) 97 return result 98 99 def put(self, key: str, value: _T) -> None: 100 sample = self._create_sample() 101 self._put(key, value, sample) 102 self._log_sample(sample) 103 104 def _decode(self, data: _U, sample: Optional[Sample]) -> _T: 105 return self.serde.decode(data) 106 107 def _encode(self, value: _T, sample: Optional[Sample]) -> Any: # returns _U 108 return self.serde.encode(value) 109 110 def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]: 111 if data := self.backend.get(key): 112 return self._decode(data, sample) 113 return None 114 115 def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None: 116 data = self._encode(value, sample) 117 self.backend.put(key, data) 118 119 def _create_sample(self) -> Optional[Sample]: 120 return None 121 122 def _log_sample(self, sample: Optional[Sample]) -> None: 123 pass 124 125 126class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]): 127 """ 128 A Redis implementation of a remote/distributed cache. 129 """ 130 131 _key_fmt: str 132 _redis: Optional[redis.Redis] = None 133 134 def __init__(self, cache_id: str) -> None: 135 if not redis: 136 # We had trouble importing redis - just skip init. 137 return 138 139 self._key_fmt = f"pt2:{cache_id}:{{key}}" 140 self._redis = redis.Redis( 141 host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"), 142 port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)), 143 ) 144 145 def __get_key(self, key: str) -> str: 146 return self._key_fmt.format(key=key) 147 148 @override 149 def get(self, key: str) -> Optional[bytes]: 150 if not self._redis: 151 # Either redis wasn't found or we already had some trouble... 152 return None 153 154 try: 155 value = self._redis.get(self.__get_key(key)) 156 except redis.exceptions.ConnectionError: 157 # Redis is lazy and doesn't actually attempt to connect until the 158 # first use. Mark is as unavailable now. 159 self._redis = None 160 return None 161 162 # In theory redis.get() can return an Awaitable as well... 163 assert value is None or isinstance(value, bytes) 164 return value 165 166 @override 167 def put(self, key: str, data: bytes) -> None: 168 if not self._redis: 169 # Either redis wasn't found or we already had some trouble... 170 return 171 172 try: 173 self._redis.set(self.__get_key(key), data) 174 except redis.exceptions.ConnectionError: 175 # Redis is lazy and doesn't actually attempt to connect until the 176 # first use. Mark is as unavailable now. 177 self._redis = None 178 179 180class RedisRemoteCache(RemoteCache[JsonDataTy]): 181 def __init__(self, key: str) -> None: 182 # Special test handling: If we're just going to override the backend 183 # anyway don't require redis 184 if self.__class__.backend_override_cls: 185 # This is totally bogus but it works for now... 186 backend = typing.cast(RemoteCacheBackend[bytes], None) 187 else: 188 backend = RedisRemoteCacheBackend(key) 189 serde = RemoteCacheJsonSerde() 190 super().__init__(backend, serde) 191 192 193class RemoteAutotuneCache(RedisRemoteCache): 194 pass 195 196 197class RemoteFxGraphCache(RedisRemoteCache): 198 pass 199