xref: /aosp_15_r20/external/pytorch/functorch/einops/_parsing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py.
2
3MIT License
4
5Copyright (c) 2018 Alex Rogozhnikov
6
7Permission is hereby granted, free of charge, to any person obtaining a copy
8of this software and associated documentation files (the "Software"), to deal
9in the Software without restriction, including without limitation the rights
10to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11copies of the Software, and to permit persons to whom the Software is
12furnished to do so, subject to the following conditions:
13
14The above copyright notice and this permission notice shall be included in all
15copies or substantial portions of the Software.
16
17THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23SOFTWARE.
24"""
25from __future__ import annotations
26
27import keyword
28import warnings
29from typing import Collection, List, Mapping, Optional, Set, Tuple, Union
30
31
32_ellipsis: str = "\u2026"  # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
33
34
35class AnonymousAxis:
36    """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier.
37
38    Note: Different instances of this class are not equal to each other, even if they have the same value.
39    """
40
41    def __init__(self, value: str) -> None:
42        self.value = int(value)
43        if self.value < 1:
44            raise ValueError(
45                f"Anonymous axis should have positive length, not {self.value}"
46            )
47
48    def __repr__(self) -> str:
49        return f"{self.value}-axis"
50
51
52class ParsedExpression:
53    """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)')."""
54
55    def __init__(
56        self,
57        expression: str,
58        *,
59        allow_underscore: bool = False,
60        allow_duplicates: bool = False,
61    ) -> None:
62        """Parse the expression and store relevant metadata.
63
64        Args:
65            expression (str): the `einops`-pattern to parse
66            allow_underscore (bool): whether to allow axis identifier names to begin with an underscore
67            allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression
68        """
69        self.has_ellipsis: bool = False
70        self.has_ellipsis_parenthesized: Optional[bool] = None
71        self.identifiers: Set[Union[str, AnonymousAxis]] = set()
72        # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
73        self.has_non_unitary_anonymous_axes: bool = False
74        # composition keeps structure of composite axes, see how different corner cases are handled in tests
75        self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = []
76        if "." in expression:
77            if "..." not in expression:
78                raise ValueError(
79                    "Expression may contain dots only inside ellipsis (...)"
80                )
81            if str.count(expression, "...") != 1 or str.count(expression, ".") != 3:
82                raise ValueError(
83                    "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor "
84                )
85            expression = expression.replace("...", _ellipsis)
86            self.has_ellipsis = True
87
88        bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None
89
90        def add_axis_name(x: str) -> None:
91            if x in self.identifiers:
92                if not (allow_underscore and x == "_") and not allow_duplicates:
93                    raise ValueError(
94                        f"Indexing expression contains duplicate dimension '{x}'"
95                    )
96            if x == _ellipsis:
97                self.identifiers.add(_ellipsis)
98                if bracket_group is None:
99                    self.composition.append(_ellipsis)
100                    self.has_ellipsis_parenthesized = False
101                else:
102                    bracket_group.append(_ellipsis)
103                    self.has_ellipsis_parenthesized = True
104            else:
105                is_number = str.isdecimal(x)
106                if is_number and int(x) == 1:
107                    # handling the case of anonymous axis of length 1
108                    if bracket_group is None:
109                        self.composition.append([])
110                    else:
111                        pass  # no need to think about 1s inside parenthesis
112                    return
113                is_axis_name, reason = self.check_axis_name_return_reason(
114                    x, allow_underscore=allow_underscore
115                )
116                if not (is_number or is_axis_name):
117                    raise ValueError(f"Invalid axis identifier: {x}\n{reason}")
118                axis_name: Union[str, AnonymousAxis] = (
119                    AnonymousAxis(x) if is_number else x
120                )
121                self.identifiers.add(axis_name)
122                if is_number:
123                    self.has_non_unitary_anonymous_axes = True
124                if bracket_group is None:
125                    self.composition.append([axis_name])
126                else:
127                    bracket_group.append(axis_name)
128
129        current_identifier = None
130        for char in expression:
131            if char in "() ":
132                if current_identifier is not None:
133                    add_axis_name(current_identifier)
134                current_identifier = None
135                if char == "(":
136                    if bracket_group is not None:
137                        raise ValueError(
138                            "Axis composition is one-level (brackets inside brackets not allowed)"
139                        )
140                    bracket_group = []
141                elif char == ")":
142                    if bracket_group is None:
143                        raise ValueError("Brackets are not balanced")
144                    self.composition.append(bracket_group)
145                    bracket_group = None
146            elif str.isalnum(char) or char in ["_", _ellipsis]:
147                if current_identifier is None:
148                    current_identifier = char
149                else:
150                    current_identifier += char
151            else:
152                raise ValueError(f"Unknown character '{char}'")
153
154        if bracket_group is not None:
155            raise ValueError(f"Imbalanced parentheses in expression: '{expression}'")
156        if current_identifier is not None:
157            add_axis_name(current_identifier)
158
159    @staticmethod
160    def check_axis_name_return_reason(
161        name: str, allow_underscore: bool = False
162    ) -> Tuple[bool, str]:
163        """Check if the given axis name is valid, and a message explaining why if not.
164
165        Valid axes names are python identifiers except keywords, and should not start or end with an underscore.
166
167        Args:
168            name (str): the axis name to check
169            allow_underscore (bool): whether axis names are allowed to start with an underscore
170
171        Returns:
172            Tuple[bool, str]: whether the axis name is valid, a message explaining why if not
173        """
174        if not str.isidentifier(name):
175            return False, "not a valid python identifier"
176        elif name[0] == "_" or name[-1] == "_":
177            if name == "_" and allow_underscore:
178                return True, ""
179            return False, "axis name should should not start or end with underscore"
180        else:
181            if keyword.iskeyword(name):
182                warnings.warn(
183                    f"It is discouraged to use axes names that are keywords: {name}",
184                    RuntimeWarning,
185                )
186            if name in ["axis"]:
187                warnings.warn(
188                    "It is discouraged to use 'axis' as an axis name and will raise an error in future",
189                    FutureWarning,
190                )
191            return True, ""
192
193    @staticmethod
194    def check_axis_name(name: str) -> bool:
195        """Check if the name is a valid axis name.
196
197        Args:
198            name (str): the axis name to check
199
200        Returns:
201            bool: whether the axis name is valid
202        """
203        is_valid, _ = ParsedExpression.check_axis_name_return_reason(name)
204        return is_valid
205
206
207def parse_pattern(
208    pattern: str, axes_lengths: Mapping[str, int]
209) -> Tuple[ParsedExpression, ParsedExpression]:
210    """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object.
211
212    Args:
213        pattern (str): the `einops`-style rearrangement pattern
214        axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
215
216    Returns:
217       Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions
218    """
219    # adapted from einops.einops._prepare_transformation_recipe
220    # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py
221    try:
222        left_str, right_str = pattern.split("->")
223    except ValueError:
224        raise ValueError("Pattern must contain a single '->' separator") from None
225
226    if _ellipsis in axes_lengths:
227        raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier")
228
229    left = ParsedExpression(left_str)
230    right = ParsedExpression(right_str)
231
232    if not left.has_ellipsis and right.has_ellipsis:
233        raise ValueError(
234            f"Ellipsis found in right side, but not left side of a pattern {pattern}"
235        )
236    if left.has_ellipsis and left.has_ellipsis_parenthesized:
237        raise ValueError(
238            f"Ellipsis is parenthesis in the left side is not allowed: {pattern}"
239        )
240
241    return left, right
242
243
244def validate_rearrange_expressions(
245    left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int]
246) -> None:
247    """Perform expression validations that are specific to the `rearrange` operation.
248
249    Args:
250        left (ParsedExpression): left-hand side expression
251        right (ParsedExpression): right-hand side expression
252        axes_lengths (Mapping[str, int]): any additional length specifications for dimensions
253    """
254    for length in axes_lengths.values():
255        if (length_type := type(length)) is not int:
256            raise TypeError(
257                f"rearrange axis lengths must be integers, got: {length_type}"
258            )
259
260    if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes:
261        raise ValueError("rearrange only supports unnamed axes of size 1")
262
263    difference = set.symmetric_difference(left.identifiers, right.identifiers)
264    if len(difference) > 0:
265        raise ValueError(
266            f"Identifiers only on one side of rearrange expression (should be on both): {difference}"
267        )
268
269    unmatched_axes = axes_lengths.keys() - left.identifiers
270    if len(unmatched_axes) > 0:
271        raise ValueError(
272            f"Identifiers not found in rearrange expression: {unmatched_axes}"
273        )
274
275
276def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str:
277    """Convert a collection of strings representing first class dims into a comma-separated string.
278
279    Args:
280        collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert
281
282    Returns:
283        str: the comma-separated string
284
285    Examples:
286        >>> comma_separate(('d0',))
287        'd0'
288
289        >>> comma_separate(('d0', 'd1', 'd2', 'd3'))
290        'd0, d1, d2, d3'
291
292        >>> comma_separate([('d1', 'd4')])
293        '(d1, d4)'
294
295        >>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')])
296        '(d0,), (), (d1,), (d2,), (d3, d4)'
297    """
298    return ", ".join(
299        item
300        if isinstance(item, str)
301        else f"({comma_separate(item)}{',' if len(item) == 1 else ''})"
302        for item in collection
303    )
304