xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/iter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import itertools
4import operator
5import sys
6from typing import Dict, List, Optional, TYPE_CHECKING, Union
7
8from .. import polyfills, variables
9from ..bytecode_transformation import create_call_function, create_instruction
10from ..exc import (
11    handle_observed_exception,
12    ObservedUserStopIteration,
13    raise_observed_exception,
14    unimplemented,
15    UserError,
16)
17from .base import MutableLocal, VariableTracker
18from .constant import ConstantVariable
19
20
21if TYPE_CHECKING:
22    from torch._dynamo.symbolic_convert import InstructionTranslator
23
24
25MAX_ITERATOR_LIMIT = 100 * 1024  # 100k
26
27
28class ItertoolsVariable(VariableTracker):
29    def __init__(self, value, **kwargs) -> None:
30        super().__init__(**kwargs)
31        self.value = value
32
33    def __repr__(self) -> str:
34        return f"ItertoolsVariable({self.value})"
35
36    def as_python_constant(self):
37        return self.value
38
39    def call_function(
40        self,
41        tx: "InstructionTranslator",
42        args: "List[VariableTracker]",
43        kwargs: "Dict[str, VariableTracker]",
44    ) -> "VariableTracker":
45        if (
46            self.value is itertools.product
47            and not kwargs
48            and all(arg.has_unpack_var_sequence(tx) for arg in args)
49        ):
50            seqs = [arg.unpack_var_sequence(tx) for arg in args]
51            items = []
52            for item in itertools.product(*seqs):
53                items.append(variables.TupleVariable(list(item)))
54            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
55        elif self.value is itertools.accumulate:
56            from .builtin import BuiltinVariable
57
58            if any(key not in ["initial", "func"] for key in kwargs.keys()):
59                unimplemented(
60                    "Unsupported kwargs for itertools.accumulate: "
61                    f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
62                )
63
64            acc = kwargs.get("initial")
65
66            if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
67                seq = args[0].unpack_var_sequence(tx)
68
69                if "func" in kwargs and len(args) == 1:
70                    func = kwargs["func"].call_function
71                elif len(args) == 2:
72                    func = args[1].call_function
73                elif len(args) == 1:
74                    # Default to operator.add
75                    func = BuiltinVariable(operator.add).call_function
76                else:
77                    unimplemented(
78                        "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
79                    )
80            else:
81                unimplemented("Unsupported arguments for itertools.accumulate")
82
83            items = []
84            if acc is not None:
85                items.append(acc)
86            for item in seq:
87                if acc is None:
88                    acc = item
89                else:
90                    try:
91                        acc = func(tx, [acc, item], {})
92                    except Exception as e:
93                        unimplemented(
94                            f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
95                            from_exc=e,
96                        )
97                items.append(acc)
98
99            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
100        elif (
101            self.value is itertools.combinations
102            and not kwargs
103            and len(args) == 2
104            and args[0].has_unpack_var_sequence(tx)
105            and args[1].is_python_constant()
106        ):
107            iterable = args[0].unpack_var_sequence(tx)
108            r = args[1].as_python_constant()
109
110            items = []
111            for item in itertools.combinations(iterable, r):
112                items.append(variables.TupleVariable(list(item)))
113            return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
114        elif self.value is itertools.groupby:
115            if any(kw != "key" for kw in kwargs.keys()):
116                unimplemented(
117                    "Unsupported kwargs for itertools.groupby: "
118                    f"{','.join(set(kwargs.keys()) - {'key'})}"
119                )
120
121            def retrieve_const_key(key):
122                if isinstance(key, variables.SymNodeVariable):
123                    return key.evaluate_expr()
124                elif isinstance(key, variables.ConstantVariable):
125                    return key.as_python_constant()
126                else:
127                    unimplemented(
128                        "Unsupported key type for itertools.groupby: " + str(type(key))
129                    )
130
131            if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
132                seq = args[0].unpack_var_sequence(tx)
133                keyfunc = (
134                    (
135                        lambda x: (
136                            retrieve_const_key(
137                                kwargs.get("key").call_function(tx, [x], {})
138                            )
139                        )
140                    )
141                    if "key" in kwargs
142                    else None
143                )
144            else:
145                unimplemented("Unsupported arguments for itertools.groupby")
146
147            result = []
148            try:
149                for k, v in itertools.groupby(seq, key=keyfunc):
150                    result.append(
151                        variables.TupleVariable(
152                            [
153                                variables.ConstantVariable.create(k)
154                                if variables.ConstantVariable.is_literal(k)
155                                else k,
156                                variables.ListIteratorVariable(
157                                    list(v), mutable_local=MutableLocal()
158                                ),
159                            ],
160                            mutable_local=MutableLocal(),
161                        )
162                    )
163            except Exception as e:
164                unimplemented(
165                    "Unexpected failure when calling itertools.groupby",
166                    from_exc=e,
167                )
168            return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
169        elif self.value is itertools.repeat:
170            if len(args) < 2:
171                return variables.RepeatIteratorVariable(
172                    *args, mutable_local=MutableLocal()
173                )
174
175            from .builder import SourcelessBuilder
176
177            return tx.inline_user_function_return(
178                SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
179            )
180        elif self.value is itertools.count:
181            return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
182        elif self.value is itertools.cycle:
183            return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
184        elif self.value is itertools.dropwhile:
185            return variables.UserFunctionVariable(polyfills.dropwhile).call_function(
186                tx, args, kwargs
187            )
188        elif self.value is itertools.zip_longest:
189            return variables.UserFunctionVariable(polyfills.zip_longest).call_function(
190                tx, args, kwargs
191            )
192        else:
193            return super().call_function(tx, args, kwargs)
194
195
196class IteratorVariable(VariableTracker):
197    def __init__(self, **kwargs) -> None:
198        super().__init__(**kwargs)
199
200    def next_variable(self, tx):
201        unimplemented("abstract method, must implement")
202
203    # NOTE: only call when unpacking this iterator safely done eagerly!
204    # Normally, iterators are accessed lazily.
205    # Example of safe eager unpacking: list(map(f, seq))
206    # Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
207    def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
208        result = []
209        while True:
210            try:
211                result.append(self.next_variable(tx))
212            except ObservedUserStopIteration:
213                handle_observed_exception(tx)
214                break
215        return result
216
217    # don't call force_unpack_var_sequence since it can mutate
218    # IteratorVariable state!
219    def has_force_unpack_var_sequence(self, tx) -> bool:
220        return True
221
222
223class RepeatIteratorVariable(IteratorVariable):
224    def __init__(self, item: VariableTracker, **kwargs) -> None:
225        super().__init__(**kwargs)
226        self.item = item
227
228    # Repeat needs no mutation, clone self
229    def next_variable(self, tx):
230        return self.item
231
232    def reconstruct(self, codegen):
233        codegen.add_push_null(
234            lambda: codegen.extend_output(
235                [
236                    codegen.create_load_python_module(itertools),
237                    codegen.create_load_attr("repeat"),
238                ]
239            )
240        )
241        codegen(self.item)
242        codegen.extend_output(create_call_function(1, False))
243
244
245class CountIteratorVariable(IteratorVariable):
246    def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
247        super().__init__(**kwargs)
248        if not isinstance(item, VariableTracker):
249            item = ConstantVariable.create(item)
250        if not isinstance(step, VariableTracker):
251            step = ConstantVariable.create(step)
252        self.item = item
253        self.step = step
254
255    def next_variable(self, tx):
256        assert self.mutable_local
257        old_item = self.item
258        tx.output.side_effects.mutation(self)
259        self.item = self.item.call_method(tx, "__add__", [self.step], {})
260        return old_item
261
262    def reconstruct(self, codegen):
263        codegen.add_push_null(
264            lambda: codegen.extend_output(
265                [
266                    codegen.create_load_python_module(itertools),
267                    codegen.create_load_attr("count"),
268                ]
269            )
270        )
271        codegen(self.item)
272        codegen(self.step)
273        codegen.extend_output(create_call_function(2, False))
274
275
276class CycleIteratorVariable(IteratorVariable):
277    def __init__(
278        self,
279        iterator: IteratorVariable,
280        saved: List[VariableTracker] = None,
281        saved_index: int = 0,
282        item: Optional[VariableTracker] = None,
283        **kwargs,
284    ) -> None:
285        if saved is None:
286            saved = []
287        super().__init__(**kwargs)
288        self.iterator = iterator
289        self.saved = saved
290        self.saved_index = saved_index
291        self.item = item
292
293    def next_variable(self, tx):
294        assert self.mutable_local
295
296        if self.iterator is not None:
297            try:
298                new_item = self.iterator.next_variable(tx)
299                if len(self.saved) > MAX_ITERATOR_LIMIT:
300                    unimplemented(
301                        "input iterator to itertools.cycle has too many items"
302                    )
303                tx.output.side_effects.mutation(self)
304                self.saved.append(new_item)
305                self.item = new_item
306                if self.item is None:
307                    return self.next_variable(tx)
308                return self.item
309            except ObservedUserStopIteration:
310                handle_observed_exception(tx)
311                self.iterator = None
312                return self.next_variable(tx)
313        elif len(self.saved) > 0:
314            tx.output.side_effects.mutation(self)
315            self.saved_index = (self.saved_index + 1) % len(self.saved)
316            return self.item
317        else:
318            raise_observed_exception(StopIteration, tx, self)
319
320
321class ZipVariable(IteratorVariable):
322    """
323    Represents zip(*iterables)
324    """
325
326    _nonvar_fields = {
327        "index",
328        "strict",
329        *IteratorVariable._nonvar_fields,
330    }
331
332    def __init__(
333        self,
334        iterables: List[Union[List[VariableTracker], VariableTracker]],
335        strict: bool = False,
336        **kwargs,
337    ) -> None:
338        super().__init__(**kwargs)
339        assert isinstance(iterables, list)
340        # can be list[Variable] or VariableTracker (with next_variable implemented)
341        self.iterables = iterables
342        self.index = 0
343        self.strict = strict
344
345    def python_type(self):
346        return zip
347
348    def has_unpack_var_sequence(self, tx) -> bool:
349        return all(
350            isinstance(it, list) or it.has_unpack_var_sequence(tx)
351            for it in self.iterables
352        )
353
354    def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
355        assert self.has_unpack_var_sequence(tx)
356        iterables = []
357        for it in self.iterables:
358            if isinstance(it, list):
359                iterables.append(it[self.index :])
360            else:
361                iterables.append(it.unpack_var_sequence(tx))
362        kwargs = {"strict": self.strict} if self.strict else {}
363        zipped = zip(*iterables, **kwargs)
364        return [variables.TupleVariable(list(var)) for var in zipped]
365
366    def next_variable(self, tx):
367        assert self.mutable_local
368        old_index = self.index
369        args = []
370
371        def get_item(it):
372            if isinstance(it, list):
373                if old_index >= len(it):
374                    raise_observed_exception(StopIteration, tx, self)
375                return it[old_index]
376            else:
377                return it.next_variable(tx)
378
379        try:
380            for idx, it in enumerate(self.iterables):
381                args.append(get_item(it))
382        except ObservedUserStopIteration:
383            if self.strict:
384                if idx == 0:
385                    # all other iterables should be exhausted
386                    for it in self.iterables:
387                        try:
388                            get_item(it)
389                        except ObservedUserStopIteration:
390                            handle_observed_exception(tx)
391                            continue
392                        # no ObservedUserStopIteration - fall through to UserError
393                        break
394                    else:
395                        # all iterables exhausted, raise original error
396                        raise
397                handle_observed_exception(tx)
398                raise UserError(
399                    ValueError,
400                    "zip() has one argument of len differing from others",
401                ) from None
402            raise
403
404        tx.output.side_effects.mutation(self)
405        self.index += 1
406        return variables.TupleVariable(args)
407
408    def reconstruct_items(self, codegen):
409        for it in self.iterables:
410            if isinstance(it, list):
411                remaining_items = it[self.index :]
412                codegen.foreach(remaining_items)
413                codegen.append_output(
414                    create_instruction("BUILD_TUPLE", arg=len(remaining_items))
415                )
416            else:
417                codegen(it)
418
419    def reconstruct(self, codegen):
420        codegen.add_push_null(
421            lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
422        )
423        self.reconstruct_items(codegen)
424        codegen.append_output(
425            create_instruction("BUILD_TUPLE", arg=len(self.iterables))
426        )
427        if sys.version_info >= (3, 10):
428            codegen.extend_output(
429                [
430                    codegen.create_load_const("strict"),
431                    codegen.create_load_const(self.strict),
432                    create_instruction("BUILD_MAP", arg=1),
433                    create_instruction("CALL_FUNCTION_EX", arg=1),
434                ]
435            )
436        else:
437            codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
438
439
440class MapVariable(ZipVariable):
441    """
442    Represents map(fn, *iterables)
443    """
444
445    def __init__(
446        self,
447        fn: VariableTracker,
448        iterables: List[Union[List[VariableTracker], VariableTracker]],
449        **kwargs,
450    ) -> None:
451        super().__init__(iterables, **kwargs)
452        self.fn = fn
453
454    def python_type(self):
455        return map
456
457    def has_unpack_var_sequence(self, tx) -> bool:
458        return False
459
460    def next_variable(self, tx):
461        args = super().next_variable(tx)
462        return self.fn.call_function(tx, args.items, {})
463
464    def reconstruct(self, codegen):
465        codegen.add_push_null(
466            lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
467        )
468        codegen(self.fn)
469        self.reconstruct_items(codegen)
470        codegen.extend_output(
471            [
472                create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
473                create_instruction("CALL_FUNCTION_EX", arg=0),
474            ]
475        )
476