xref: /aosp_15_r20/external/pytorch/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Dict, List, Optional
3
4import torch
5import torch.nn as nn
6from torch.ao.pruning.sparsifier.utils import fqn_to_module, module_to_fqn
7
8
9SUPPORTED_MODULES = {nn.Embedding, nn.EmbeddingBag}
10
11
12def _fetch_all_embeddings(model):
13    """Fetches Embedding and EmbeddingBag modules from the model"""
14    embedding_modules = []
15    stack = [model]
16    while stack:
17        module = stack.pop()
18        for _, child in module.named_children():
19            fqn_name = module_to_fqn(model, child)
20            if type(child) in SUPPORTED_MODULES:
21                embedding_modules.append((fqn_name, child))
22            else:
23                stack.append(child)
24    return embedding_modules
25
26
27def post_training_sparse_quantize(
28    model,
29    data_sparsifier_class,
30    sparsify_first=True,
31    select_embeddings: Optional[List[nn.Module]] = None,
32    **sparse_config,
33):
34    """Takes in a model and applies sparsification and quantization to only embeddings & embeddingbags.
35    The quantization step can happen before or after sparsification depending on the `sparsify_first` argument.
36
37    Args:
38        - model (nn.Module)
39            model whose embeddings needs to be sparsified
40        - data_sparsifier_class (type of data sparsifier)
41            Type of sparsification that needs to be applied to model
42        - sparsify_first (bool)
43            if true, sparsifies first and then quantizes
44            otherwise, quantizes first and then sparsifies.
45        - select_embeddings (List of Embedding modules)
46            List of embedding modules to in the model to be sparsified & quantized.
47            If None, all embedding modules with be sparsified
48        - sparse_config (Dict)
49            config that will be passed to the constructor of data sparsifier object.
50
51    Note:
52        1. When `sparsify_first=False`, quantization occurs first followed by sparsification.
53            - before sparsifying, the embedding layers are dequantized.
54            - scales and zero-points are saved
55            - embedding layers are sparsified and `squash_mask` is applied
56            - embedding weights are requantized using the saved scales and zero-points
57        2. When `sparsify_first=True`, sparsification occurs first followed by quantization.
58            - embeddings are sparsified first
59            - quantization is applied on the sparsified embeddings
60    """
61    data_sparsifier = data_sparsifier_class(**sparse_config)
62
63    # if select_embeddings is None, perform it on all embeddings
64    if select_embeddings is None:
65        embedding_modules = _fetch_all_embeddings(model)
66
67    else:
68        embedding_modules = []
69        assert isinstance(
70            select_embeddings, List
71        ), "the embedding_modules must be a list of embedding modules"
72        for emb in select_embeddings:
73            assert (
74                type(emb) in SUPPORTED_MODULES
75            ), "the embedding_modules list must be an embedding or embedding bags"
76            fqn_name = module_to_fqn(model, emb)
77            assert (
78                fqn_name is not None
79            ), "the embedding modules must be part of input model"
80            embedding_modules.append((fqn_name, emb))
81
82    if sparsify_first:
83        # sparsify
84        for name, emb_module in embedding_modules:
85            valid_name = name.replace(".", "_")
86            data_sparsifier.add_data(name=valid_name, data=emb_module)
87
88        data_sparsifier.step()
89        data_sparsifier.squash_mask()
90
91        # quantize
92        for _, emb_module in embedding_modules:
93            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
94
95        torch.ao.quantization.prepare(model, inplace=True)
96        torch.ao.quantization.convert(model, inplace=True)
97
98    else:
99        # quantize
100        for _, emb_module in embedding_modules:
101            emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
102
103        torch.ao.quantization.prepare(model, inplace=True)
104        torch.ao.quantization.convert(model, inplace=True)
105
106        # retrieve scale & zero_points
107        quantize_params: Dict[str, Dict] = {
108            "scales": {},
109            "zero_points": {},
110            "dequant_weights": {},
111            "axis": {},
112            "dtype": {},
113        }
114
115        for name, _ in embedding_modules:
116            quantized_emb = fqn_to_module(model, name)
117            assert quantized_emb is not None  # satisfy mypy
118
119            quantized_weight = quantized_emb.weight()  # type: ignore[operator]
120            quantize_params["scales"][name] = quantized_weight.q_per_channel_scales()
121            quantize_params["zero_points"][
122                name
123            ] = quantized_weight.q_per_channel_zero_points()
124            quantize_params["dequant_weights"][name] = torch.dequantize(
125                quantized_weight
126            )
127            quantize_params["axis"][name] = quantized_weight.q_per_channel_axis()
128            quantize_params["dtype"][name] = quantized_weight.dtype
129
130            # attach data to sparsifier
131            data_sparsifier.add_data(
132                name=name.replace(".", "_"),
133                data=quantize_params["dequant_weights"][name],
134            )
135
136        data_sparsifier.step()
137        data_sparsifier.squash_mask()
138
139        for name, _ in embedding_modules:
140            quantized_emb = fqn_to_module(model, name)
141            assert quantized_emb is not None  # satisfy mypy
142            requantized_vector = torch.quantize_per_channel(
143                quantize_params["dequant_weights"][name],
144                scales=quantize_params["scales"][name],
145                zero_points=quantize_params["zero_points"][name],
146                dtype=quantize_params["dtype"][name],
147                axis=quantize_params["axis"][name],
148            )
149
150            quantized_emb.set_weight(requantized_vector)  # type: ignore[operator]
151