xref: /aosp_15_r20/external/pytorch/functorch/dim/dim.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import dis
7import inspect
8from dataclasses import dataclass
9from typing import Union
10
11from . import DimList
12
13
14_vmap_levels = []
15
16
17@dataclass
18class LevelInfo:
19    level: int
20    alive: bool = True
21
22
23class Dim:
24    def __init__(self, name: str, size: Union[None, int] = None):
25        self.name = name
26        self._size = None
27        self._vmap_level = None
28        if size is not None:
29            self.size = size
30
31    def __del__(self):
32        if self._vmap_level is not None:
33            _vmap_active_levels[self._vmap_stack].alive = False  # noqa: F821
34            while (
35                not _vmap_levels[-1].alive
36                and current_level() == _vmap_levels[-1].level  # noqa: F821
37            ):
38                _vmap_decrement_nesting()  # noqa: F821
39                _vmap_levels.pop()
40
41    @property
42    def size(self):
43        assert self.is_bound
44        return self._size
45
46    @size.setter
47    def size(self, size: int):
48        from . import DimensionBindError
49
50        if self._size is None:
51            self._size = size
52            self._vmap_level = _vmap_increment_nesting(size, "same")  # noqa: F821
53            self._vmap_stack = len(_vmap_levels)
54            _vmap_levels.append(LevelInfo(self._vmap_level))
55
56        elif self._size != size:
57            raise DimensionBindError(
58                f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
59            )
60
61    @property
62    def is_bound(self):
63        return self._size is not None
64
65    def __repr__(self):
66        return self.name
67
68
69def extract_name(inst):
70    assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
71    return inst.argval
72
73
74_cache = {}
75
76
77def dims(lists=0):
78    frame = inspect.currentframe()
79    assert frame is not None
80    calling_frame = frame.f_back
81    assert calling_frame is not None
82    code, lasti = calling_frame.f_code, calling_frame.f_lasti
83    key = (code, lasti)
84    if key not in _cache:
85        first = lasti // 2 + 1
86        instructions = list(dis.get_instructions(calling_frame.f_code))
87        unpack = instructions[first]
88
89        if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
90            # just a single dim, not a list
91            name = unpack.argval
92            ctor = Dim if lists == 0 else DimList
93            _cache[key] = lambda: ctor(name=name)
94        else:
95            assert unpack.opname == "UNPACK_SEQUENCE"
96            ndims = unpack.argval
97            names = tuple(
98                extract_name(instructions[first + 1 + i]) for i in range(ndims)
99            )
100            first_list = len(names) - lists
101            _cache[key] = lambda: tuple(
102                Dim(n) if i < first_list else DimList(name=n)
103                for i, n in enumerate(names)
104            )
105    return _cache[key]()
106
107
108def _dim_set(positional, arg):
109    def convert(a):
110        if isinstance(a, Dim):
111            return a
112        else:
113            assert isinstance(a, int)
114            return positional[a]
115
116    if arg is None:
117        return positional
118    elif not isinstance(arg, (Dim, int)):
119        return tuple(convert(a) for a in arg)
120    else:
121        return (convert(arg),)
122