1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass, PassResult 12 13""" 14The passes below were taking fron bolt/nn/executorch/passes/quant_fusion.py 15""" 16 17T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 18 19 20class DequantDuplicator: 21 def __init__(self, pass_obj): 22 self.pass_obj = pass_obj 23 self.duplicated_dequant = False 24 25 def __call__(self, arg): 26 if ( 27 not isinstance(arg, torch.fx.Node) 28 or arg.op != "call_function" 29 or arg.target != T_DQuantPerTensor # TODO handle per channel case 30 ): 31 return arg 32 33 if arg not in self.pass_obj.dequant_map: 34 self.pass_obj.dequant_map[arg] = (arg.args, arg.kwargs, arg.meta) 35 return arg 36 else: 37 args, kwargs, meta = self.pass_obj.dequant_map[arg] 38 with self.pass_obj.mod.graph.inserting_before(self.pass_obj.current_node): 39 dup_dequant = self.pass_obj.mod.graph.call_function( 40 T_DQuantPerTensor, 41 args=args, 42 kwargs=kwargs, 43 ) 44 dup_dequant.meta = meta 45 dup_dequant.meta["val"] = copy.copy(meta["val"]) 46 self.duplicated_dequant = True 47 return dup_dequant 48 49 50class DuplicateDequantNodePass(ExportPass): 51 """ 52 Duplicates all of the dequantize. This is such that all 53 quantized ops have their own unique dequant nodes. Since 54 quantized ops are represented as dq -> op -> q. Sharing dq nodes 55 for quantized ops makes it impossibl to partition against quantized 56 ops. As a result we need to duplicate dq nodes for ops which 57 share a dq node 58 59 In this example, the graph below: 60 61 --> op --> q 62 / 63 dq -> op -> q -> dq 64 \ 65 --> op --> q 66 67 is transformed into: 68 69 --> dq --> op --> q 70 / 71 dq -> op -> q 72 \ 73 --> dq --> op --> q 74 75 76 """ 77 78 def __init__(self): 79 super().__init__() 80 self.dequant_map = {} # Map of dequant results to its node's arguments 81 82 def call(self, mod): 83 self.mod = mod 84 duplicator = DequantDuplicator(self) 85 for node in list(mod.graph.nodes): 86 self.current_node = node 87 88 if node.op != "call_function": 89 continue 90 91 new_args = [] 92 duplicator.duplicated_dequant = False 93 for arg in node.args: 94 if isinstance(arg, list): 95 new_args.append(list(map(duplicator, arg))) 96 else: 97 new_args.append(duplicator(arg)) 98 99 if duplicator.duplicated_dequant: 100 with mod.graph.inserting_before(node): 101 new_node = mod.graph.call_function( 102 node.target, 103 args=tuple(new_args), 104 kwargs=node.kwargs, 105 ) 106 new_node.meta = node.meta 107 node.replace_all_uses_with(new_node) 108 mod.graph.eliminate_dead_code() 109 mod.recompile() 110 111 return PassResult(mod, True) 112