1"""Adapted from https://github.com/arogozhnikov/einops/blob/36c7bb16e57d6e57f8f3050f9e07abdf3f00469f/einops/parsing.py. 2 3MIT License 4 5Copyright (c) 2018 Alex Rogozhnikov 6 7Permission is hereby granted, free of charge, to any person obtaining a copy 8of this software and associated documentation files (the "Software"), to deal 9in the Software without restriction, including without limitation the rights 10to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11copies of the Software, and to permit persons to whom the Software is 12furnished to do so, subject to the following conditions: 13 14The above copyright notice and this permission notice shall be included in all 15copies or substantial portions of the Software. 16 17THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23SOFTWARE. 24""" 25from __future__ import annotations 26 27import keyword 28import warnings 29from typing import Collection, List, Mapping, Optional, Set, Tuple, Union 30 31 32_ellipsis: str = "\u2026" # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated 33 34 35class AnonymousAxis: 36 """Used by `ParsedExpression` to represent an axis with a size (> 1), but no associated identifier. 37 38 Note: Different instances of this class are not equal to each other, even if they have the same value. 39 """ 40 41 def __init__(self, value: str) -> None: 42 self.value = int(value) 43 if self.value < 1: 44 raise ValueError( 45 f"Anonymous axis should have positive length, not {self.value}" 46 ) 47 48 def __repr__(self) -> str: 49 return f"{self.value}-axis" 50 51 52class ParsedExpression: 53 """Structure containing information about one side of an `einops`-style pattern (e.g. 'b c (h w)').""" 54 55 def __init__( 56 self, 57 expression: str, 58 *, 59 allow_underscore: bool = False, 60 allow_duplicates: bool = False, 61 ) -> None: 62 """Parse the expression and store relevant metadata. 63 64 Args: 65 expression (str): the `einops`-pattern to parse 66 allow_underscore (bool): whether to allow axis identifier names to begin with an underscore 67 allow_duplicates (bool): whether to allow an identifier to appear more than once in the expression 68 """ 69 self.has_ellipsis: bool = False 70 self.has_ellipsis_parenthesized: Optional[bool] = None 71 self.identifiers: Set[Union[str, AnonymousAxis]] = set() 72 # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition 73 self.has_non_unitary_anonymous_axes: bool = False 74 # composition keeps structure of composite axes, see how different corner cases are handled in tests 75 self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] 76 if "." in expression: 77 if "..." not in expression: 78 raise ValueError( 79 "Expression may contain dots only inside ellipsis (...)" 80 ) 81 if str.count(expression, "...") != 1 or str.count(expression, ".") != 3: 82 raise ValueError( 83 "Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor " 84 ) 85 expression = expression.replace("...", _ellipsis) 86 self.has_ellipsis = True 87 88 bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None 89 90 def add_axis_name(x: str) -> None: 91 if x in self.identifiers: 92 if not (allow_underscore and x == "_") and not allow_duplicates: 93 raise ValueError( 94 f"Indexing expression contains duplicate dimension '{x}'" 95 ) 96 if x == _ellipsis: 97 self.identifiers.add(_ellipsis) 98 if bracket_group is None: 99 self.composition.append(_ellipsis) 100 self.has_ellipsis_parenthesized = False 101 else: 102 bracket_group.append(_ellipsis) 103 self.has_ellipsis_parenthesized = True 104 else: 105 is_number = str.isdecimal(x) 106 if is_number and int(x) == 1: 107 # handling the case of anonymous axis of length 1 108 if bracket_group is None: 109 self.composition.append([]) 110 else: 111 pass # no need to think about 1s inside parenthesis 112 return 113 is_axis_name, reason = self.check_axis_name_return_reason( 114 x, allow_underscore=allow_underscore 115 ) 116 if not (is_number or is_axis_name): 117 raise ValueError(f"Invalid axis identifier: {x}\n{reason}") 118 axis_name: Union[str, AnonymousAxis] = ( 119 AnonymousAxis(x) if is_number else x 120 ) 121 self.identifiers.add(axis_name) 122 if is_number: 123 self.has_non_unitary_anonymous_axes = True 124 if bracket_group is None: 125 self.composition.append([axis_name]) 126 else: 127 bracket_group.append(axis_name) 128 129 current_identifier = None 130 for char in expression: 131 if char in "() ": 132 if current_identifier is not None: 133 add_axis_name(current_identifier) 134 current_identifier = None 135 if char == "(": 136 if bracket_group is not None: 137 raise ValueError( 138 "Axis composition is one-level (brackets inside brackets not allowed)" 139 ) 140 bracket_group = [] 141 elif char == ")": 142 if bracket_group is None: 143 raise ValueError("Brackets are not balanced") 144 self.composition.append(bracket_group) 145 bracket_group = None 146 elif str.isalnum(char) or char in ["_", _ellipsis]: 147 if current_identifier is None: 148 current_identifier = char 149 else: 150 current_identifier += char 151 else: 152 raise ValueError(f"Unknown character '{char}'") 153 154 if bracket_group is not None: 155 raise ValueError(f"Imbalanced parentheses in expression: '{expression}'") 156 if current_identifier is not None: 157 add_axis_name(current_identifier) 158 159 @staticmethod 160 def check_axis_name_return_reason( 161 name: str, allow_underscore: bool = False 162 ) -> Tuple[bool, str]: 163 """Check if the given axis name is valid, and a message explaining why if not. 164 165 Valid axes names are python identifiers except keywords, and should not start or end with an underscore. 166 167 Args: 168 name (str): the axis name to check 169 allow_underscore (bool): whether axis names are allowed to start with an underscore 170 171 Returns: 172 Tuple[bool, str]: whether the axis name is valid, a message explaining why if not 173 """ 174 if not str.isidentifier(name): 175 return False, "not a valid python identifier" 176 elif name[0] == "_" or name[-1] == "_": 177 if name == "_" and allow_underscore: 178 return True, "" 179 return False, "axis name should should not start or end with underscore" 180 else: 181 if keyword.iskeyword(name): 182 warnings.warn( 183 f"It is discouraged to use axes names that are keywords: {name}", 184 RuntimeWarning, 185 ) 186 if name in ["axis"]: 187 warnings.warn( 188 "It is discouraged to use 'axis' as an axis name and will raise an error in future", 189 FutureWarning, 190 ) 191 return True, "" 192 193 @staticmethod 194 def check_axis_name(name: str) -> bool: 195 """Check if the name is a valid axis name. 196 197 Args: 198 name (str): the axis name to check 199 200 Returns: 201 bool: whether the axis name is valid 202 """ 203 is_valid, _ = ParsedExpression.check_axis_name_return_reason(name) 204 return is_valid 205 206 207def parse_pattern( 208 pattern: str, axes_lengths: Mapping[str, int] 209) -> Tuple[ParsedExpression, ParsedExpression]: 210 """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. 211 212 Args: 213 pattern (str): the `einops`-style rearrangement pattern 214 axes_lengths (Mapping[str, int]): any additional length specifications for dimensions 215 216 Returns: 217 Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions 218 """ 219 # adapted from einops.einops._prepare_transformation_recipe 220 # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py 221 try: 222 left_str, right_str = pattern.split("->") 223 except ValueError: 224 raise ValueError("Pattern must contain a single '->' separator") from None 225 226 if _ellipsis in axes_lengths: 227 raise ValueError(f"'{_ellipsis}' is not an allowed axis identifier") 228 229 left = ParsedExpression(left_str) 230 right = ParsedExpression(right_str) 231 232 if not left.has_ellipsis and right.has_ellipsis: 233 raise ValueError( 234 f"Ellipsis found in right side, but not left side of a pattern {pattern}" 235 ) 236 if left.has_ellipsis and left.has_ellipsis_parenthesized: 237 raise ValueError( 238 f"Ellipsis is parenthesis in the left side is not allowed: {pattern}" 239 ) 240 241 return left, right 242 243 244def validate_rearrange_expressions( 245 left: ParsedExpression, right: ParsedExpression, axes_lengths: Mapping[str, int] 246) -> None: 247 """Perform expression validations that are specific to the `rearrange` operation. 248 249 Args: 250 left (ParsedExpression): left-hand side expression 251 right (ParsedExpression): right-hand side expression 252 axes_lengths (Mapping[str, int]): any additional length specifications for dimensions 253 """ 254 for length in axes_lengths.values(): 255 if (length_type := type(length)) is not int: 256 raise TypeError( 257 f"rearrange axis lengths must be integers, got: {length_type}" 258 ) 259 260 if left.has_non_unitary_anonymous_axes or right.has_non_unitary_anonymous_axes: 261 raise ValueError("rearrange only supports unnamed axes of size 1") 262 263 difference = set.symmetric_difference(left.identifiers, right.identifiers) 264 if len(difference) > 0: 265 raise ValueError( 266 f"Identifiers only on one side of rearrange expression (should be on both): {difference}" 267 ) 268 269 unmatched_axes = axes_lengths.keys() - left.identifiers 270 if len(unmatched_axes) > 0: 271 raise ValueError( 272 f"Identifiers not found in rearrange expression: {unmatched_axes}" 273 ) 274 275 276def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: 277 """Convert a collection of strings representing first class dims into a comma-separated string. 278 279 Args: 280 collection (Collection[Union[str, Collection[str]]]): the collection of strings to convert 281 282 Returns: 283 str: the comma-separated string 284 285 Examples: 286 >>> comma_separate(('d0',)) 287 'd0' 288 289 >>> comma_separate(('d0', 'd1', 'd2', 'd3')) 290 'd0, d1, d2, d3' 291 292 >>> comma_separate([('d1', 'd4')]) 293 '(d1, d4)' 294 295 >>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')]) 296 '(d0,), (), (d1,), (d2,), (d3, d4)' 297 """ 298 return ", ".join( 299 item 300 if isinstance(item, str) 301 else f"({comma_separate(item)}{',' if len(item) == 1 else ''})" 302 for item in collection 303 ) 304