xref: /aosp_15_r20/external/pytorch/torch/_dynamo/source.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import dataclasses
4import enum
5from typing import Any, Optional, Union
6
7from torch._guards import ChainedSource, GuardSource, Source
8
9from . import utils
10from .bytecode_transformation import create_call_function, create_instruction
11from .utils import enum_repr
12
13
14# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
15# so those cases are omitted intentionally
16
17# represents nn.Modules tracked with NNModuleVariable (specialized is implicit in the variable name)
18_GUARD_SOURCE_SPECIALIZED_NN_MODULE = {
19    GuardSource.LOCAL: GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
20    GuardSource.GLOBAL: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
21    GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
22    GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_SPECIALIZED_NN_MODULE,
23    # Just to ensure that guard_source() works
24    GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
25    GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
26    GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
27    GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
28}
29
30# represents nn.Modules tracked with UnspecializedNNModuleVariable
31_GUARD_SOURCE_UNSPECIALIZED_NN_MODULE = {
32    GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
33    GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
34    GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
35    GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
36    # this happens for an UnspecializedNNModule submodule on a NNModuleVariable
37    GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE,
38    GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE,
39    # Just to ensure that guard_source() works
40    GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
41    GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
42}
43
44# represents nn.Modules tracked with UnspecializedBuiltinNNModuleVariable
45_GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE = {
46    GuardSource.LOCAL: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
47    GuardSource.GLOBAL: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
48    GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
49    GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
50    GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
51    GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
52    # Just to ensure that guard_source() works
53    GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
54    GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
55}
56
57_GUARD_SOURCE_FSDP_MODULE = {
58    GuardSource.LOCAL: GuardSource.LOCAL_FSDP_MODULE,
59    GuardSource.GLOBAL: GuardSource.GLOBAL_FSDP_MODULE,
60    GuardSource.LOCAL_SPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
61    GuardSource.GLOBAL_SPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
62    GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE,
63    GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
64    GuardSource.LOCAL_UNSPECIALIZED_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
65    GuardSource.GLOBAL_UNSPECIALIZED_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
66    GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.LOCAL_FSDP_MODULE,
67    GuardSource.GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE: GuardSource.GLOBAL_FSDP_MODULE,
68}
69
70
71def is_constant_source(source):
72    if isinstance(source, ConstantSource):
73        return True
74    try:
75        if source.guard_source() == GuardSource.CONSTANT:
76            return True
77    except NotImplementedError:
78        pass
79
80    return False
81
82
83def reconstruct_getitem(
84    source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice
85):
86    source.base.reconstruct(codegen)
87    if isinstance(source.index, Source):
88        source.index.reconstruct(codegen)
89    else:
90        if index_is_slice:
91            assert isinstance(source, GetItemSource)
92            codegen.append_output(codegen.create_load_const(source.unpack_slice()))
93        else:
94            codegen.append_output(codegen.create_load_const(source.index))
95
96
97@dataclasses.dataclass(frozen=True)
98class LocalSource(Source):
99    local_name: str
100    cell_or_freevar: bool = False
101
102    def reconstruct(self, codegen):
103        codegen.append_output(codegen.create_load(self.local_name))
104
105    def guard_source(self):
106        return GuardSource.LOCAL
107
108    def name(self):
109        return f"L[{repr(self.local_name)}]"
110
111
112@dataclasses.dataclass(frozen=True)
113class SyntheticLocalSource(Source):
114    local_name: str
115
116    def reconstruct(self, codegen):
117        codegen.append_output(codegen.create_load(self.local_name))
118
119    def guard_source(self):
120        return GuardSource.SYNTHETIC_LOCAL
121
122    def name(self):
123        return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
124
125
126@dataclasses.dataclass(frozen=True)
127class RandomValueSource(Source):
128    random_call_index: int
129
130    def guard_source(self):
131        return GuardSource.RANDOM_VALUE
132
133    def reconstruct(self, codegen):
134        codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
135        codegen.append_output(codegen.create_load_const(self.random_call_index))
136        codegen.append_output(create_instruction("BINARY_SUBSCR"))
137
138    def name(self):
139        return f"random_value_{self.random_call_index}"
140
141
142@dataclasses.dataclass(frozen=True)
143class GlobalSource(Source):
144    global_name: str
145
146    def reconstruct(self, codegen):
147        codegen.append_output(codegen.create_load_global(self.global_name, add=True))
148
149    def guard_source(self):
150        return GuardSource.GLOBAL
151
152    def name(self):
153        return f"G[{repr(self.global_name)}]"
154
155
156@dataclasses.dataclass(frozen=True)
157class GlobalWeakRefSource(Source):
158    global_name: str
159
160    def reconstruct(self, codegen):
161        codegen.add_push_null(
162            lambda: codegen.append_output(
163                codegen.create_load_global(self.global_name, add=True)
164            )
165        )
166        codegen.extend_output(create_call_function(0, False))
167
168    def guard_source(self):
169        return GuardSource.GLOBAL
170
171    def name(self):
172        return f"G[{repr(self.global_name)}]()"
173
174
175@dataclasses.dataclass(frozen=True)
176class WeakRefCallSource(ChainedSource):
177    def reconstruct(self, codegen):
178        codegen.add_push_null(lambda: self.base.reconstruct(codegen))
179        codegen.extend_output(create_call_function(0, False))
180
181    def guard_source(self):
182        return self.base.guard_source()
183
184    def name(self):
185        return f"{self.base.name()}()"
186
187
188@dataclasses.dataclass(frozen=True)
189class AttrSource(ChainedSource):
190    member: str
191
192    def __post_init__(self):
193        assert self.base, "Can't construct an AttrSource without a valid base source"
194        if "." in self.member:
195            member_parts = self.member.split(".")
196            object.__setattr__(
197                self, "base", AttrSource(self.base, ".".join(member_parts[:-1]))
198            )
199            object.__setattr__(self, "member", member_parts[-1])
200
201    def reconstruct(self, codegen):
202        self.base.reconstruct(codegen)
203        codegen.extend_output(codegen.create_load_attrs(self.member))
204
205    def guard_source(self):
206        return self.base.guard_source()
207
208    def name(self):
209        if not self.member.isidentifier():
210            return f"getattr({self.base.name()}, {self.member!r})"
211        return f"{self.base.name()}.{self.member}"
212
213
214# Represents tensor.grad source. It could be represented by AttrSource as well.
215# But, we could access grad field on tensor directly in C++ without going
216# through the Python bytecodes. Therefore, we use a separate source for grad
217# field.
218@dataclasses.dataclass(frozen=True)
219class GradSource(ChainedSource):
220    member: str = "grad"
221
222    def reconstruct(self, codegen):
223        self.base.reconstruct(codegen)
224        codegen.extend_output(codegen.create_load_attrs(self.member))
225
226    def guard_source(self):
227        return self.base.guard_source()
228
229    def name(self):
230        return f"{self.base.name()}.{self.member}"
231
232
233@dataclasses.dataclass(frozen=True)
234class ParamBufferSource(AttrSource):
235    def guard_source(self):
236        return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
237
238
239# Special AttrSource to differentiate module._buffers or module._parameters
240@dataclasses.dataclass(frozen=True)
241class UnspecializedParamBufferSource(AttrSource):
242    pass
243
244
245# This source is intended to be used in places where a source is needed but it is expected
246# that the symbol will be simplified out later on. Symbols with ephemeral sources are
247# prioritized to be simplified out when e.g. compared against a symbol without an ephemeral
248# source. Guarding on this source is an error.
249#
250# Example: During subclass view fake-ification, any close-over ViewFunc state should be
251# symbolicized / fake-ified to avoid invalid specialization during view replay. This source
252# is useful for symbols utilized in the middle of the view chain that are not expected to be
253# present within the final view shape metadata.
254@dataclasses.dataclass(frozen=True)
255class EphemeralSource(Source):
256    desc: Optional[str] = None
257
258    def guard_source(self):
259        return GuardSource.EPHEMERAL
260
261    def name(self):
262        return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
263
264    def make_guard(self):
265        raise NotImplementedError
266
267    def is_ephemeral(self):
268        return True
269
270
271class TensorProperty(enum.Enum):
272    SIZE = 0
273    STRIDE = 1
274    STORAGE_OFFSET = 2
275
276    def method_name(self):
277        if self is TensorProperty.SIZE:
278            return "size"
279        elif self is TensorProperty.STRIDE:
280            return "stride"
281        elif self is TensorProperty.STORAGE_OFFSET:
282            return "storage_offset"
283
284
285@dataclasses.dataclass(frozen=True)
286class TensorPropertySource(ChainedSource):
287    prop: TensorProperty
288    idx: Optional[int] = None  # None for STORAGE_OFFSET
289
290    def __post_init__(self):
291        assert self.base is not None
292        if self.prop is TensorProperty.STORAGE_OFFSET:
293            assert self.idx is None
294        else:
295            assert self.idx is not None
296
297    def reconstruct(self, codegen):
298        def gen_fn():
299            self.base.reconstruct(codegen)
300            codegen.append_output(codegen.create_load_attr(self.prop.method_name()))
301
302        codegen.add_push_null(gen_fn)
303        if self.idx is not None:
304            codegen.append_output(codegen.create_load_const(self.idx))
305        codegen.extend_output(
306            create_call_function(1 if self.idx is not None else 0, False)
307        )
308
309    def guard_source(self):
310        return self.base.guard_source()
311
312    def name(self):
313        if self.prop is TensorProperty.SIZE:
314            return f"{self.base.name()}.size()[{self.idx}]"
315        elif self.prop is TensorProperty.STRIDE:
316            return f"{self.base.name()}.stride()[{self.idx}]"
317        elif self.prop is TensorProperty.STORAGE_OFFSET:
318            assert self.idx is None
319            return f"{self.base.name()}.storage_offset()"
320        else:
321            raise AssertionError(f"unhandled {self.prop}")
322
323
324@dataclasses.dataclass(frozen=True)
325class NegateSource(ChainedSource):
326    def __post_init__(self):
327        assert self.base is not None
328
329    def reconstruct(self, codegen):
330        raise NotImplementedError
331
332    def guard_source(self):
333        return self.base.guard_source()
334
335    def name(self):
336        # NB: use method call so that function stripping regexes work
337        return f"{self.base.name()}.__neg__()"
338
339
340@dataclasses.dataclass(frozen=True)
341class ConvertIntSource(ChainedSource):
342    def __post_init__(self):
343        assert self.base is not None
344
345    def reconstruct(self, codegen):
346        self.base.reconstruct(codegen)
347
348    def guard_source(self):
349        return self.base.guard_source()
350
351    def name(self):
352        return f"cast_symbool_to_symint_guardless({self.base.name()})"
353
354
355@dataclasses.dataclass(frozen=True)
356class FlattenScriptObjectSource(ChainedSource):
357    def __post_init__(self):
358        assert self.base is not None
359
360    def reconstruct(self, codegen):
361        self.base.reconstruct(codegen)
362
363    def guard_source(self):
364        return self.base.guard_source()
365
366    def name(self):
367        return f"{self.base.name()}.__obj_flatten__()"
368
369
370@dataclasses.dataclass(frozen=True)
371class ScriptObjectQualifiedNameSource(ChainedSource):
372    def __post_init__(self):
373        assert self.base is not None
374
375    def reconstruct(self, codegen):
376        self.base.reconstruct(codegen)
377
378    def guard_source(self):
379        return self.base.guard_source()
380
381    def name(self):
382        return f"{self.base.name()}._type().qualified_name()"
383
384
385class AttrProxySource(ChainedSource):
386    def reconstruct(self, codegen):
387        self.base.reconstruct(codegen)
388
389    def guard_source(self):
390        return self.base.guard_source()
391
392    def name(self):
393        return f"{self.base.name()}.get_base()"
394
395
396@dataclasses.dataclass(frozen=True)
397class DefaultsSource(ChainedSource):
398    idx_key: Union[int, str]
399    is_kw: bool = False
400    field: str = dataclasses.field(init=False, repr=False, compare=False)
401    _name: str = dataclasses.field(init=False, repr=False, compare=False)
402
403    def __post_init__(self):
404        assert (
405            self.base
406        ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
407        if self.is_kw:
408            assert isinstance(self.idx_key, str)
409            object.__setattr__(self, "field", "__kwdefaults__")
410            object.__setattr__(
411                self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']"
412            )
413        else:
414            assert isinstance(self.idx_key, int)
415            object.__setattr__(self, "field", "__defaults__")
416            object.__setattr__(
417                self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
418            )
419
420    def reconstruct(self, codegen):
421        self.base.reconstruct(codegen)
422        codegen.extend_output(codegen.create_load_attrs(self.field))
423        codegen.append_output(codegen.create_load_const(self.idx_key))
424        codegen.append_output(create_instruction("BINARY_SUBSCR"))
425
426    def guard_source(self):
427        return self.base.guard_source()
428
429    def name(self):
430        return self._name
431
432
433@dataclasses.dataclass(frozen=True)
434class GetItemSource(ChainedSource):
435    index: Any
436    index_is_slice: bool = False
437
438    def __post_init__(self):
439        assert self.base is not None
440        if isinstance(self.index, slice):
441            # store the hashable version of the slice so the whole GetItemSource is hashable
442            super().__setattr__("index", self.index.__reduce__())
443            super().__setattr__("index_is_slice", True)
444
445    def reconstruct(self, codegen):
446        reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice)
447        codegen.append_output(create_instruction("BINARY_SUBSCR"))
448
449    def guard_source(self):
450        return self.base.guard_source()
451
452    def unpack_slice(self):
453        assert self.index_is_slice
454        slice_class, slice_args = self.index
455        return slice_class(*slice_args)
456
457    def name(self):
458        # Index can be of following types
459        # 1) ConstDictKeySource
460        # 2) enum.Enum
461        # 3) index is a slice - example 1:4
462        # 4) index is a constant - example string, integer
463        if isinstance(self.index, Source):
464            if not isinstance(self.index, ConstDictKeySource):
465                raise ValueError(
466                    "GetItemSource index must be a constant, enum or ConstDictKeySource"
467                )
468            return f"{self.base.name()}[{self.index.name()}]"
469        elif self.index_is_slice:
470            return f"{self.base.name()}[{self.unpack_slice()!r}]"
471        elif isinstance(self.index, enum.Enum):
472            return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]"
473        else:
474            return f"{self.base.name()}[{self.index!r}]"
475
476
477@dataclasses.dataclass(frozen=True)
478class ConstDictKeySource(GetItemSource):
479    def is_dict_key(self):
480        return True
481
482    def reconstruct(self, codegen):
483        codegen.add_push_null(
484            lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
485        )
486        self.base.reconstruct(codegen)
487        codegen.append_output(codegen.create_load_const(self.index))
488        codegen.extend_output(create_call_function(2, False))
489
490    def name(self):
491        # The list creation will be CSE'd by PyExprCSEPass
492        return f"list({self.base.name()}.keys())[{self.index!r}]"
493
494
495@dataclasses.dataclass(frozen=True)
496class TupleIteratorGetItemSource(GetItemSource):
497    def reconstruct(self, codegen):
498        codegen.add_push_null(
499            lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
500        )
501        self.base.reconstruct(codegen)
502        codegen.append_output(codegen.create_load_const(self.index))
503        codegen.extend_output(create_call_function(2, False))
504
505    def name(self):
506        return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
507
508
509@dataclasses.dataclass(frozen=True)
510class TypeSource(ChainedSource):
511    def __post_init__(self):
512        assert self.base is not None
513
514    def reconstruct(self, codegen):
515        codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
516        self.base.reconstruct(codegen)
517        codegen.extend_output(create_call_function(1, False))
518
519    def guard_source(self):
520        return self.base.guard_source()
521
522    def name(self):
523        return f"type({self.base.name()})"
524
525
526@dataclasses.dataclass(frozen=True)
527class ODictGetItemSource(ChainedSource):
528    index: Any
529
530    def __post_init__(self):
531        assert self.base is not None
532
533    def reconstruct(self, codegen):
534        codegen.add_push_null(
535            lambda: codegen.append_output(
536                codegen._create_load_const(collections.OrderedDict.__getitem__)
537            )
538        )
539        reconstruct_getitem(self, codegen, index_is_slice=False)
540        codegen.extend_output(create_call_function(2, False))
541
542    def guard_source(self):
543        return self.base.guard_source()
544
545    def name(self):
546        if isinstance(self.index, type):
547            rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}'
548            return f"___odict_getitem({self.base.name()}, {rep})"
549        elif isinstance(self.index, Source):
550            return f"___odict_getitem({self.base.name()}, {self.index.name()})"
551        else:
552            return f"___odict_getitem({self.base.name()}, {self.index!r})"
553
554
555@dataclasses.dataclass(frozen=True)
556class OptimizerSource(ChainedSource):
557    def reconstruct(self, codegen):
558        self.base.reconstruct(codegen)
559
560    def guard_source(self):
561        return self.base.guard_source()
562
563    def name(self):
564        return self.base.name()
565
566
567@dataclasses.dataclass(frozen=True)
568class NNModuleSource(ChainedSource):
569    def reconstruct(self, codegen):
570        self.base.reconstruct(codegen)
571
572    def guard_source(self):
573        return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
574
575    def name(self):
576        return self.base.name()
577
578
579@dataclasses.dataclass(frozen=True)
580class UnspecializedNNModuleSource(NNModuleSource):
581    def guard_source(self):
582        return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
583
584
585@dataclasses.dataclass(frozen=True)
586class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
587    def guard_source(self):
588        return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
589
590
591@dataclasses.dataclass(frozen=True)
592class FSDPNNModuleSource(NNModuleSource):
593    def guard_source(self):
594        return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
595
596
597@dataclasses.dataclass(frozen=True)
598class GlobalStateSource(Source):
599    def name(self):
600        return ""
601
602    def guard_source(self):
603        return GuardSource.GLOBAL
604
605
606@dataclasses.dataclass(frozen=True)
607class TorchFunctionModeStackSource(Source):
608    ind: int
609
610    def name(self):
611        return ""
612
613    def _get_index(self):
614        from .variables.torch_function import TorchFunctionModeStackVariable
615
616        return TorchFunctionModeStackVariable.get_mode_index(self.ind)
617
618    def reconstruct(self, codegen):
619        codegen.add_push_null(
620            lambda: codegen.load_import_from(
621                utils.__name__, "get_torch_function_mode_stack_at"
622            )
623        )
624        codegen.extend_output([codegen.create_load_const(self._get_index())])
625        codegen.extend_output(create_call_function(1, False))
626
627    def guard_source(self):
628        return GuardSource.GLOBAL
629
630
631@dataclasses.dataclass(frozen=True)
632class ConstantSource(Source):
633    source_name: str
634
635    def reconstruct(self, codegen):
636        codegen.append_output(codegen.create_load_global(self.source_name, add=False))
637
638    def guard_source(self):
639        return GuardSource.CONSTANT
640
641    def name(self):
642        return self.source_name
643
644    def make_guard(self, fn):
645        raise NotImplementedError
646
647
648@dataclasses.dataclass(frozen=True)
649class NumpyTensorSource(ChainedSource):
650    def name(self) -> str:
651        return f"___from_numpy({self.base.name()})"
652
653    def guard_source(self):
654        return self.base.guard_source()
655
656    def reconstruct(self, codegen):
657        codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
658        self.base.reconstruct(codegen)
659        codegen.extend_output(create_call_function(1, False))
660
661
662@dataclasses.dataclass(frozen=True)
663class SubclassAttrListSource(ChainedSource):
664    def name(self) -> str:
665        return f"{self.base.name()}.__tensor_flatten__()[0]"
666
667    def guard_source(self):
668        return self.base.guard_source()
669
670
671# NB: We don't expect you to actually ever generate guards against this
672# source, it is ephemeral
673@dataclasses.dataclass(frozen=True)
674class FloatTensorSource(ChainedSource):
675    def name(self) -> str:
676        return f"___as_tensor({self.base.name()})"
677
678    def guard_source(self):
679        return self.base.guard_source()
680
681
682@dataclasses.dataclass(frozen=True)
683class CallMethodItemSource(ChainedSource):
684    def name(self) -> str:
685        return f"{self.base.name()}.item()"
686
687    def guard_source(self):
688        return self.base.guard_source()
689
690
691# This is a synthetic source that is associated with the singleton
692# shape env guard we always register for all frames.  We get the actual
693# guard contents from the ambient ShapeEnv
694@dataclasses.dataclass(frozen=True)
695class ShapeEnvSource(Source):
696    def name(self):
697        return ""
698
699    def guard_source(self):
700        return GuardSource.SHAPE_ENV
701
702
703@dataclasses.dataclass(frozen=True)
704class BackwardStateSource(Source):
705    def name(self):
706        return ""
707
708    def guard_source(self):
709        return GuardSource.BACKWARD_STATE
710
711
712def is_from_local_source(source: Source, *, allow_cell_or_freevar=True):
713    if isinstance(source, ChainedSource):
714        return is_from_local_source(
715            source.base, allow_cell_or_freevar=allow_cell_or_freevar
716        )
717    if not isinstance(source, LocalSource):
718        return False
719    if not allow_cell_or_freevar and source.cell_or_freevar:
720        return False
721    return True
722
723
724def is_from_unspecialized_param_buffer_source(source: Source):
725    if isinstance(source, UnspecializedParamBufferSource):
726        return True
727    if isinstance(source, ChainedSource):
728        return is_from_unspecialized_param_buffer_source(source.base)
729    return False
730
731
732def is_from_flatten_script_object_source(source: Source):
733    if isinstance(source, FlattenScriptObjectSource):
734        return True
735    elif isinstance(source, ChainedSource):
736        return is_from_flatten_script_object_source(source.base)
737    return False
738
739
740def is_from_optimizer_source(source: Source):
741    if isinstance(source, OptimizerSource):
742        return True
743    if isinstance(source, ChainedSource):
744        return is_from_optimizer_source(source.base)
745    return False
746
747
748# TODO: can probably write a generic "test this on everything in the chain"
749# helper
750def is_from_defaults(source: Source):
751    if isinstance(source, DefaultsSource):
752        return True
753    if isinstance(source, ChainedSource):
754        return is_from_defaults(source.base)
755    return False
756
757
758def is_cell_contents(source: Source):
759    return isinstance(source, AttrSource) and source.member == "cell_contents"
760