1# mypy: allow-untyped-defs 2"""Functionality for Python <-> C++ frontend inter-op.""" 3 4from torch import nn 5 6 7class OrderedDictWrapper: 8 """A wrapper around a C++ OrderedDict. 9 10 It dynamically evaluates the OrderedDict getter on a bound C++ module, such 11 that new changes on the C++ side are picked up. Otherwise accessing e.g. 12 ``cpp_module._parameters`` just once would get a frozen copy of the parameters 13 at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` 14 so using properties does not work. 15 """ 16 17 def __init__(self, cpp_module, attr): 18 self.cpp_module = cpp_module 19 self.attr = attr 20 21 @property 22 def cpp_dict(self): 23 return getattr(self.cpp_module, self.attr) 24 25 # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we 26 # must manually override them. 27 28 def items(self): 29 return self.cpp_dict.items() 30 31 def keys(self): 32 return self.cpp_dict.keys() 33 34 def values(self): 35 return self.cpp_dict.values() 36 37 def __iter__(self): 38 return self.cpp_dict.__iter__() 39 40 def __len__(self): 41 return self.cpp_dict.__len__() 42 43 def __contains__(self, key): 44 return self.cpp_dict.__contains__(key) 45 46 def __getitem__(self, key): 47 return self.cpp_dict.__getitem__(key) 48 49 50class ModuleWrapper(nn.Module): 51 """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" 52 53 def __init__(self, cpp_module): 54 # Assign before the super class constructor so ``self.training`` can be 55 # assigned to in the super class constructor. 56 self.cpp_module = cpp_module 57 super().__init__() 58 self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] 59 self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] 60 self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] 61 for attr in dir(cpp_module): 62 # Skip magic methods and the three attributes above. 63 if not attr.startswith("_"): 64 setattr(self, attr, getattr(self.cpp_module, attr)) 65 66 def _apply(self, fn, recurse=True): 67 for param in self.parameters(): 68 # Tensors stored in modules are graph leaves, and we don't 69 # want to create copy nodes, so we have to unpack the data. 70 param.data = fn(param.data) 71 if param._grad is not None: 72 param._grad.data = fn(param._grad.data) 73 74 for buf in self.buffers(): 75 buf.data = fn(buf.data) 76 77 return self 78 79 # nn.Module defines training as a boolean 80 @property # type: ignore[override] 81 def training(self): 82 return self.cpp_module.training 83 84 @training.setter 85 def training(self, mode): 86 self.cpp_module.train(mode) 87 88 def __repr__(self): 89 return self.cpp_module.__repr__() 90