xref: /aosp_15_r20/external/pytorch/torch/fx/passes/annotate_getitem_nodes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import operator
2
3import torch
4
5
6def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
7    """
8    Annotate the type of getitem nodes, inferred from the type of sequence node.
9    If sequence node is not annotated with a type, do nothing.
10    Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
11
12    This is helpful since annotations on local names within function are lost during FX transforms.
13    Adding back known type annotation for getitem nodes to improve jit scriptability.
14
15    Args:
16        graph (Graph): The graph to be annotated
17    """
18    for node in graph.nodes:
19        if node.target == operator.getitem:
20            sequence_node, index_node = node.args
21            if not sequence_node.type:
22                continue
23            # container types
24            if hasattr(sequence_node.type, "_name"):
25                parameterized_types = sequence_node.type.__args__
26                if sequence_node.type._name == "Tuple":
27                    if len(parameterized_types) == 2 and isinstance(
28                        parameterized_types[1], type(...)
29                    ):
30                        node.type = parameterized_types[0]
31                    else:
32                        assert len(parameterized_types) > index_node
33                        node_type = parameterized_types[index_node]
34                        node.type = node_type
35                elif sequence_node.type._name == "List":
36                    assert len(parameterized_types) == 1
37                    node.type = parameterized_types[0]
38            # NamedTuple type
39            elif hasattr(sequence_node.type, "__annotations__"):
40                if sequence_node.type == torch.Tensor:
41                    continue
42                sequence_node_field_types = sequence_node.type.__annotations__
43                field_name = sequence_node.type._fields[index_node]
44                node.type = sequence_node_field_types[field_name]
45