xref: /aosp_15_r20/external/pytorch/functorch/dim/reference.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# reference python implementations for C ops
8import torch
9from functorch._C import dim as _C
10
11from . import op_properties
12from .batch_tensor import _enable_layers
13from .tree_map import tree_flatten, tree_map
14
15
16DimList = _C.DimList
17import operator
18from functools import reduce
19
20
21# use dict to avoid writing C++ bindings for set
22pointwise = set(op_properties.pointwise)
23
24
25def prod(x):
26    return reduce(operator.mul, x, 1)
27
28
29def _wrap_dim(d, N, keepdim):
30    from . import Dim
31
32    if isinstance(d, Dim):
33        assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
34        return d
35    elif d >= 0:
36        return d - N
37    else:
38        return d
39
40
41def _dims(d, N, keepdim, single_dim):
42    from . import Dim
43
44    if isinstance(d, (Dim, int)):
45        return ltuple((_wrap_dim(d, N, keepdim),))
46    assert not single_dim, f"expected a single dimension or int but found: {d}"
47    return ltuple(_wrap_dim(x, N, keepdim) for x in d)
48
49
50def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
51    from . import DimensionMismatchError
52
53    not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
54    if len(not_bound) == 1:
55        idx, d = not_bound[0]
56        rhs_so_far = prod(r.size for r in rhs if r.is_bound)
57        if lhs_size % rhs_so_far != 0:
58            rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
59            raise DimensionMismatchError(
60                f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}"
61            )
62        new_size = lhs_size // rhs_so_far
63        d.size = new_size
64    elif len(not_bound) > 1:
65        rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs)
66        raise DimensionMismatchError(
67            f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}"
68        )
69    else:
70        rhs_size = prod(r.size for r in rhs)
71        if lhs_size != rhs_size:
72            raise DimensionMismatchError(
73                f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}"
74            )
75
76
77def _tensor_levels(inp):
78    from . import _Tensor
79
80    if isinstance(inp, _Tensor):
81        return inp._tensor, llist(inp._levels), inp._has_device
82    else:
83        return inp, llist(range(-inp.ndim, 0)), True
84
85
86def _match_levels(v, from_levels, to_levels):
87    view = []
88    permute = []
89    requires_view = False
90    size = v.size()
91    for t in to_levels:
92        try:
93            idx = from_levels.index(t)
94            permute.append(idx)
95            view.append(size[idx])
96        except ValueError:
97            view.append(1)
98            requires_view = True
99    if permute != list(range(len(permute))):
100        v = v.permute(*permute)
101    if requires_view:
102        v = v.view(*view)
103    return v
104
105
106# make a single dimension positional but do not permute it,
107# used to do multi-tensor operators where the dim being acted on
108# should not physically move if possible
109def _positional_no_permute(self, dim, expand_dim=False):
110    from . import Tensor
111
112    ptensor, levels = self._tensor, llist(self._levels)
113    try:
114        idx = levels.index(dim)
115    except ValueError:
116        if not expand_dim:
117            raise
118        idx = 0
119        ptensor = ptensor.expand(dim.size, *ptensor.size())
120        levels.insert(0, 0)
121    idx_batched = 0
122    for i in range(idx):
123        if isinstance(levels[i], int):
124            levels[i] -= 1
125            idx_batched += 1
126    levels[idx] = -idx_batched - 1
127    return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
128
129
130def seq(a, b):
131    from . import Dim
132
133    if isinstance(a, Dim) != isinstance(b, Dim):
134        return False
135    if isinstance(a, Dim):
136        return a is b
137    else:
138        return a == b
139
140
141class isin:
142    def __contains__(self, item):
143        for x in self:
144            if seq(item, x):
145                return True
146        return False
147
148    def index(self, item):
149        for i, x in enumerate(self):
150            if seq(item, x):
151                return i
152        raise ValueError
153
154
155class llist(isin, list):
156    pass
157
158
159class ltuple(isin, tuple):
160    pass
161
162
163empty_dict = {}
164
165
166@classmethod
167def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
168    from . import _Tensor, Tensor, TensorLike
169    from .delayed_mul_tensor import DelayedMulTensor
170
171    if orig is torch.Tensor.__mul__:
172        lhs, rhs = args
173        if (
174            isinstance(lhs, _Tensor)
175            and isinstance(rhs, _Tensor)
176            and lhs.ndim == 0
177            and rhs.ndim == 0
178        ):
179            return DelayedMulTensor(lhs, rhs)
180    all_dims = llist()
181    flat_args, unflatten = tree_flatten((args, kwargs))
182    device_holding_tensor = None
183    for f in flat_args:
184        if isinstance(f, _Tensor):
185            if f._has_device:
186                device_holding_tensor = f._batchtensor
187            for d in f.dims:
188                if d not in all_dims:
189                    all_dims.append(d)
190
191    def unwrap(t):
192        if isinstance(t, _Tensor):
193            r = t._batchtensor
194            if device_holding_tensor is not None and not t._has_device:
195                r = r.to(device=device_holding_tensor.device)
196            return r
197        return t
198
199    if orig in pointwise:
200        result_levels = llist()
201        arg_levels = llist()
202        to_expand = []
203        for i, f in enumerate(flat_args):
204            if isinstance(f, TensorLike):
205                ptensor, levels, _ = _tensor_levels(f)
206                if (
207                    isinstance(f, _Tensor)
208                    and not f._has_device
209                    and device_holding_tensor is not None
210                ):
211                    ptensor = ptensor.to(device=device_holding_tensor.device)
212                flat_args[i] = ptensor
213                for l in levels:
214                    if l not in result_levels:
215                        result_levels.append(l)
216                to_expand.append((i, levels))
217
218        for i, levels in to_expand:
219            flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
220        args, kwargs = unflatten(flat_args)
221        result = orig(*args, **kwargs)
222
223        def wrap(t):
224            if isinstance(t, TensorLike):
225                return Tensor.from_positional(
226                    t, result_levels, device_holding_tensor is not None
227                )
228            return t
229
230        return tree_map(wrap, result)
231    else:
232
233        def wrap(t):
234            if isinstance(t, TensorLike):
235                return Tensor.from_batched(t, device_holding_tensor is not None)
236            return t
237
238        with _enable_layers(all_dims):
239            print(f"batch_tensor for {orig}")
240            args, kwargs = unflatten(unwrap(f) for f in flat_args)
241            result = orig(*args, **kwargs)
242            # print("END", orig)
243            return tree_map(wrap, result)
244
245
246def positional(self, *dims):
247    from . import Dim, DimensionBindError, Tensor
248
249    ptensor, levels = self._tensor, llist(self._levels)
250    flat_dims = llist()
251    view = []
252    needs_view = False
253    ndim = self.ndim
254    for d in dims:
255        if isinstance(d, DimList):
256            flat_dims.extend(d)
257            view.extend(e.size for e in d)
258        elif isinstance(d, Dim):
259            flat_dims.append(d)
260            view.append(d.size)
261        elif isinstance(d, int):
262            d = _wrap_dim(d, ndim, False)
263            flat_dims.append(d)
264            view.append(ptensor.size(d))
265        else:
266            flat_dims.extend(d)
267            view.append(prod(e.size for e in d))
268            needs_view = True
269
270    permute = list(range(len(levels)))
271    nflat = len(flat_dims)
272    for i, d in enumerate(flat_dims):
273        try:
274            idx = levels.index(d)
275        except ValueError as e:
276            raise DimensionBindError(
277                f"tensor of dimensions {self.dims} does not contain dim {d}"
278            ) from e
279        p = permute[idx]
280        del levels[idx]
281        del permute[idx]
282        levels.insert(i, 0)
283        permute.insert(i, p)
284    ptensor = ptensor.permute(*permute)
285    seen = 0
286    for i in range(len(levels) - 1, -1, -1):
287        if isinstance(levels[i], int):
288            seen += 1
289            levels[i] = -seen
290    result = Tensor.from_positional(ptensor, levels, self._has_device)
291    if needs_view:
292        result = result.reshape(*view, *result.size()[len(flat_dims) :])
293    return result
294
295
296def _contains_dim(input):
297    from . import Dim
298
299    for i in input:
300        if isinstance(i, Dim):
301            return True
302
303
304def expand(self, *sizes):
305    if not _contains_dim(sizes):
306        return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
307    dims = sizes
308    sizes = [d.size for d in dims] + [-1] * self.ndim
309    self = self.expand(*sizes)
310    return self[dims]
311
312
313_not_present = object()
314
315
316def _getarg(name, offset, args, kwargs, default):
317    if len(args) > offset:
318        return args[offset]
319    return kwargs.get(name, default)
320
321
322def _patcharg(name, offset, args, kwargs, value):
323    if len(args) > offset:
324        args[offset] = value
325    else:
326        kwargs[name] = value
327
328
329def _wrap(
330    orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True
331):
332    from . import Dim, Tensor, TensorLike
333
334    def fn(self, *args, **kwargs):
335        dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
336        if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
337            with _enable_layers(self.dims):
338                print(f"dim fallback batch_tensor for {orig}")
339                return Tensor.from_batched(
340                    orig(self._batchtensor, *args, **kwargs), self._has_device
341                )
342        keepdim = (
343            _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False
344        )
345        t, levels = self._tensor, llist(self._levels)
346        dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
347        dim_indices = tuple(levels.index(d) for d in dims)
348        if reduce and not keepdim:
349            new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
350        else:
351            new_levels = levels
352
353        if len(dim_indices) == 1:
354            dim_indices = dim_indices[
355                0
356            ]  # so that dims that really only take a single argument work...
357        args = list(args)
358        _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
359
360        def wrap(t):
361            if isinstance(t, TensorLike):
362                return Tensor.from_positional(t, new_levels, self._has_device)
363            return t
364
365        with _enable_layers(new_levels):
366            print(f"dim used batch_tensor for {orig}")
367            r = orig(t, *args, **kwargs)
368            return tree_map(wrap, r)
369
370    return fn
371
372
373def _def(name, *args, **kwargs):
374    from . import _Tensor
375
376    orig = getattr(torch.Tensor, name)
377    setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
378
379
380no_slice = slice(None)
381
382_orig_getitem = torch.Tensor.__getitem__
383
384
385class dim_tracker:
386    def __init__(self) -> None:
387        self.dims = llist()
388        self.count = []
389
390    def record(self, d):
391        if d not in self.dims:
392            self.dims.append(d)
393            self.count.append(1)
394
395    def __getitem__(self, d):
396        return self.count[self.dims.index(d)]
397
398
399def t__getitem__(self, input):
400    from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike
401
402    # * bail to original example if we have a single non-Dim tensor, or a non-tensor
403    # * locate ... or an unbound tensor list, and determine its size, bind dim list
404    #   (remember that None does not count to the total dim count)
405    # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
406    #   produce the re-view if needed
407    # * for each single-use dim index, replace with no_slice and mark that it will be added
408    #   (keep track of whether we have to call super)
409    # * call super if needed
410    # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
411    # this handles bool indexing handling, as well as some other simple cases.
412
413    is_simple = (
414        not isinstance(input, Dim)
415        and not isinstance(input, (tuple, list))
416        and
417        # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
418        not (isinstance(input, TensorLike) and input.ndim == 0)
419    )
420
421    if is_simple:
422        if isinstance(self, _Tensor):
423            return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
424        else:
425            return _orig_getitem(self, input)
426
427    # can further optimize this case
428    if not isinstance(input, tuple):
429        input = [input]
430    else:
431        input = list(input)
432
433    dims_indexed = 0
434    expanding_object = None
435    dimlists = []
436    for i, s in enumerate(input):
437        if s is ... or isinstance(s, DimList) and not s.is_bound:
438            if expanding_object is not None:
439                msg = (
440                    "at most one ... or unbound dimension list can exist in indexing list but"
441                    f" found 2 at offsets {i} and {expanding_object}"
442                )
443                raise DimensionBindError(msg)
444            expanding_object = i
445
446        if isinstance(s, DimList):
447            dims_indexed += len(s) if s.is_bound else 0
448            dimlists.append(i)
449        elif s is not None and s is not ...:
450            dims_indexed += 1
451
452    ndim = self.ndim
453    if dims_indexed > ndim:
454        raise IndexError(
455            f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions."
456        )
457    if expanding_object is not None:
458        expanding_ndims = ndim - dims_indexed
459        obj = input[expanding_object]
460        if obj is ...:
461            input[expanding_object : expanding_object + 1] = [
462                no_slice
463            ] * expanding_ndims
464        else:
465            obj.bind_len(expanding_ndims)
466    # flatten the dimslists into the indexing
467    for i in reversed(dimlists):
468        input[i : i + 1] = input[i]
469    dims_indexed = 0
470    requires_view = False
471    size = self.size()
472    view_sizes = []
473    dims_seen = dim_tracker()
474
475    def add_dims(t):
476        if not isinstance(t, _Tensor):
477            return
478        for d in t.dims:
479            dims_seen.record(d)
480
481    add_dims(self)
482    dim_packs = []
483    for i, idx in enumerate(input):
484        if idx is None:
485            input[i] = no_slice
486            view_sizes.append(1)
487            requires_view = True
488        else:
489            sz = size[dims_indexed]
490            if isinstance(idx, Dim):
491                idx.size = sz
492                dims_seen.record(idx)
493                view_sizes.append(sz)
494            elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
495                for d in idx:
496                    dims_seen.record(idx)
497                _bind_dims_to_size(sz, idx, f"offset {i}")
498                view_sizes.extend(d.size for d in idx)
499                requires_view = True
500                dim_packs.append(i)
501            else:
502                add_dims(idx)
503                view_sizes.append(sz)
504            dims_indexed += 1
505    if requires_view:
506        self = self.view(*view_sizes)
507    for i in reversed(dim_packs):
508        input[i : i + 1] = input[i]
509
510    # currenty:
511    # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
512    # self may have first-class dims as well.
513
514    # to index:
515    # drop the first class dims from self, they just become direct indices of their positions
516
517    # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
518    # these dimensions will appear and need to be bound at the first place tensor occures
519
520    if isinstance(self, _Tensor):
521        ptensor_self, levels = self._tensor, list(self._levels)
522        # indices to ptensor rather than self which has first-class dimensions
523        input_it = iter(input)
524        flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
525        has_device = self._has_device
526        to_pad = 0
527    else:
528        ptensor_self, flat_inputs = self, input
529        to_pad = ptensor_self.ndim - len(flat_inputs)
530        has_device = True
531
532    result_levels = []
533    index_levels = []
534    tensor_insert_point = None
535    to_expand = {}
536    requires_getindex = False
537    for i, inp in enumerate(flat_inputs):
538        if isinstance(inp, Dim) and dims_seen[inp] == 1:
539            flat_inputs[i] = no_slice
540            result_levels.append(inp)
541        elif isinstance(inp, TensorLike):
542            requires_getindex = True
543            if tensor_insert_point is None:
544                tensor_insert_point = len(result_levels)
545            ptensor, levels, _ = _tensor_levels(inp)
546            to_expand[i] = levels
547            flat_inputs[i] = ptensor
548            for l in levels:
549                if l not in index_levels:
550                    index_levels.append(l)
551        else:
552            requires_getindex = True
553            result_levels.append(0)
554
555    if tensor_insert_point is not None:
556        result_levels[tensor_insert_point:tensor_insert_point] = index_levels
557
558    for i, levels in to_expand.items():
559        flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
560
561    if requires_getindex:
562        result = _orig_getitem(ptensor_self, flat_inputs)
563    else:
564        result = ptensor_self
565
566    next_positional = -1
567    if to_pad > 0:
568        result_levels.extend([0] * to_pad)
569    for i, r in enumerate(reversed(result_levels)):
570        if isinstance(r, int):
571            result_levels[-1 - i] = next_positional
572            next_positional -= 1
573
574    return Tensor.from_positional(result, result_levels, has_device)
575
576
577# XXX - dim is optional and can be the outer-most dimension...
578def stack(tensors, new_dim, dim=0, out=None):
579    if isinstance(dim, int):
580        return torch.stack(tensors, dim, out).index(dim, new_dim)
581    index = None
582    if out is not None:
583        out, index = _positional_no_permute(out, dim, expand_dim=True)
584    ptensors = []
585    for t in tensors:
586        pt, pi = _positional_no_permute(t, dim, expand_dim=True)
587        if index is not None and pi != index:
588            pt = pt.move_dim(pi, index)
589        else:
590            index = pi
591        ptensors.append(pt)
592    pr = torch.stack(ptensors, index, out=out)
593    return pr.index((index, index + 1), (new_dim, dim))
594
595
596_orig_split = torch.Tensor.split
597
598
599def split(self, split_size_or_sections, dim=0):
600    from . import _Tensor, Dim
601
602    if isinstance(split_size_or_sections, int) or any(
603        isinstance(t, int) for t in split_size_or_sections
604    ):
605        if isinstance(dim, Dim):
606            raise ValueError(
607                "when dim is specified as a Dim object, split sizes must also be dimensions."
608            )
609        return _orig_split(self, split_size_or_sections, dim=dim)
610
611    if isinstance(dim, Dim):
612        assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
613        self, dim = _positional_no_permute(self, dim)
614
615    size = self.size(dim)
616    total_bound_size = 0
617    unbound = []
618    sizes = []
619    for i, d in enumerate(split_size_or_sections):
620        if d.is_bound:
621            sizes.append(d.size)
622            total_bound_size += d.size
623        else:
624            sizes.append(0)
625            unbound.append(i)
626
627    if unbound:
628        assert (
629            total_bound_size <= size
630        ), f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
631        remaining_size = size - total_bound_size
632        chunk_size = -(-remaining_size // len(unbound))
633        for u in unbound:
634            sz = min(chunk_size, remaining_size)
635            split_size_or_sections[u].size = sz
636            sizes[u] = sz
637            remaining_size -= sz
638    else:
639        assert (
640            total_bound_size == size
641        ), f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
642    return tuple(
643        t.index(dim, d)
644        for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim))
645    )
646