1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional 3 4import torch 5import torch.nn as nn 6from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn 7 8 9SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag} 10 11 12def _fetch_all_embeddings(model): 13 """Fetches Embedding and EmbeddingBag modules from the model""" 14 embedding_modules = [] 15 stack = [model] 16 while stack: 17 module = stack.pop() 18 for _, child in module.named_children(): 19 fqn_name = module_to_fqn(model, child) 20 if type(child) in SUPPORTED_MODULES: 21 embedding_modules.append((fqn_name, child)) 22 else: 23 stack.append(child) 24 return embedding_modules 25 26 27def post_training_sparse_quantize( 28 model, 29 data_sparsifier_class, 30 sparsify_first=True, 31 select_embeddings: Optional[List[nn.Module]] = None, 32 **sparse_config, 33): 34 """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags. 35 The quantization step can happen before or after sparsification depending on the `sparsify_first` argument. 36 37 Args: 38 - model (nn.Module) 39 model whose embeddings needs to be sparsified 40 - data_sparsifier_class (type of data sparsifier) 41 Type of sparsification that needs to be applied to model 42 - sparsify_first (bool) 43 if true, sparsifies first and then quantizes 44 otherwise, quantizes first and then sparsifies. 45 - select_embeddings (List of Embedding modules) 46 List of embedding modules to in the model to be sparsified & quantized. 47 If None, all embedding modules with be sparsified 48 - sparse_config (Dict) 49 config that will be passed to the constructor of data sparsifier object. 50 51 Note: 52 1. When `sparsify_first=False`, quantization occurs first followed by sparsification. 53 - before sparsifying, the embedding layers are dequantized. 54 - scales and zero-points are saved 55 - embedding layers are sparsified and `squash_mask` is applied 56 - embedding weights are requantized using the saved scales and zero-points 57 2. When `sparsify_first=True`, sparsification occurs first followed by quantization. 58 - embeddings are sparsified first 59 - quantization is applied on the sparsified embeddings 60 """ 61 data_sparsifier = data_sparsifier_class(**sparse_config) 62 63 # if select_embeddings is None, perform it on all embeddings 64 if select_embeddings is None: 65 embedding_modules = _fetch_all_embeddings(model) 66 67 else: 68 embedding_modules = [] 69 assert isinstance( 70 select_embeddings, List 71 ), "the embedding_modules must be a list of embedding modules" 72 for emb in select_embeddings: 73 assert ( 74 type(emb) in SUPPORTED_MODULES 75 ), "the embedding_modules list must be an embedding or embedding bags" 76 fqn_name = module_to_fqn(model, emb) 77 assert ( 78 fqn_name is not None 79 ), "the embedding modules must be part of input model" 80 embedding_modules.append((fqn_name, emb)) 81 82 if sparsify_first: 83 # sparsify 84 for name, emb_module in embedding_modules: 85 valid_name = name.replace(".", "_") 86 data_sparsifier.add_data(name=valid_name, data=emb_module) 87 88 data_sparsifier.step() 89 data_sparsifier.squash_mask() 90 91 # quantize 92 for _, emb_module in embedding_modules: 93 emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig 94 95 torch.ao.quantization.prepare(model, inplace=True) 96 torch.ao.quantization.convert(model, inplace=True) 97 98 else: 99 # quantize 100 for _, emb_module in embedding_modules: 101 emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig 102 103 torch.ao.quantization.prepare(model, inplace=True) 104 torch.ao.quantization.convert(model, inplace=True) 105 106 # retrieve scale & zero_points 107 quantize_params: Dict[str, Dict] = { 108 "scales": {}, 109 "zero_points": {}, 110 "dequant_weights": {}, 111 "axis": {}, 112 "dtype": {}, 113 } 114 115 for name, _ in embedding_modules: 116 quantized_emb = fqn_to_module(model, name) 117 assert quantized_emb is not None # satisfy mypy 118 119 quantized_weight = quantized_emb.weight() # type: ignore[operator] 120 quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() 121 quantize_params["zero_points"][ 122 name 123 ] = quantized_weight.q_per_channel_zero_points() 124 quantize_params["dequant_weights"][name] = torch.dequantize( 125 quantized_weight 126 ) 127 quantize_params["axis"][name] = quantized_weight.q_per_channel_axis() 128 quantize_params["dtype"][name] = quantized_weight.dtype 129 130 # attach data to sparsifier 131 data_sparsifier.add_data( 132 name=name.replace(".", "_"), 133 data=quantize_params["dequant_weights"][name], 134 ) 135 136 data_sparsifier.step() 137 data_sparsifier.squash_mask() 138 139 for name, _ in embedding_modules: 140 quantized_emb = fqn_to_module(model, name) 141 assert quantized_emb is not None # satisfy mypy 142 requantized_vector = torch.quantize_per_channel( 143 quantize_params["dequant_weights"][name], 144 scales=quantize_params["scales"][name], 145 zero_points=quantize_params["zero_points"][name], 146 dtype=quantize_params["dtype"][name], 147 axis=quantize_params["axis"][name], 148 ) 149 150 quantized_emb.set_weight(requantized_vector) # type: ignore[operator] 151