xref: /aosp_15_r20/external/executorch/backends/transforms/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExportedProgram
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import (
13*523fa7a6SAndroid Build Coastguard Worker    get_buffer,
14*523fa7a6SAndroid Build Coastguard Worker    get_lifted_tensor_constant,
15*523fa7a6SAndroid Build Coastguard Worker    get_param,
16*523fa7a6SAndroid Build Coastguard Worker    is_buffer,
17*523fa7a6SAndroid Build Coastguard Worker    is_lifted_tensor_constant,
18*523fa7a6SAndroid Build Coastguard Worker    is_param,
19*523fa7a6SAndroid Build Coastguard Worker)
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool:
23*523fa7a6SAndroid Build Coastguard Worker    """
24*523fa7a6SAndroid Build Coastguard Worker    Returns true if the given node is a get attr node for a tensor of the model
25*523fa7a6SAndroid Build Coastguard Worker    """
26*523fa7a6SAndroid Build Coastguard Worker    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
30*523fa7a6SAndroid Build Coastguard Worker    return (
31*523fa7a6SAndroid Build Coastguard Worker        is_get_attr_node(node)
32*523fa7a6SAndroid Build Coastguard Worker        or is_param(exp_prog, node)
33*523fa7a6SAndroid Build Coastguard Worker        or is_buffer(exp_prog, node)
34*523fa7a6SAndroid Build Coastguard Worker        or is_lifted_tensor_constant(exp_prog, node)
35*523fa7a6SAndroid Build Coastguard Worker    )
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker
38*523fa7a6SAndroid Build Coastguard Workerdef get_param_tensor(
39*523fa7a6SAndroid Build Coastguard Worker    exp_prog: ExportedProgram, node: torch.fx.Node
40*523fa7a6SAndroid Build Coastguard Worker) -> Optional[torch.Tensor]:
41*523fa7a6SAndroid Build Coastguard Worker    if node is None:
42*523fa7a6SAndroid Build Coastguard Worker        return None
43*523fa7a6SAndroid Build Coastguard Worker    elif is_param(exp_prog, node):
44*523fa7a6SAndroid Build Coastguard Worker        return get_param(exp_prog, node)
45*523fa7a6SAndroid Build Coastguard Worker    elif is_buffer(exp_prog, node):
46*523fa7a6SAndroid Build Coastguard Worker        return get_buffer(exp_prog, node)
47*523fa7a6SAndroid Build Coastguard Worker    elif is_lifted_tensor_constant(exp_prog, node):
48*523fa7a6SAndroid Build Coastguard Worker        return get_lifted_tensor_constant(exp_prog, node)
49*523fa7a6SAndroid Build Coastguard Worker    elif is_get_attr_node(node):
50*523fa7a6SAndroid Build Coastguard Worker        # This is a hack to support both lifted and unlifted graph
51*523fa7a6SAndroid Build Coastguard Worker        try:
52*523fa7a6SAndroid Build Coastguard Worker            return getattr(node.graph.owning_module, node.target)
53*523fa7a6SAndroid Build Coastguard Worker        except AttributeError:
54*523fa7a6SAndroid Build Coastguard Worker            return getattr(exp_prog.graph_module, node.target)
55*523fa7a6SAndroid Build Coastguard Worker    raise RuntimeError(f"unsupported param type, {node.op}.")
56