1# mypy: ignore-errors 2 3import itertools 4import operator 5import sys 6from typing import Dict, List, Optional, TYPE_CHECKING, Union 7 8from .. import polyfills, variables 9from ..bytecode_transformation import create_call_function, create_instruction 10from ..exc import ( 11 handle_observed_exception, 12 ObservedUserStopIteration, 13 raise_observed_exception, 14 unimplemented, 15 UserError, 16) 17from .base import MutableLocal, VariableTracker 18from .constant import ConstantVariable 19 20 21if TYPE_CHECKING: 22 from torch._dynamo.symbolic_convert import InstructionTranslator 23 24 25MAX_ITERATOR_LIMIT = 100 * 1024 # 100k 26 27 28class ItertoolsVariable(VariableTracker): 29 def __init__(self, value, **kwargs) -> None: 30 super().__init__(**kwargs) 31 self.value = value 32 33 def __repr__(self) -> str: 34 return f"ItertoolsVariable({self.value})" 35 36 def as_python_constant(self): 37 return self.value 38 39 def call_function( 40 self, 41 tx: "InstructionTranslator", 42 args: "List[VariableTracker]", 43 kwargs: "Dict[str, VariableTracker]", 44 ) -> "VariableTracker": 45 if ( 46 self.value is itertools.product 47 and not kwargs 48 and all(arg.has_unpack_var_sequence(tx) for arg in args) 49 ): 50 seqs = [arg.unpack_var_sequence(tx) for arg in args] 51 items = [] 52 for item in itertools.product(*seqs): 53 items.append(variables.TupleVariable(list(item))) 54 return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) 55 elif self.value is itertools.accumulate: 56 from .builtin import BuiltinVariable 57 58 if any(key not in ["initial", "func"] for key in kwargs.keys()): 59 unimplemented( 60 "Unsupported kwargs for itertools.accumulate: " 61 f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}" 62 ) 63 64 acc = kwargs.get("initial") 65 66 if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx): 67 seq = args[0].unpack_var_sequence(tx) 68 69 if "func" in kwargs and len(args) == 1: 70 func = kwargs["func"].call_function 71 elif len(args) == 2: 72 func = args[1].call_function 73 elif len(args) == 1: 74 # Default to operator.add 75 func = BuiltinVariable(operator.add).call_function 76 else: 77 unimplemented( 78 "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg" 79 ) 80 else: 81 unimplemented("Unsupported arguments for itertools.accumulate") 82 83 items = [] 84 if acc is not None: 85 items.append(acc) 86 for item in seq: 87 if acc is None: 88 acc = item 89 else: 90 try: 91 acc = func(tx, [acc, item], {}) 92 except Exception as e: 93 unimplemented( 94 f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})", 95 from_exc=e, 96 ) 97 items.append(acc) 98 99 return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) 100 elif ( 101 self.value is itertools.combinations 102 and not kwargs 103 and len(args) == 2 104 and args[0].has_unpack_var_sequence(tx) 105 and args[1].is_python_constant() 106 ): 107 iterable = args[0].unpack_var_sequence(tx) 108 r = args[1].as_python_constant() 109 110 items = [] 111 for item in itertools.combinations(iterable, r): 112 items.append(variables.TupleVariable(list(item))) 113 return variables.ListIteratorVariable(items, mutable_local=MutableLocal()) 114 elif self.value is itertools.groupby: 115 if any(kw != "key" for kw in kwargs.keys()): 116 unimplemented( 117 "Unsupported kwargs for itertools.groupby: " 118 f"{','.join(set(kwargs.keys()) - {'key'})}" 119 ) 120 121 def retrieve_const_key(key): 122 if isinstance(key, variables.SymNodeVariable): 123 return key.evaluate_expr() 124 elif isinstance(key, variables.ConstantVariable): 125 return key.as_python_constant() 126 else: 127 unimplemented( 128 "Unsupported key type for itertools.groupby: " + str(type(key)) 129 ) 130 131 if len(args) == 1 and args[0].has_unpack_var_sequence(tx): 132 seq = args[0].unpack_var_sequence(tx) 133 keyfunc = ( 134 ( 135 lambda x: ( 136 retrieve_const_key( 137 kwargs.get("key").call_function(tx, [x], {}) 138 ) 139 ) 140 ) 141 if "key" in kwargs 142 else None 143 ) 144 else: 145 unimplemented("Unsupported arguments for itertools.groupby") 146 147 result = [] 148 try: 149 for k, v in itertools.groupby(seq, key=keyfunc): 150 result.append( 151 variables.TupleVariable( 152 [ 153 variables.ConstantVariable.create(k) 154 if variables.ConstantVariable.is_literal(k) 155 else k, 156 variables.ListIteratorVariable( 157 list(v), mutable_local=MutableLocal() 158 ), 159 ], 160 mutable_local=MutableLocal(), 161 ) 162 ) 163 except Exception as e: 164 unimplemented( 165 "Unexpected failure when calling itertools.groupby", 166 from_exc=e, 167 ) 168 return variables.ListIteratorVariable(result, mutable_local=MutableLocal()) 169 elif self.value is itertools.repeat: 170 if len(args) < 2: 171 return variables.RepeatIteratorVariable( 172 *args, mutable_local=MutableLocal() 173 ) 174 175 from .builder import SourcelessBuilder 176 177 return tx.inline_user_function_return( 178 SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs 179 ) 180 elif self.value is itertools.count: 181 return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) 182 elif self.value is itertools.cycle: 183 return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal()) 184 elif self.value is itertools.dropwhile: 185 return variables.UserFunctionVariable(polyfills.dropwhile).call_function( 186 tx, args, kwargs 187 ) 188 elif self.value is itertools.zip_longest: 189 return variables.UserFunctionVariable(polyfills.zip_longest).call_function( 190 tx, args, kwargs 191 ) 192 else: 193 return super().call_function(tx, args, kwargs) 194 195 196class IteratorVariable(VariableTracker): 197 def __init__(self, **kwargs) -> None: 198 super().__init__(**kwargs) 199 200 def next_variable(self, tx): 201 unimplemented("abstract method, must implement") 202 203 # NOTE: only call when unpacking this iterator safely done eagerly! 204 # Normally, iterators are accessed lazily. 205 # Example of safe eager unpacking: list(map(f, seq)) 206 # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) 207 def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: 208 result = [] 209 while True: 210 try: 211 result.append(self.next_variable(tx)) 212 except ObservedUserStopIteration: 213 handle_observed_exception(tx) 214 break 215 return result 216 217 # don't call force_unpack_var_sequence since it can mutate 218 # IteratorVariable state! 219 def has_force_unpack_var_sequence(self, tx) -> bool: 220 return True 221 222 223class RepeatIteratorVariable(IteratorVariable): 224 def __init__(self, item: VariableTracker, **kwargs) -> None: 225 super().__init__(**kwargs) 226 self.item = item 227 228 # Repeat needs no mutation, clone self 229 def next_variable(self, tx): 230 return self.item 231 232 def reconstruct(self, codegen): 233 codegen.add_push_null( 234 lambda: codegen.extend_output( 235 [ 236 codegen.create_load_python_module(itertools), 237 codegen.create_load_attr("repeat"), 238 ] 239 ) 240 ) 241 codegen(self.item) 242 codegen.extend_output(create_call_function(1, False)) 243 244 245class CountIteratorVariable(IteratorVariable): 246 def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: 247 super().__init__(**kwargs) 248 if not isinstance(item, VariableTracker): 249 item = ConstantVariable.create(item) 250 if not isinstance(step, VariableTracker): 251 step = ConstantVariable.create(step) 252 self.item = item 253 self.step = step 254 255 def next_variable(self, tx): 256 assert self.mutable_local 257 old_item = self.item 258 tx.output.side_effects.mutation(self) 259 self.item = self.item.call_method(tx, "__add__", [self.step], {}) 260 return old_item 261 262 def reconstruct(self, codegen): 263 codegen.add_push_null( 264 lambda: codegen.extend_output( 265 [ 266 codegen.create_load_python_module(itertools), 267 codegen.create_load_attr("count"), 268 ] 269 ) 270 ) 271 codegen(self.item) 272 codegen(self.step) 273 codegen.extend_output(create_call_function(2, False)) 274 275 276class CycleIteratorVariable(IteratorVariable): 277 def __init__( 278 self, 279 iterator: IteratorVariable, 280 saved: List[VariableTracker] = None, 281 saved_index: int = 0, 282 item: Optional[VariableTracker] = None, 283 **kwargs, 284 ) -> None: 285 if saved is None: 286 saved = [] 287 super().__init__(**kwargs) 288 self.iterator = iterator 289 self.saved = saved 290 self.saved_index = saved_index 291 self.item = item 292 293 def next_variable(self, tx): 294 assert self.mutable_local 295 296 if self.iterator is not None: 297 try: 298 new_item = self.iterator.next_variable(tx) 299 if len(self.saved) > MAX_ITERATOR_LIMIT: 300 unimplemented( 301 "input iterator to itertools.cycle has too many items" 302 ) 303 tx.output.side_effects.mutation(self) 304 self.saved.append(new_item) 305 self.item = new_item 306 if self.item is None: 307 return self.next_variable(tx) 308 return self.item 309 except ObservedUserStopIteration: 310 handle_observed_exception(tx) 311 self.iterator = None 312 return self.next_variable(tx) 313 elif len(self.saved) > 0: 314 tx.output.side_effects.mutation(self) 315 self.saved_index = (self.saved_index + 1) % len(self.saved) 316 return self.item 317 else: 318 raise_observed_exception(StopIteration, tx, self) 319 320 321class ZipVariable(IteratorVariable): 322 """ 323 Represents zip(*iterables) 324 """ 325 326 _nonvar_fields = { 327 "index", 328 "strict", 329 *IteratorVariable._nonvar_fields, 330 } 331 332 def __init__( 333 self, 334 iterables: List[Union[List[VariableTracker], VariableTracker]], 335 strict: bool = False, 336 **kwargs, 337 ) -> None: 338 super().__init__(**kwargs) 339 assert isinstance(iterables, list) 340 # can be list[Variable] or VariableTracker (with next_variable implemented) 341 self.iterables = iterables 342 self.index = 0 343 self.strict = strict 344 345 def python_type(self): 346 return zip 347 348 def has_unpack_var_sequence(self, tx) -> bool: 349 return all( 350 isinstance(it, list) or it.has_unpack_var_sequence(tx) 351 for it in self.iterables 352 ) 353 354 def unpack_var_sequence(self, tx) -> List["VariableTracker"]: 355 assert self.has_unpack_var_sequence(tx) 356 iterables = [] 357 for it in self.iterables: 358 if isinstance(it, list): 359 iterables.append(it[self.index :]) 360 else: 361 iterables.append(it.unpack_var_sequence(tx)) 362 kwargs = {"strict": self.strict} if self.strict else {} 363 zipped = zip(*iterables, **kwargs) 364 return [variables.TupleVariable(list(var)) for var in zipped] 365 366 def next_variable(self, tx): 367 assert self.mutable_local 368 old_index = self.index 369 args = [] 370 371 def get_item(it): 372 if isinstance(it, list): 373 if old_index >= len(it): 374 raise_observed_exception(StopIteration, tx, self) 375 return it[old_index] 376 else: 377 return it.next_variable(tx) 378 379 try: 380 for idx, it in enumerate(self.iterables): 381 args.append(get_item(it)) 382 except ObservedUserStopIteration: 383 if self.strict: 384 if idx == 0: 385 # all other iterables should be exhausted 386 for it in self.iterables: 387 try: 388 get_item(it) 389 except ObservedUserStopIteration: 390 handle_observed_exception(tx) 391 continue 392 # no ObservedUserStopIteration - fall through to UserError 393 break 394 else: 395 # all iterables exhausted, raise original error 396 raise 397 handle_observed_exception(tx) 398 raise UserError( 399 ValueError, 400 "zip() has one argument of len differing from others", 401 ) from None 402 raise 403 404 tx.output.side_effects.mutation(self) 405 self.index += 1 406 return variables.TupleVariable(args) 407 408 def reconstruct_items(self, codegen): 409 for it in self.iterables: 410 if isinstance(it, list): 411 remaining_items = it[self.index :] 412 codegen.foreach(remaining_items) 413 codegen.append_output( 414 create_instruction("BUILD_TUPLE", arg=len(remaining_items)) 415 ) 416 else: 417 codegen(it) 418 419 def reconstruct(self, codegen): 420 codegen.add_push_null( 421 lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True 422 ) 423 self.reconstruct_items(codegen) 424 codegen.append_output( 425 create_instruction("BUILD_TUPLE", arg=len(self.iterables)) 426 ) 427 if sys.version_info >= (3, 10): 428 codegen.extend_output( 429 [ 430 codegen.create_load_const("strict"), 431 codegen.create_load_const(self.strict), 432 create_instruction("BUILD_MAP", arg=1), 433 create_instruction("CALL_FUNCTION_EX", arg=1), 434 ] 435 ) 436 else: 437 codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0)) 438 439 440class MapVariable(ZipVariable): 441 """ 442 Represents map(fn, *iterables) 443 """ 444 445 def __init__( 446 self, 447 fn: VariableTracker, 448 iterables: List[Union[List[VariableTracker], VariableTracker]], 449 **kwargs, 450 ) -> None: 451 super().__init__(iterables, **kwargs) 452 self.fn = fn 453 454 def python_type(self): 455 return map 456 457 def has_unpack_var_sequence(self, tx) -> bool: 458 return False 459 460 def next_variable(self, tx): 461 args = super().next_variable(tx) 462 return self.fn.call_function(tx, args.items, {}) 463 464 def reconstruct(self, codegen): 465 codegen.add_push_null( 466 lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True 467 ) 468 codegen(self.fn) 469 self.reconstruct_items(codegen) 470 codegen.extend_output( 471 [ 472 create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1), 473 create_instruction("CALL_FUNCTION_EX", arg=0), 474 ] 475 ) 476