1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerfrom typing import Optional 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import ExportedProgram 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.utils import ( 13*523fa7a6SAndroid Build Coastguard Worker get_buffer, 14*523fa7a6SAndroid Build Coastguard Worker get_lifted_tensor_constant, 15*523fa7a6SAndroid Build Coastguard Worker get_param, 16*523fa7a6SAndroid Build Coastguard Worker is_buffer, 17*523fa7a6SAndroid Build Coastguard Worker is_lifted_tensor_constant, 18*523fa7a6SAndroid Build Coastguard Worker is_param, 19*523fa7a6SAndroid Build Coastguard Worker) 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerdef is_get_attr_node(node: torch.fx.Node) -> bool: 23*523fa7a6SAndroid Build Coastguard Worker """ 24*523fa7a6SAndroid Build Coastguard Worker Returns true if the given node is a get attr node for a tensor of the model 25*523fa7a6SAndroid Build Coastguard Worker """ 26*523fa7a6SAndroid Build Coastguard Worker return isinstance(node, torch.fx.Node) and node.op == "get_attr" 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Workerdef is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: 30*523fa7a6SAndroid Build Coastguard Worker return ( 31*523fa7a6SAndroid Build Coastguard Worker is_get_attr_node(node) 32*523fa7a6SAndroid Build Coastguard Worker or is_param(exp_prog, node) 33*523fa7a6SAndroid Build Coastguard Worker or is_buffer(exp_prog, node) 34*523fa7a6SAndroid Build Coastguard Worker or is_lifted_tensor_constant(exp_prog, node) 35*523fa7a6SAndroid Build Coastguard Worker ) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker 38*523fa7a6SAndroid Build Coastguard Workerdef get_param_tensor( 39*523fa7a6SAndroid Build Coastguard Worker exp_prog: ExportedProgram, node: torch.fx.Node 40*523fa7a6SAndroid Build Coastguard Worker) -> Optional[torch.Tensor]: 41*523fa7a6SAndroid Build Coastguard Worker if node is None: 42*523fa7a6SAndroid Build Coastguard Worker return None 43*523fa7a6SAndroid Build Coastguard Worker elif is_param(exp_prog, node): 44*523fa7a6SAndroid Build Coastguard Worker return get_param(exp_prog, node) 45*523fa7a6SAndroid Build Coastguard Worker elif is_buffer(exp_prog, node): 46*523fa7a6SAndroid Build Coastguard Worker return get_buffer(exp_prog, node) 47*523fa7a6SAndroid Build Coastguard Worker elif is_lifted_tensor_constant(exp_prog, node): 48*523fa7a6SAndroid Build Coastguard Worker return get_lifted_tensor_constant(exp_prog, node) 49*523fa7a6SAndroid Build Coastguard Worker elif is_get_attr_node(node): 50*523fa7a6SAndroid Build Coastguard Worker # This is a hack to support both lifted and unlifted graph 51*523fa7a6SAndroid Build Coastguard Worker try: 52*523fa7a6SAndroid Build Coastguard Worker return getattr(node.graph.owning_module, node.target) 53*523fa7a6SAndroid Build Coastguard Worker except AttributeError: 54*523fa7a6SAndroid Build Coastguard Worker return getattr(exp_prog.graph_module, node.target) 55*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"unsupported param type, {node.op}.") 56