xref: /aosp_15_r20/external/executorch/backends/xnnpack/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, cast, Optional, Tuple
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
10*523fa7a6SAndroid Build Coastguard Workerimport torch
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.xnnpack.utils.configs import (
13*523fa7a6SAndroid Build Coastguard Worker    get_transform_passes,
14*523fa7a6SAndroid Build Coastguard Worker    get_xnnpack_capture_config,
15*523fa7a6SAndroid Build Coastguard Worker    get_xnnpack_edge_compile_config,
16*523fa7a6SAndroid Build Coastguard Worker)
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExportedProgram
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import (
21*523fa7a6SAndroid Build Coastguard Worker    get_buffer,
22*523fa7a6SAndroid Build Coastguard Worker    get_lifted_tensor_constant,
23*523fa7a6SAndroid Build Coastguard Worker    get_param,
24*523fa7a6SAndroid Build Coastguard Worker    is_buffer,
25*523fa7a6SAndroid Build Coastguard Worker    is_lifted_tensor_constant,
26*523fa7a6SAndroid Build Coastguard Worker    is_param,
27*523fa7a6SAndroid Build Coastguard Worker)
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker### XNNPACK Capture ###
31*523fa7a6SAndroid Build Coastguard Workerdef capture_graph_for_xnnpack(
32*523fa7a6SAndroid Build Coastguard Worker    module: torch.nn.Module,
33*523fa7a6SAndroid Build Coastguard Worker    inputs: Tuple[torch.Tensor],
34*523fa7a6SAndroid Build Coastguard Worker    enable_aot: Optional[bool] = None,
35*523fa7a6SAndroid Build Coastguard Worker    unlift: Optional[bool] = None,
36*523fa7a6SAndroid Build Coastguard Worker) -> exir.ExirExportedProgram:
37*523fa7a6SAndroid Build Coastguard Worker    return (
38*523fa7a6SAndroid Build Coastguard Worker        exir.capture(
39*523fa7a6SAndroid Build Coastguard Worker            module,
40*523fa7a6SAndroid Build Coastguard Worker            inputs,
41*523fa7a6SAndroid Build Coastguard Worker            get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift),
42*523fa7a6SAndroid Build Coastguard Worker        )
43*523fa7a6SAndroid Build Coastguard Worker        .to_edge(get_xnnpack_edge_compile_config())
44*523fa7a6SAndroid Build Coastguard Worker        .transform(*get_transform_passes())
45*523fa7a6SAndroid Build Coastguard Worker    )
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker### XNNPACK Utils ###
49*523fa7a6SAndroid Build Coastguard WorkerPERM_NCHW_TO_NHWC = [0, 2, 3, 1]
50*523fa7a6SAndroid Build Coastguard WorkerPERM_NHWC_TO_NCHW = [0, 3, 1, 2]
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Workerdef check_or_raise(condition: bool, err: str) -> None:
54*523fa7a6SAndroid Build Coastguard Worker    """
55*523fa7a6SAndroid Build Coastguard Worker    Raises runtime error if condition is false, with the given error message
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    Args:
58*523fa7a6SAndroid Build Coastguard Worker        condition: boolean condition to check
59*523fa7a6SAndroid Build Coastguard Worker        err: error message to raise if condition is not true
60*523fa7a6SAndroid Build Coastguard Worker    """
61*523fa7a6SAndroid Build Coastguard Worker    if not condition:
62*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(err)
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Workerdef is_node(node: Any) -> bool:
66*523fa7a6SAndroid Build Coastguard Worker    """
67*523fa7a6SAndroid Build Coastguard Worker    returns true if node is a torch.fx.Node, otherwise false
68*523fa7a6SAndroid Build Coastguard Worker    """
69*523fa7a6SAndroid Build Coastguard Worker    return isinstance(node, torch.fx.Node)
70*523fa7a6SAndroid Build Coastguard Worker
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Workerdef is_getitem(node: torch.fx.Node) -> bool:
73*523fa7a6SAndroid Build Coastguard Worker    if node.op != "call_function":
74*523fa7a6SAndroid Build Coastguard Worker        return False
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    return node.target.__name__ == "getitem"  # pyre-ignore
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Workerdef get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
80*523fa7a6SAndroid Build Coastguard Worker    return cast(torch.fx.Node, node.args[input_index])
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Workerdef get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
84*523fa7a6SAndroid Build Coastguard Worker    """
85*523fa7a6SAndroid Build Coastguard Worker    Checks if the current node is only consumed by a relu node and can be fused,
86*523fa7a6SAndroid Build Coastguard Worker    if so, we return the relu node that can be fused, otherwise return None
87*523fa7a6SAndroid Build Coastguard Worker    """
88*523fa7a6SAndroid Build Coastguard Worker    if (
89*523fa7a6SAndroid Build Coastguard Worker        len(node.users) == 1
90*523fa7a6SAndroid Build Coastguard Worker        and list(node.users.keys())[0].target == exir_ops.edge.aten.relu.default
91*523fa7a6SAndroid Build Coastguard Worker    ):
92*523fa7a6SAndroid Build Coastguard Worker        relu_node = list(node.users.keys())[0]
93*523fa7a6SAndroid Build Coastguard Worker        return relu_node
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker    return None
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool:
99*523fa7a6SAndroid Build Coastguard Worker    """
100*523fa7a6SAndroid Build Coastguard Worker    Returns true if the given node is a get attr node for a tensor of the model
101*523fa7a6SAndroid Build Coastguard Worker    """
102*523fa7a6SAndroid Build Coastguard Worker    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
106*523fa7a6SAndroid Build Coastguard Worker    return (
107*523fa7a6SAndroid Build Coastguard Worker        is_get_attr_node(node)
108*523fa7a6SAndroid Build Coastguard Worker        or is_param(exp_prog, node)
109*523fa7a6SAndroid Build Coastguard Worker        or is_buffer(exp_prog, node)
110*523fa7a6SAndroid Build Coastguard Worker        or is_lifted_tensor_constant(exp_prog, node)
111*523fa7a6SAndroid Build Coastguard Worker    )
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Workerdef get_param_tensor(
115*523fa7a6SAndroid Build Coastguard Worker    exp_prog: ExportedProgram, node: torch.fx.Node
116*523fa7a6SAndroid Build Coastguard Worker) -> Optional[torch.Tensor]:
117*523fa7a6SAndroid Build Coastguard Worker    if node is None:
118*523fa7a6SAndroid Build Coastguard Worker        return None
119*523fa7a6SAndroid Build Coastguard Worker    elif is_param(exp_prog, node):
120*523fa7a6SAndroid Build Coastguard Worker        return get_param(exp_prog, node)
121*523fa7a6SAndroid Build Coastguard Worker    elif is_buffer(exp_prog, node):
122*523fa7a6SAndroid Build Coastguard Worker        return get_buffer(exp_prog, node)
123*523fa7a6SAndroid Build Coastguard Worker    elif is_lifted_tensor_constant(exp_prog, node):
124*523fa7a6SAndroid Build Coastguard Worker        return get_lifted_tensor_constant(exp_prog, node)
125*523fa7a6SAndroid Build Coastguard Worker    elif is_get_attr_node(node):
126*523fa7a6SAndroid Build Coastguard Worker        # This is a hack to support both lifted and unlifted graph
127*523fa7a6SAndroid Build Coastguard Worker        try:
128*523fa7a6SAndroid Build Coastguard Worker            return getattr(node.graph.owning_module, node.target)
129*523fa7a6SAndroid Build Coastguard Worker        except AttributeError:
130*523fa7a6SAndroid Build Coastguard Worker            return getattr(exp_prog.graph_module, node.target)
131*523fa7a6SAndroid Build Coastguard Worker    raise RuntimeError(f"unsupported param type, {node.op}.")
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker
134*523fa7a6SAndroid Build Coastguard Workerdef get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
135*523fa7a6SAndroid Build Coastguard Worker    """
136*523fa7a6SAndroid Build Coastguard Worker    Returns the source fn of the given node, return None if something goes wrong
137*523fa7a6SAndroid Build Coastguard Worker    """
138*523fa7a6SAndroid Build Coastguard Worker    if (
139*523fa7a6SAndroid Build Coastguard Worker        node.op != "call_function"
140*523fa7a6SAndroid Build Coastguard Worker        or (source_fn_st := node.meta.get("source_fn_stack", None)) is None
141*523fa7a6SAndroid Build Coastguard Worker    ):
142*523fa7a6SAndroid Build Coastguard Worker        return None
143*523fa7a6SAndroid Build Coastguard Worker    source_fn = source_fn_st[-1]
144*523fa7a6SAndroid Build Coastguard Worker    return source_fn[1]
145