1import operator 2 3import torch 4 5 6def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: 7 """ 8 Annotate the type of getitem nodes, inferred from the type of sequence node. 9 If sequence node is not annotated with a type, do nothing. 10 Currently support getitem nodes from Tuple, List, and NamedTuple sequence node. 11 12 This is helpful since annotations on local names within function are lost during FX transforms. 13 Adding back known type annotation for getitem nodes to improve jit scriptability. 14 15 Args: 16 graph (Graph): The graph to be annotated 17 """ 18 for node in graph.nodes: 19 if node.target == operator.getitem: 20 sequence_node, index_node = node.args 21 if not sequence_node.type: 22 continue 23 # container types 24 if hasattr(sequence_node.type, "_name"): 25 parameterized_types = sequence_node.type.__args__ 26 if sequence_node.type._name == "Tuple": 27 if len(parameterized_types) == 2 and isinstance( 28 parameterized_types[1], type(...) 29 ): 30 node.type = parameterized_types[0] 31 else: 32 assert len(parameterized_types) > index_node 33 node_type = parameterized_types[index_node] 34 node.type = node_type 35 elif sequence_node.type._name == "List": 36 assert len(parameterized_types) == 1 37 node.type = parameterized_types[0] 38 # NamedTuple type 39 elif hasattr(sequence_node.type, "__annotations__"): 40 if sequence_node.type == torch.Tensor: 41 continue 42 sequence_node_field_types = sequence_node.type.__annotations__ 43 field_name = sequence_node.type._fields[index_node] 44 node.type = sequence_node_field_types[field_name] 45