xref: /aosp_15_r20/external/pytorch/torch/_library/fake_class_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3from typing import Any, Dict, Optional, Protocol, Tuple, Union
4
5import torch
6from torch._library.utils import parse_namespace
7
8
9log = logging.getLogger(__name__)
10
11
12class FakeScriptObject:
13    def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject):
14        self.wrapped_obj = wrapped_obj
15
16        # The fully qualified name of the class of original script object
17        self.script_class_name = script_class_name
18        self.real_obj = x
19
20
21class FakeScriptMethod:
22    def __init__(
23        self,
24        self_fake_obj: FakeScriptObject,
25        method_name: str,
26        schema: Optional[torch.FunctionSchema],
27    ):
28        self.self_fake_obj = self_fake_obj
29        self.method_name = method_name
30        self.schema = schema
31
32    def __call__(self, *args, **kwargs):
33        from torch._higher_order_ops.torchbind import call_torchbind
34
35        return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs)
36
37
38class HasStaticMethodFromReal(Protocol):
39    @classmethod
40    def from_real(cls, real_obj: torch.ScriptObject):
41        pass
42
43
44class FakeClassRegistry:
45    def __init__(self) -> None:
46        self._registered_class: Dict[str, Any] = {}
47
48    def has_impl(self, full_qualname: str) -> bool:
49        return full_qualname in self._registered_class
50
51    def get_impl(self, full_qualname: str) -> Any:
52        self._check_registered(full_qualname)
53        return self._registered_class[full_qualname]
54
55    def register(self, full_qualname: str, fake_class=None) -> None:
56        if self.has_impl(full_qualname):
57            log.warning(
58                "%s is already registered. Previous fake class is overridden with  %s.",
59                full_qualname,
60                fake_class,
61            )
62        self._registered_class[full_qualname] = fake_class
63
64    def deregister(self, full_qualname: str) -> Any:
65        if not self.has_impl(full_qualname):
66            log.warning(
67                "Cannot deregister %s. Please use register_fake_class to register it first."
68                " Or do you dereigster it twice?",
69                full_qualname,
70            )
71        else:
72            return self._registered_class.pop(full_qualname)
73
74    def clear(self) -> None:
75        self._registered_class.clear()
76
77    def _check_registered(self, full_qualname: str) -> None:
78        if full_qualname not in self._registered_class:
79            raise RuntimeError(
80                f"{full_qualname} is not registered. Please use register_fake_class to register it first."
81            )
82
83
84global_fake_class_registry = FakeClassRegistry()
85
86
87# TODO: add this check at compile time for __obj_flatten__.
88def _check_valid_flat_script_obj(flat_x):
89    if not isinstance(flat_x, tuple):
90        raise RuntimeError("Expect flat x to be a tuple.")
91
92    for tp in flat_x:
93        if not isinstance(tp, tuple):
94            raise RuntimeError("Expect flat x to be a tuple of tuples.")
95
96        if not len(tp) == 2 or not isinstance(tp[0], str):
97            raise RuntimeError(
98                "Expect element of flat x to be a tuple of two elements with first element being a string"
99            )
100
101
102def tracing_with_real(x: torch.ScriptObject) -> bool:
103    if not hasattr(x, "tracing_mode"):
104        return False
105
106    assert x.tracing_mode() in [
107        "real",
108        "fake",
109    ], f"tracing_mode can be either real or fake but got {x.tracing_mode()}"
110    return x.tracing_mode() == "real"
111
112
113def maybe_to_fake_obj(
114    fake_mode, x: torch.ScriptObject
115) -> Union[FakeScriptObject, torch.ScriptObject]:
116    import torch.utils._pytree as pytree
117    from torch.utils._python_dispatch import _disable_current_modes
118
119    # When tracing with real mode, people should implement meta kernels that can
120    # handle the case of real script object + fake tensor inputs.
121    if tracing_with_real(x):
122        return x
123
124    # x.__obj_flatten__() could be calling some tensor operations inside but we don't
125    # want to call these ops in surrounding dispatch modes when executing it.
126    # Otherwise, for example, the fake tensor modes will error out when the tensors inside
127    # script obeject execute some operations like clone if allow_non_fake_input flag is set.
128    with _disable_current_modes():
129        flat_x = x.__obj_flatten__()  # type: ignore[attr-defined]
130
131    _check_valid_flat_script_obj(flat_x)
132
133    fake_flattened = pytree.tree_map_only(
134        torch.Tensor,
135        lambda t: fake_mode.from_tensor(t),
136        flat_x,
137    )
138
139    fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened)
140
141    fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x)  # type: ignore[attr-defined]
142
143    for name in x._method_names():  # type: ignore[attr-defined]
144        attr = getattr(fake_x, name, None)
145        if attr:
146            if not callable(attr):
147                raise RuntimeError(f"Expect {name} to be a callable but got {attr}.")
148
149            real_attr = getattr(x, name)  # type: ignore[attr-defined]
150
151            # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__
152            method_schema: Optional[torch.FunctionSchema] = None
153            if isinstance(real_attr, torch.ScriptMethod):
154                method_schema = real_attr.schema  # type: ignore[attr-defined]
155
156            setattr(
157                fake_x_wrapped,
158                name,
159                FakeScriptMethod(fake_x_wrapped, name, method_schema),
160            )
161        else:
162            override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"}
163            if name not in override_skip_list:
164                log.warning("fake object of %s doesn't implement method %s.", x, name)
165    return fake_x_wrapped
166
167
168def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None):
169    r"""Register a fake implementation for this class.
170
171    It's in the same spirit of registering a fake implementation for
172    an operator but with the difference that it
173    associates a fake class with the original torch bind class (registered
174    with torch::class_). In this way, torch.compile can handle them properly
175    in components such as Dynamo and AOTAutograd.
176
177    This API may be used as a decorator (see example). For the fake class, users
178    are required to provide a from_real classmethod that takes a real object and
179    returns an instance of the fake class. All tensors in the fake object should also
180    be properly fakified with to_fake_tensor() in from_real.
181
182
183    Examples:
184        # For a custom class Foo defined in test_custom_class_registration.cpp:
185
186        TORCH_LIBRARY(_TorchScriptTesting, m) {
187          m.class_<TensorQueue>("_TensorQueue")
188            .def(torch::init<at::Tensor>())
189            .def("push", &TensorQueue::push)
190            .def("pop", &TensorQueue::pop)
191            .def("top", &TensorQueue::top)
192            .def("size", &TensorQueue::size)
193            .def("clone_queue", &TensorQueue::clone_queue)
194            .def("__obj_flatten__", &TensorQueue::__obj_flatten__)
195            .def_pickle(
196                // __getstate__
197                [](const c10::intrusive_ptr<TensorQueue>& self)
198                    -> c10::Dict<std::string, at::Tensor> {
199                  return self->serialize();
200                },
201                // __setstate__
202                [](c10::Dict<std::string, at::Tensor> data)
203                    -> c10::intrusive_ptr<TensorQueue> {
204                  return c10::make_intrusive<TensorQueue>(std::move(data));
205                });
206            };
207        # We could register a fake class FakeTensorQueue in Python as follows:
208        import torch
209
210        @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
211        class FakeTensorQueue:
212            def __init__(self, queue):
213                self.queue = queue
214
215            @classmethod
216            def __obj_unflatten__(cls, flattened_ctx):
217                return cls(**dict(ctx))
218
219            def push(self, x):
220                self.queue.append(x)
221
222            def pop(self):
223                return self.queue.pop(0)
224
225            def size(self):
226                return len(self.queue)
227
228    In this example, the original TensorQeue need to addd a __obj_flatten__ method
229    to the class TensorQueue and the flattend result is passed into FakeTensorQueue's
230    __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look
231    at the contents of the script object and properly handle them in the subsystems
232    like dynamo, aot_aotugrad or more.
233    """
234
235    def inner(fake_class: HasStaticMethodFromReal):
236        ns, name = parse_namespace(qualname)
237
238        # This also checks whether the refered torch::class_ exists.
239        torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name)
240
241        from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
242        if not from_method:
243            raise RuntimeError(
244                f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}."
245            )
246
247        if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod):
248            raise RuntimeError(
249                f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod."
250            )
251
252        global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class)
253        return fake_class
254
255    if fake_class is None:
256        return inner
257    return inner(fake_class)
258
259
260def deregister_fake_class(qualname):
261    return global_fake_class_registry.deregister(_full_qual_class_name(qualname))
262
263
264def has_fake_class(full_qualname) -> bool:
265    return global_fake_class_registry.has_impl(full_qualname)
266
267
268def find_fake_class(full_qualname) -> Optional[Any]:
269    if not has_fake_class(full_qualname):
270        return None
271    return global_fake_class_registry.get_impl(full_qualname)
272
273
274def _full_qual_class_name(qualname: str) -> str:
275    ns, name = parse_namespace(qualname)
276    return "__torch__.torch.classes." + ns + "." + name
277
278
279# Return the namespace and class name from fully qualified name.
280def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
281    splits = full_qualname.split(".")
282    assert len(splits) == 5
283    _torch, torch_ns, classes, ns, class_name = splits
284    return ns, class_name
285
286
287def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any:
288    full_qualname = x._type().qualified_name()  # type: ignore[attr-defined]
289    ns, class_name = _ns_and_class_name(full_qualname)
290    fake_class = find_fake_class(full_qualname)
291    if fake_class is None:
292        raise RuntimeError(
293            f" ScriptObject's {full_qualname} haven't registered a fake class."
294            f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj."
295            f" Specifically, create a python class that implements a fake version for all the methods"
296            f" that're used in the program and put annotated class in the program e.g. after loading the library."
297            f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally"
298            f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod"
299            f" to enable creating a fake obj from a real one."
300        )
301    return fake_class
302
303
304_CONVERT_FROM_REAL_NAME = "__obj_unflatten__"
305
306
307def _fake_obj_from_real(fake_mode, x) -> Any:
308    fake_class = _find_fake_class_for_script_object(x)
309
310    from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None)
311    if not from_real_method:
312        raise RuntimeError(
313            f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}"
314            f" that converts the real object to the fake object."
315        )
316
317    # from_real defined by user need the ctx to fakify the tensor states.
318    ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None)
319    with torch._library.fake_impl.set_ctx_getter(lambda: ctx):
320        return fake_class.from_real(x)
321