xref: /aosp_15_r20/external/executorch/backends/xnnpack/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Any, cast, Optional, Tuple
8
9import executorch.exir as exir
10import torch
11
12from executorch.backends.xnnpack.utils.configs import (
13    get_transform_passes,
14    get_xnnpack_capture_config,
15    get_xnnpack_edge_compile_config,
16)
17from executorch.exir import ExportedProgram
18from executorch.exir.dialects._ops import ops as exir_ops
19
20from torch._export.utils import (
21    get_buffer,
22    get_lifted_tensor_constant,
23    get_param,
24    is_buffer,
25    is_lifted_tensor_constant,
26    is_param,
27)
28
29
30### XNNPACK Capture ###
31def capture_graph_for_xnnpack(
32    module: torch.nn.Module,
33    inputs: Tuple[torch.Tensor],
34    enable_aot: Optional[bool] = None,
35    unlift: Optional[bool] = None,
36) -> exir.ExirExportedProgram:
37    return (
38        exir.capture(
39            module,
40            inputs,
41            get_xnnpack_capture_config(enable_aot=enable_aot, unlift=unlift),
42        )
43        .to_edge(get_xnnpack_edge_compile_config())
44        .transform(*get_transform_passes())
45    )
46
47
48### XNNPACK Utils ###
49PERM_NCHW_TO_NHWC = [0, 2, 3, 1]
50PERM_NHWC_TO_NCHW = [0, 3, 1, 2]
51
52
53def check_or_raise(condition: bool, err: str) -> None:
54    """
55    Raises runtime error if condition is false, with the given error message
56
57    Args:
58        condition: boolean condition to check
59        err: error message to raise if condition is not true
60    """
61    if not condition:
62        raise RuntimeError(err)
63
64
65def is_node(node: Any) -> bool:
66    """
67    returns true if node is a torch.fx.Node, otherwise false
68    """
69    return isinstance(node, torch.fx.Node)
70
71
72def is_getitem(node: torch.fx.Node) -> bool:
73    if node.op != "call_function":
74        return False
75
76    return node.target.__name__ == "getitem"  # pyre-ignore
77
78
79def get_input_node(node: torch.fx.Node, input_index: int) -> torch.fx.Node:
80    return cast(torch.fx.Node, node.args[input_index])
81
82
83def get_relu_fused_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
84    """
85    Checks if the current node is only consumed by a relu node and can be fused,
86    if so, we return the relu node that can be fused, otherwise return None
87    """
88    if (
89        len(node.users) == 1
90        and list(node.users.keys())[0].target == exir_ops.edge.aten.relu.default
91    ):
92        relu_node = list(node.users.keys())[0]
93        return relu_node
94
95    return None
96
97
98def is_get_attr_node(node: torch.fx.Node) -> bool:
99    """
100    Returns true if the given node is a get attr node for a tensor of the model
101    """
102    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
103
104
105def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
106    return (
107        is_get_attr_node(node)
108        or is_param(exp_prog, node)
109        or is_buffer(exp_prog, node)
110        or is_lifted_tensor_constant(exp_prog, node)
111    )
112
113
114def get_param_tensor(
115    exp_prog: ExportedProgram, node: torch.fx.Node
116) -> Optional[torch.Tensor]:
117    if node is None:
118        return None
119    elif is_param(exp_prog, node):
120        return get_param(exp_prog, node)
121    elif is_buffer(exp_prog, node):
122        return get_buffer(exp_prog, node)
123    elif is_lifted_tensor_constant(exp_prog, node):
124        return get_lifted_tensor_constant(exp_prog, node)
125    elif is_get_attr_node(node):
126        # This is a hack to support both lifted and unlifted graph
127        try:
128            return getattr(node.graph.owning_module, node.target)
129        except AttributeError:
130            return getattr(exp_prog.graph_module, node.target)
131    raise RuntimeError(f"unsupported param type, {node.op}.")
132
133
134def get_source_fn(node: torch.fx.Node) -> Optional[torch.fx.Node]:
135    """
136    Returns the source fn of the given node, return None if something goes wrong
137    """
138    if (
139        node.op != "call_function"
140        or (source_fn_st := node.meta.get("source_fn_stack", None)) is None
141    ):
142        return None
143    source_fn = source_fn_st[-1]
144    return source_fn[1]
145