1# mypy: allow-untyped-defs 2import logging 3from typing import Any, Dict, Optional, Protocol, Tuple, Union 4 5import torch 6from torch._library.utils import parse_namespace 7 8 9log = logging.getLogger(__name__) 10 11 12class FakeScriptObject: 13 def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject): 14 self.wrapped_obj = wrapped_obj 15 16 # The fully qualified name of the class of original script object 17 self.script_class_name = script_class_name 18 self.real_obj = x 19 20 21class FakeScriptMethod: 22 def __init__( 23 self, 24 self_fake_obj: FakeScriptObject, 25 method_name: str, 26 schema: Optional[torch.FunctionSchema], 27 ): 28 self.self_fake_obj = self_fake_obj 29 self.method_name = method_name 30 self.schema = schema 31 32 def __call__(self, *args, **kwargs): 33 from torch._higher_order_ops.torchbind import call_torchbind 34 35 return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) 36 37 38class HasStaticMethodFromReal(Protocol): 39 @classmethod 40 def from_real(cls, real_obj: torch.ScriptObject): 41 pass 42 43 44class FakeClassRegistry: 45 def __init__(self) -> None: 46 self._registered_class: Dict[str, Any] = {} 47 48 def has_impl(self, full_qualname: str) -> bool: 49 return full_qualname in self._registered_class 50 51 def get_impl(self, full_qualname: str) -> Any: 52 self._check_registered(full_qualname) 53 return self._registered_class[full_qualname] 54 55 def register(self, full_qualname: str, fake_class=None) -> None: 56 if self.has_impl(full_qualname): 57 log.warning( 58 "%s is already registered. Previous fake class is overridden with %s.", 59 full_qualname, 60 fake_class, 61 ) 62 self._registered_class[full_qualname] = fake_class 63 64 def deregister(self, full_qualname: str) -> Any: 65 if not self.has_impl(full_qualname): 66 log.warning( 67 "Cannot deregister %s. Please use register_fake_class to register it first." 68 " Or do you dereigster it twice?", 69 full_qualname, 70 ) 71 else: 72 return self._registered_class.pop(full_qualname) 73 74 def clear(self) -> None: 75 self._registered_class.clear() 76 77 def _check_registered(self, full_qualname: str) -> None: 78 if full_qualname not in self._registered_class: 79 raise RuntimeError( 80 f"{full_qualname} is not registered. Please use register_fake_class to register it first." 81 ) 82 83 84global_fake_class_registry = FakeClassRegistry() 85 86 87# TODO: add this check at compile time for __obj_flatten__. 88def _check_valid_flat_script_obj(flat_x): 89 if not isinstance(flat_x, tuple): 90 raise RuntimeError("Expect flat x to be a tuple.") 91 92 for tp in flat_x: 93 if not isinstance(tp, tuple): 94 raise RuntimeError("Expect flat x to be a tuple of tuples.") 95 96 if not len(tp) == 2 or not isinstance(tp[0], str): 97 raise RuntimeError( 98 "Expect element of flat x to be a tuple of two elements with first element being a string" 99 ) 100 101 102def tracing_with_real(x: torch.ScriptObject) -> bool: 103 if not hasattr(x, "tracing_mode"): 104 return False 105 106 assert x.tracing_mode() in [ 107 "real", 108 "fake", 109 ], f"tracing_mode can be either real or fake but got {x.tracing_mode()}" 110 return x.tracing_mode() == "real" 111 112 113def maybe_to_fake_obj( 114 fake_mode, x: torch.ScriptObject 115) -> Union[FakeScriptObject, torch.ScriptObject]: 116 import torch.utils._pytree as pytree 117 from torch.utils._python_dispatch import _disable_current_modes 118 119 # When tracing with real mode, people should implement meta kernels that can 120 # handle the case of real script object + fake tensor inputs. 121 if tracing_with_real(x): 122 return x 123 124 # x.__obj_flatten__() could be calling some tensor operations inside but we don't 125 # want to call these ops in surrounding dispatch modes when executing it. 126 # Otherwise, for example, the fake tensor modes will error out when the tensors inside 127 # script obeject execute some operations like clone if allow_non_fake_input flag is set. 128 with _disable_current_modes(): 129 flat_x = x.__obj_flatten__() # type: ignore[attr-defined] 130 131 _check_valid_flat_script_obj(flat_x) 132 133 fake_flattened = pytree.tree_map_only( 134 torch.Tensor, 135 lambda t: fake_mode.from_tensor(t), 136 flat_x, 137 ) 138 139 fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) 140 141 fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined] 142 143 for name in x._method_names(): # type: ignore[attr-defined] 144 attr = getattr(fake_x, name, None) 145 if attr: 146 if not callable(attr): 147 raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") 148 149 real_attr = getattr(x, name) # type: ignore[attr-defined] 150 151 # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ 152 method_schema: Optional[torch.FunctionSchema] = None 153 if isinstance(real_attr, torch.ScriptMethod): 154 method_schema = real_attr.schema # type: ignore[attr-defined] 155 156 setattr( 157 fake_x_wrapped, 158 name, 159 FakeScriptMethod(fake_x_wrapped, name, method_schema), 160 ) 161 else: 162 override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"} 163 if name not in override_skip_list: 164 log.warning("fake object of %s doesn't implement method %s.", x, name) 165 return fake_x_wrapped 166 167 168def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None): 169 r"""Register a fake implementation for this class. 170 171 It's in the same spirit of registering a fake implementation for 172 an operator but with the difference that it 173 associates a fake class with the original torch bind class (registered 174 with torch::class_). In this way, torch.compile can handle them properly 175 in components such as Dynamo and AOTAutograd. 176 177 This API may be used as a decorator (see example). For the fake class, users 178 are required to provide a from_real classmethod that takes a real object and 179 returns an instance of the fake class. All tensors in the fake object should also 180 be properly fakified with to_fake_tensor() in from_real. 181 182 183 Examples: 184 # For a custom class Foo defined in test_custom_class_registration.cpp: 185 186 TORCH_LIBRARY(_TorchScriptTesting, m) { 187 m.class_<TensorQueue>("_TensorQueue") 188 .def(torch::init<at::Tensor>()) 189 .def("push", &TensorQueue::push) 190 .def("pop", &TensorQueue::pop) 191 .def("top", &TensorQueue::top) 192 .def("size", &TensorQueue::size) 193 .def("clone_queue", &TensorQueue::clone_queue) 194 .def("__obj_flatten__", &TensorQueue::__obj_flatten__) 195 .def_pickle( 196 // __getstate__ 197 [](const c10::intrusive_ptr<TensorQueue>& self) 198 -> c10::Dict<std::string, at::Tensor> { 199 return self->serialize(); 200 }, 201 // __setstate__ 202 [](c10::Dict<std::string, at::Tensor> data) 203 -> c10::intrusive_ptr<TensorQueue> { 204 return c10::make_intrusive<TensorQueue>(std::move(data)); 205 }); 206 }; 207 # We could register a fake class FakeTensorQueue in Python as follows: 208 import torch 209 210 @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") 211 class FakeTensorQueue: 212 def __init__(self, queue): 213 self.queue = queue 214 215 @classmethod 216 def __obj_unflatten__(cls, flattened_ctx): 217 return cls(**dict(ctx)) 218 219 def push(self, x): 220 self.queue.append(x) 221 222 def pop(self): 223 return self.queue.pop(0) 224 225 def size(self): 226 return len(self.queue) 227 228 In this example, the original TensorQeue need to addd a __obj_flatten__ method 229 to the class TensorQueue and the flattend result is passed into FakeTensorQueue's 230 __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look 231 at the contents of the script object and properly handle them in the subsystems 232 like dynamo, aot_aotugrad or more. 233 """ 234 235 def inner(fake_class: HasStaticMethodFromReal): 236 ns, name = parse_namespace(qualname) 237 238 # This also checks whether the refered torch::class_ exists. 239 torchbind_class = torch._C._get_custom_class_python_wrapper(ns, name) 240 241 from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) 242 if not from_method: 243 raise RuntimeError( 244 f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}." 245 ) 246 247 if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): 248 raise RuntimeError( 249 f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod." 250 ) 251 252 global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class) 253 return fake_class 254 255 if fake_class is None: 256 return inner 257 return inner(fake_class) 258 259 260def deregister_fake_class(qualname): 261 return global_fake_class_registry.deregister(_full_qual_class_name(qualname)) 262 263 264def has_fake_class(full_qualname) -> bool: 265 return global_fake_class_registry.has_impl(full_qualname) 266 267 268def find_fake_class(full_qualname) -> Optional[Any]: 269 if not has_fake_class(full_qualname): 270 return None 271 return global_fake_class_registry.get_impl(full_qualname) 272 273 274def _full_qual_class_name(qualname: str) -> str: 275 ns, name = parse_namespace(qualname) 276 return "__torch__.torch.classes." + ns + "." + name 277 278 279# Return the namespace and class name from fully qualified name. 280def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]: 281 splits = full_qualname.split(".") 282 assert len(splits) == 5 283 _torch, torch_ns, classes, ns, class_name = splits 284 return ns, class_name 285 286 287def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any: 288 full_qualname = x._type().qualified_name() # type: ignore[attr-defined] 289 ns, class_name = _ns_and_class_name(full_qualname) 290 fake_class = find_fake_class(full_qualname) 291 if fake_class is None: 292 raise RuntimeError( 293 f" ScriptObject's {full_qualname} haven't registered a fake class." 294 f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj." 295 f" Specifically, create a python class that implements a fake version for all the methods" 296 f" that're used in the program and put annotated class in the program e.g. after loading the library." 297 f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally" 298 f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod" 299 f" to enable creating a fake obj from a real one." 300 ) 301 return fake_class 302 303 304_CONVERT_FROM_REAL_NAME = "__obj_unflatten__" 305 306 307def _fake_obj_from_real(fake_mode, x) -> Any: 308 fake_class = _find_fake_class_for_script_object(x) 309 310 from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) 311 if not from_real_method: 312 raise RuntimeError( 313 f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}" 314 f" that converts the real object to the fake object." 315 ) 316 317 # from_real defined by user need the ctx to fakify the tensor states. 318 ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None) 319 with torch._library.fake_impl.set_ctx_getter(lambda: ctx): 320 return fake_class.from_real(x) 321