xref: /aosp_15_r20/external/pytorch/torch/nn/cpp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Functionality for Python <-> C++ frontend inter-op."""
3
4from torch import nn
5
6
7class OrderedDictWrapper:
8    """A wrapper around a C++ OrderedDict.
9
10    It dynamically evaluates the OrderedDict getter on a bound C++ module, such
11    that new changes on the C++ side are picked up. Otherwise accessing e.g.
12    ``cpp_module._parameters`` just once would get a frozen copy of the parameters
13    at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
14    so using properties does not work.
15    """
16
17    def __init__(self, cpp_module, attr):
18        self.cpp_module = cpp_module
19        self.attr = attr
20
21    @property
22    def cpp_dict(self):
23        return getattr(self.cpp_module, self.attr)
24
25    # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
26    # must manually override them.
27
28    def items(self):
29        return self.cpp_dict.items()
30
31    def keys(self):
32        return self.cpp_dict.keys()
33
34    def values(self):
35        return self.cpp_dict.values()
36
37    def __iter__(self):
38        return self.cpp_dict.__iter__()
39
40    def __len__(self):
41        return self.cpp_dict.__len__()
42
43    def __contains__(self, key):
44        return self.cpp_dict.__contains__(key)
45
46    def __getitem__(self, key):
47        return self.cpp_dict.__getitem__(key)
48
49
50class ModuleWrapper(nn.Module):
51    """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
52
53    def __init__(self, cpp_module):
54        # Assign before the super class constructor so ``self.training`` can be
55        # assigned to in the super class constructor.
56        self.cpp_module = cpp_module
57        super().__init__()
58        self._parameters = OrderedDictWrapper(cpp_module, "_parameters")  # type: ignore[assignment]
59        self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers")  # type: ignore[assignment]
60        self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules")  # type: ignore[assignment]
61        for attr in dir(cpp_module):
62            # Skip magic methods and the three attributes above.
63            if not attr.startswith("_"):
64                setattr(self, attr, getattr(self.cpp_module, attr))
65
66    def _apply(self, fn, recurse=True):
67        for param in self.parameters():
68            # Tensors stored in modules are graph leaves, and we don't
69            # want to create copy nodes, so we have to unpack the data.
70            param.data = fn(param.data)
71            if param._grad is not None:
72                param._grad.data = fn(param._grad.data)
73
74        for buf in self.buffers():
75            buf.data = fn(buf.data)
76
77        return self
78
79    # nn.Module defines training as a boolean
80    @property  # type: ignore[override]
81    def training(self):
82        return self.cpp_module.training
83
84    @training.setter
85    def training(self, mode):
86        self.cpp_module.train(mode)
87
88    def __repr__(self):
89        return self.cpp_module.__repr__()
90