xref: /aosp_15_r20/external/executorch/backends/arm/_passes/arm_pass_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10from typing import Optional
11
12import torch
13import torch.fx
14
15from executorch.exir import ExportedProgram
16from executorch.exir.dialects._ops import ops as exir_ops
17
18from torch._export.utils import (
19    get_buffer,
20    get_lifted_tensor_constant,
21    get_param,
22    is_buffer,
23    is_lifted_tensor_constant,
24    is_param,
25)
26from torch._ops import OpOverload
27from torch._subclasses.fake_tensor import FakeTensor
28
29
30def is_get_attr_node(node: torch.fx.Node) -> bool:
31    """
32    Returns true if the given node is a get attr node for a tensor of the model
33    """
34    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
35
36
37def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
38    return (
39        is_get_attr_node(node)
40        or is_param(exp_prog, node)
41        or is_buffer(exp_prog, node)
42        or is_lifted_tensor_constant(exp_prog, node)
43    )
44
45
46def get_param_tensor(
47    exp_prog: ExportedProgram, node: torch.fx.Node
48) -> Optional[torch.Tensor]:
49    if node is None:
50        return None
51    elif is_param(exp_prog, node):
52        return get_param(exp_prog, node)
53    elif is_buffer(exp_prog, node):
54        return get_buffer(exp_prog, node)
55    elif is_lifted_tensor_constant(exp_prog, node):
56        return get_lifted_tensor_constant(exp_prog, node)
57    elif is_get_attr_node(node):
58        # This is a hack to support both lifted and unlifted graph
59        try:
60            return getattr(node.graph.owning_module, node.target)
61        except AttributeError:
62            return getattr(exp_prog.graph_module, node.target)
63    raise RuntimeError(f"unsupported param type, {node.op}.")
64
65
66def create_node(
67    graph: torch.fx.Graph,
68    op_target: OpOverload,
69    args: tuple = (),
70    kwargs: Optional[dict] = None,
71    quantize: bool = False,
72    q_params: Optional[tuple] = None,
73):
74    """
75    Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
76    If quantize is true and q_params is not None, a q dq pair is inserted after the newly created node.
77    """
78
79    node = graph.create_node(
80        "call_function",
81        op_target,
82        args=args,
83        kwargs=kwargs or {},
84    )
85    if quantize and q_params:
86        return insert_q_dq_pair(graph, node, q_params)
87    return node
88
89
90def insert_q_dq_pair(
91    graph: torch.fx.Graph,
92    anchor: torch.fx.Node,
93    q_params: tuple,
94):
95    """
96    Inserts a q dq node pair after the node 'anchor'.
97    """
98
99    with graph.inserting_after(anchor):
100        q = create_node(
101            graph=graph,
102            op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
103            args=(),  # We add the argument last
104        )
105        q.meta = anchor.meta
106    with graph.inserting_after(q):
107        dq = create_node(
108            graph=graph,
109            op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
110            args=(q,) + q_params,
111        )
112        dq.meta = q.meta
113    anchor.replace_all_uses_with(dq)
114    # We add this last so the replace all uses above does not replace the quantized
115    # node's first use
116    q.args = (anchor,) + q_params
117    return dq
118
119
120def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
121    """
122    Returns a FakeTensor from the meta field of 'node'.
123    If the node contains many fake tensors, return the first one.
124    """
125    if isinstance(
126        node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
127    ):
128        fake_tensor = node.meta["val"][0]
129    else:
130        fake_tensor = node.meta["val"]
131
132    assert isinstance(
133        fake_tensor, FakeTensor
134    ), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
135    return fake_tensor
136