xref: /aosp_15_r20/external/pytorch/torch/_inductor/remote_cache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import json
4import os
5import typing
6from abc import abstractmethod
7from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
8from typing_extensions import override, TypeAlias
9
10from torch._inductor import config
11
12
13try:
14    import redis
15except ImportError:
16    redis = None  # type: ignore[assignment]
17
18
19if config.is_fbcode():
20    from rfe.scubadata.scubadata_py3 import (  # type: ignore[import-not-found]
21        Sample as Sample_,
22    )
23
24    Sample: TypeAlias = Sample_
25else:
26    Sample: TypeAlias = Type[object]  # type: ignore[misc,no-redef]
27
28
29_T = TypeVar("_T")
30_U = TypeVar("_U")
31
32
33class RemoteCacheBackend(Generic[_T]):
34    """
35    A backend implementation for accessing a remote/distributed cache.  Only
36    works with bytes in/out.  For structured data use a RemoteCache.
37    """
38
39    @abstractmethod
40    def get(self, key: str) -> Optional[_T]:
41        pass
42
43    @abstractmethod
44    def put(self, key: str, data: _T) -> None:
45        pass
46
47
48# Serde that encodes from _T to _U and decodes from _U to _T.
49class RemoteCacheSerde(Generic[_T, _U]):
50    @abstractmethod
51    def encode(self, data: _T) -> _U:
52        pass
53
54    @abstractmethod
55    def decode(self, data: _U) -> _T:
56        pass
57
58
59JsonDataTy = Optional[
60    Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
61]
62
63
64class RemoteCacheJsonSerde(RemoteCacheSerde[JsonDataTy, bytes]):
65    def encode(self, data: JsonDataTy) -> bytes:
66        return bytes(json.dumps(data), "ascii")
67
68    def decode(self, data: bytes) -> JsonDataTy:
69        return json.loads(data)
70
71
72class RemoteCachePassthroughSerde(RemoteCacheSerde[_T, _T]):
73    def encode(self, data: _T) -> _T:
74        return data
75
76    def decode(self, data: _T) -> _T:
77        return data
78
79
80class RemoteCache(Generic[_T]):
81    backend_override_cls: Optional[Callable[[], RemoteCacheBackend[Any]]] = None
82
83    def __init__(
84        self, backend: RemoteCacheBackend[_U], serde: RemoteCacheSerde[_T, _U]
85    ) -> None:
86        # Support for testing.
87        if (override_cls := self.__class__.backend_override_cls) is not None:
88            self.backend = override_cls()
89        else:
90            self.backend = backend
91        self.serde = serde
92
93    def get(self, key: str) -> Optional[_T]:
94        sample = self._create_sample()
95        result = self._get(key, sample)
96        self._log_sample(sample)
97        return result
98
99    def put(self, key: str, value: _T) -> None:
100        sample = self._create_sample()
101        self._put(key, value, sample)
102        self._log_sample(sample)
103
104    def _decode(self, data: _U, sample: Optional[Sample]) -> _T:
105        return self.serde.decode(data)
106
107    def _encode(self, value: _T, sample: Optional[Sample]) -> Any:  # returns _U
108        return self.serde.encode(value)
109
110    def _get(self, key: str, sample: Optional[Sample]) -> Optional[_T]:
111        if data := self.backend.get(key):
112            return self._decode(data, sample)
113        return None
114
115    def _put(self, key: str, value: _T, sample: Optional[Sample]) -> None:
116        data = self._encode(value, sample)
117        self.backend.put(key, data)
118
119    def _create_sample(self) -> Optional[Sample]:
120        return None
121
122    def _log_sample(self, sample: Optional[Sample]) -> None:
123        pass
124
125
126class RedisRemoteCacheBackend(RemoteCacheBackend[bytes]):
127    """
128    A Redis implementation of a remote/distributed cache.
129    """
130
131    _key_fmt: str
132    _redis: Optional[redis.Redis] = None
133
134    def __init__(self, cache_id: str) -> None:
135        if not redis:
136            # We had trouble importing redis - just skip init.
137            return
138
139        self._key_fmt = f"pt2:{cache_id}:{{key}}"
140        self._redis = redis.Redis(
141            host=os.environ.get("TORCHINDUCTOR_REDIS_HOST", "localhost"),
142            port=int(os.environ.get("TORCHINDUCTOR_REDIS_PORT", 6379)),
143        )
144
145    def __get_key(self, key: str) -> str:
146        return self._key_fmt.format(key=key)
147
148    @override
149    def get(self, key: str) -> Optional[bytes]:
150        if not self._redis:
151            # Either redis wasn't found or we already had some trouble...
152            return None
153
154        try:
155            value = self._redis.get(self.__get_key(key))
156        except redis.exceptions.ConnectionError:
157            # Redis is lazy and doesn't actually attempt to connect until the
158            # first use. Mark is as unavailable now.
159            self._redis = None
160            return None
161
162        # In theory redis.get() can return an Awaitable as well...
163        assert value is None or isinstance(value, bytes)
164        return value
165
166    @override
167    def put(self, key: str, data: bytes) -> None:
168        if not self._redis:
169            # Either redis wasn't found or we already had some trouble...
170            return
171
172        try:
173            self._redis.set(self.__get_key(key), data)
174        except redis.exceptions.ConnectionError:
175            # Redis is lazy and doesn't actually attempt to connect until the
176            # first use. Mark is as unavailable now.
177            self._redis = None
178
179
180class RedisRemoteCache(RemoteCache[JsonDataTy]):
181    def __init__(self, key: str) -> None:
182        # Special test handling: If we're just going to override the backend
183        # anyway don't require redis
184        if self.__class__.backend_override_cls:
185            # This is totally bogus but it works for now...
186            backend = typing.cast(RemoteCacheBackend[bytes], None)
187        else:
188            backend = RedisRemoteCacheBackend(key)
189        serde = RemoteCacheJsonSerde()
190        super().__init__(backend, serde)
191
192
193class RemoteAutotuneCache(RedisRemoteCache):
194    pass
195
196
197class RemoteFxGraphCache(RedisRemoteCache):
198    pass
199