xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_subclass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4from copy import deepcopy
5from torch.utils._pytree import tree_map
6import torch.utils._pytree as pytree
7
8
9# TODO: Move LoggingTensor here.
10from torch.testing._internal.logging_tensor import LoggingTensor
11
12
13# Base class for wrapper-style tensors.
14class WrapperTensor(torch.Tensor):
15    @staticmethod
16    def __new__(cls, *args, **kwargs):
17        t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
18        if "size" not in kwargs:
19            size = t.size()
20        else:
21            size = kwargs["size"]
22            del kwargs["size"]
23        if "dtype" not in kwargs:
24            kwargs["dtype"] = t.dtype
25        if "layout" not in kwargs:
26            kwargs["layout"] = t.layout
27        if "device" not in kwargs:
28            kwargs["device"] = t.device
29        if "requires_grad" not in kwargs:
30            kwargs["requires_grad"] = False
31        # Ignore memory_format and pin memory for now as I don't know how to
32        # safely access them on a Tensor (if possible??)
33
34        wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
35        wrapper._validate_methods()
36        return wrapper
37
38    @classmethod
39    def get_wrapper_properties(cls, *args, **kwargs):
40        # Should return both an example Tensor and a dictionary of kwargs
41        # to override any of that example Tensor's properly.
42        # This is very similar to the `t.new_*(args)` API
43        raise NotImplementedError("You need to implement get_wrapper_properties")
44
45    def _validate_methods(self):
46        # Skip this if not in debug mode?
47        # Changing these on the python side is wrong as it would not be properly reflected
48        # on the c++ side
49        # This doesn't catch attributes set in the __init__
50        forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
51        for el in forbidden_overrides:
52            if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
53                raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
54                                   f"property {el} but this is not allowed as such change would "
55                                   "not be reflected to c++ callers.")
56
57
58class DiagTensorBelow(WrapperTensor):
59    @classmethod
60    def get_wrapper_properties(cls, diag, requires_grad=False):
61        assert diag.ndim == 1
62        return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
63
64    def __init__(self, diag, requires_grad=False):
65        self.diag = diag
66
67    handled_ops = {}
68
69    @classmethod
70    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
71        if not all(issubclass(cls, t) for t in types):
72            return NotImplemented
73
74        # For everything else, call the handler:
75        fn = cls.handled_ops.get(func.__name__, None)
76        if fn:
77            return fn(*args, **(kwargs or {}))
78        else:
79            # Note that here, because we don't need to provide the autograd formulas
80            # we can have a default "fallback" that creates a plain Tensor based
81            # on the diag elements and calls the func again.
82
83            def unwrap(e):
84                return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
85
86            def wrap(e):
87                if isinstance(e, torch.Tensor) and e.ndim == 1:
88                    return DiagTensorBelow(e)
89                if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
90                    return DiagTensorBelow(e.diag())
91                return e
92
93            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
94            return rs
95
96    def __repr__(self):
97        return super().__repr__(tensor_contents=f"diag={self.diag}")
98
99
100class SparseTensor(WrapperTensor):
101    @classmethod
102    def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
103        assert values.device == indices.device
104        return values, {"size": size, "requires_grad": requires_grad}
105
106    def __init__(self, size, values, indices, requires_grad=False):
107        self.values = values
108        self.indices = indices
109
110    def __repr__(self):
111        return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
112
113    def sparse_to_dense(self):
114        res = torch.zeros(self.size(), dtype=self.values.dtype)
115        res[self.indices.unbind(1)] = self.values
116        return res
117
118    @staticmethod
119    def from_dense(t):
120        indices = t.nonzero()
121        values = t[indices.unbind(1)]
122        return SparseTensor(t.size(), values, indices)
123
124    @classmethod
125    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
126        func_name = f"{func.__module__}.{func.__name__}"
127
128        res = cls._try_call_special_impl(func_name, args, kwargs)
129        if res is not NotImplemented:
130            return res
131
132        # Otherwise, use a default implementation that construct dense
133        # tensors and use that to compute values
134        def unwrap(e):
135            return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
136
137        # Wrap back all Tensors into our custom class
138        def wrap(e):
139            # Check for zeros and use that to get indices
140            return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
141
142        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
143        return rs
144
145    # To show how things happen later
146    def __rmul__(self, other):
147        return super().__rmul__(other)
148
149    _SPECIAL_IMPLS = {}
150
151    @classmethod
152    def _try_call_special_impl(cls, func, args, kwargs):
153        if func not in cls._SPECIAL_IMPLS:
154            return NotImplemented
155        return cls._SPECIAL_IMPLS[func](args, kwargs)
156
157
158# Example non-wrapper subclass that stores extra state.
159class NonWrapperTensor(torch.Tensor):
160    def __new__(cls, data):
161        t = torch.Tensor._make_subclass(cls, data)
162        t.extra_state = {
163            'last_func_called': None
164        }
165        return t
166
167    @classmethod
168    def __torch_function__(cls, func, types, args=(), kwargs=None):
169        result = super().__torch_function__(func, types, args, kwargs)
170
171        if isinstance(result, cls):
172            # Do something with the extra state. For the example here, just store the name of the
173            # last function called (skip for deepcopy so the copy has the same extra state).
174            if func is torch.Tensor.__deepcopy__:
175                result.extra_state = deepcopy(args[0].extra_state)
176            else:
177                result.extra_state = {
178                    'last_func_called': func.__name__,
179                }
180
181        return result
182
183    # new_empty() must be defined for deepcopy to work
184    def new_empty(self, shape):
185        return type(self)(torch.empty(shape))
186
187
188# Class used to store info about subclass tensors used in testing.
189class SubclassInfo:
190
191    __slots__ = ['name', 'create_fn', 'closed_under_ops']
192
193    def __init__(self, name, create_fn, closed_under_ops=True):
194        self.name = name
195        self.create_fn = create_fn  # create_fn(shape) -> tensor instance
196        self.closed_under_ops = closed_under_ops
197
198
199subclass_db = {
200    torch.Tensor: SubclassInfo(
201        'base_tensor', create_fn=torch.randn
202    ),
203    NonWrapperTensor: SubclassInfo(
204        'non_wrapper_tensor',
205        create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
206    ),
207    LoggingTensor: SubclassInfo(
208        'logging_tensor',
209        create_fn=lambda shape: LoggingTensor(torch.randn(shape))
210    ),
211    SparseTensor: SubclassInfo(
212        'sparse_tensor',
213        create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
214    ),
215    DiagTensorBelow: SubclassInfo(
216        'diag_tensor_below',
217        create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
218        closed_under_ops=False  # sparse semantics
219    ),
220}
221
222class SubclassWithTensorFactory(torch.Tensor):
223    @staticmethod
224    def __new__(cls, src):
225        shape = src.shape
226        kwargs = {}
227        kwargs["strides"] = src.stride()
228        kwargs["storage_offset"] = src.storage_offset()
229        kwargs["device"] = src.device
230        kwargs["layout"] = src.layout
231        kwargs["requires_grad"] = src.requires_grad
232        kwargs["dtype"] = src.dtype
233        out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
234        return out
235
236    def __init__(self, src):
237        self.src = src
238
239    def __repr__(self):
240        return f"{self.__class__.__name__}"
241
242    def __tensor_flatten__(self):
243        return ["src"], None
244
245    @classmethod
246    def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
247        src = inner_tensors["src"]
248        return cls(src)
249
250    @classmethod
251    def __torch_dispatch__(cls, func, types, args, kwargs):
252        if kwargs is None:
253            kwargs = {}
254
255        def _fn(x):
256            return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
257
258        _args = pytree.tree_map_only(cls, _fn, args)
259        _kwargs = pytree.tree_map_only(cls, _fn, kwargs)
260
261        _out = func(*_args, **_kwargs)
262
263        _out_flat, _out_spec = pytree.tree_flatten(_out)
264
265        out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
266        return pytree.tree_unflatten(out_flat, _out_spec)
267