xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_model_report/model_report_observer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3from torch.ao.quantization.observer import ObserverBase
4
5
6class ModelReportObserver(ObserverBase):
7    r"""This observer is used to record additional information regarding keeping track
8    of S = average_batch_activation_range/epoch_activation_range.
9
10    The purpose of this information is to prepare a report to present to users on whether
11    Dynamic or Static Quantization is more appropriate for their model given the general
12    distributions of their data.
13
14    Args:
15        ch_axis (int, optional): The channel axis for which the range and outlier stats are computed
16            Default: 1
17        comp_percentile (float, optional): The percentile to compare against 100 percentile to find outliers
18            Should be between 0 and 1 exclusive
19            Default: 0.9
20
21    * :attr:`num_batches_tracked` specifies number of batches passed through the observer
22
23    * :attr:`average_batch_activation_range` defines average across the ranges of each batch passed through
24
25    * :attr:`epoch_activation_min` defines the minimum value passed through the observer
26
27    * :attr:`epoch_activation_max` defines the maximum value passed through the observer
28
29    * :attr:`ch_axis` defines the channel being used to compute per channel min max stats
30
31    * :attr:`min_val` defines the per channel minimum values passed through
32
33    * :attr:`max_val` defines the per channel maximum values passed through
34
35    * :attr:`comp_percentile` defines comparison percentile to find outliers
36
37    * :attr:`average_percentile_ratio` defines the per channel average percentile ratios
38
39    * :attr:`percentile_batches_tracked` defines the number of percentile batches tracked for each channel
40
41    * :attr:`constant_channels` defines the number of batches that aren't constant channels per channel
42
43    Note: this tool is meant for FX Graph Mode Quantization
44    """
45
46    epoch_activation_min: torch.Tensor
47    epoch_activation_max: torch.Tensor
48    min_val: torch.Tensor
49    max_val: torch.Tensor
50    comp_percentile: torch.Tensor
51    average_percentile_ratio: torch.Tensor
52    percentile_batches_tracked: torch.Tensor
53    constant_channels: torch.Tensor
54
55    def __init__(self, ch_axis: int = 1, comp_percentile: float = 0.9):
56        super().__init__(torch.qint8)
57        self.num_batches_tracked = 0
58
59        # keep track of the min and mix of the range for average batch and epoch as a whole
60        self.average_batch_activation_range: torch.Tensor = torch.tensor(float(0))
61        self.register_buffer("epoch_activation_min", torch.tensor(float("inf")))
62        self.register_buffer("epoch_activation_max", torch.tensor(float("-inf")))
63
64        # keep track of per channel min max information using the given channel
65        self.ch_axis: int = ch_axis
66        self.register_buffer("min_val", torch.tensor([]))
67        self.register_buffer("max_val", torch.tensor([]))
68
69        # keep track of percentile ratio information per channel
70        self.register_buffer("comp_percentile", torch.tensor([comp_percentile]))
71        self.register_buffer("average_percentile_ratio", torch.tensor([]))
72        self.register_buffer("percentile_batches_tracked", torch.tensor([]))
73        self.register_buffer("constant_channels", torch.tensor([]))
74
75    def forward(self, x):
76        x_copy = x.detach()  # avoid keeping autograd tape
77        x_copy = x_copy.to(self.epoch_activation_min.dtype)
78
79        x_copy = self._calculate_range_stats(x_copy)
80        x_copy = self._calculate_min_max_stats(x_copy)
81        x_copy = self._calculate_percentile_stats(x_copy)
82
83        # return the passed in the value
84        return x
85
86    def _calculate_range_stats(self, x_copy):
87        r"""Calculates and stores range stats with forward values.
88
89        Args
90            x_copy: A copy of the forward data
91
92        Returns the passed in x_copy
93        """
94        # get the min, max values of the data
95        min_val_cur, max_val_cur = torch.aminmax(x_copy)
96
97        # calculate new epoch range values
98        epoch_min_val = torch.min(self.epoch_activation_min, min_val_cur)
99        epoch_max_val = torch.max(self.epoch_activation_max, max_val_cur)
100
101        self.epoch_activation_min.copy_(epoch_min_val)
102        self.epoch_activation_max.copy_(epoch_max_val)
103
104        # calculate the average batch activation range
105        current_batch_range = max_val_cur - min_val_cur
106        new_range = (
107            self.average_batch_activation_range * self.num_batches_tracked
108            + current_batch_range
109        ) / (self.num_batches_tracked + 1)
110
111        self.average_batch_activation_range = new_range
112        self.num_batches_tracked += 1  # new batch was processed
113
114        return x_copy
115
116    def _calculate_min_max_stats(self, x_copy):
117        r"""Calculates and stores the per_channel min, max stats with forward values.
118        Does calculation based on channel axis: self.ch_axis
119
120        Args
121            x_copy: A copy of the forward data
122
123        Returns the passed in x_copy
124        """
125        # get the current min and max vals
126        min_val = self.min_val
127        max_val = self.max_val
128        x_dim = x_copy.size()
129
130        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
131        new_axis_list[self.ch_axis] = 0
132        new_axis_list[0] = self.ch_axis
133        y = x_copy.permute(new_axis_list)
134        # Need to match dtype of min/max because the updates to buffers
135        # are done in place and types need to match for comparisons
136        y = y.to(self.min_val.dtype)
137        y = torch.flatten(y, start_dim=1)
138        if min_val.numel() == 0 or max_val.numel() == 0:
139            min_val, max_val = torch.aminmax(y, dim=1)
140        else:
141            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
142            min_val = torch.min(min_val_cur, min_val)
143            max_val = torch.max(max_val_cur, max_val)
144
145        self.min_val.resize_(min_val.shape)
146        self.max_val.resize_(max_val.shape)
147        self.min_val.copy_(min_val)
148        self.max_val.copy_(max_val)
149
150        return x_copy
151
152    def _calculate_percentile_stats(self, x_copy):
153        r"""Calculates and stores the per_channel percentile stats with forward values.
154        Does calculation based on channel axis: self.ch_axis
155
156        Args
157            x_copy: A copy of the forward data
158
159        Returns the passed in x_copy
160        """
161        # get the dimension of the copy
162        x_dim = x_copy.size()
163
164        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
165        new_axis_list[self.ch_axis] = 0
166        new_axis_list[0] = self.ch_axis
167        y = x_copy.permute(new_axis_list)
168        # Need to match dtype of min/max because the updates to buffers
169        # are done in place and types need to match for comparisons
170        y = y.to(self.min_val.dtype)
171        y = torch.flatten(y, start_dim=1)
172        y = y.to(dtype=self.min_val.dtype, device="cpu")
173
174        # find the percentile values along the axis
175        # we want both 100th percentile and comp_percentile
176        # we also want to find 0th quartile to see if we have constant channel
177        quantiles_list = [0, self.comp_percentile, 1.00]
178        quantiles_to_find = torch.tensor(quantiles_list, dtype=self.min_val.dtype)
179
180        # find the quantiles
181        desired_quantiles = torch.quantile(
182            y, quantiles_to_find, dim=self.ch_axis, interpolation="lower"
183        )
184        zero_quantile = desired_quantiles[0]
185        comp_quantile = desired_quantiles[1]
186        hundreth_quartile = desired_quantiles[2]
187
188        # if any of the channels have 0s, we ignore that channel for this calculation
189        any_non_zero_quantile_value: torch.Tensor = (
190            comp_quantile != torch.tensor([0])
191        ) | (hundreth_quartile != torch.tensor([0]))
192        any_non_zero_quantile_value = (
193            any_non_zero_quantile_value.int()
194        )  # transform boolean values to int values
195
196        # we also check if we have a constant channel
197        any_constant_channels: torch.Tensor = (
198            hundreth_quartile - zero_quantile
199        ) == torch.tensor([0])
200        any_constant_channels = (
201            any_constant_channels.int()
202        )  # transform boolean values to int values
203
204        # possibilities to get nan as an answer
205        #   will ignore any of these three cases with 0s and just not deal with them for now
206        # case (1) 0 in numerator: issue if 0 is largest, all negative, and rest are really negative
207        # case (2) 0 in denominator: is possible unless case 3, we just ignore
208        # case (3) 0 in both: not outlier, channel just kinda useless, ignore
209
210        # get the ratio and get rid of nan values
211        quantile_ratios = hundreth_quartile / comp_quantile
212        quantile_ratios = torch.nan_to_num(quantile_ratios)
213        # update averages, remembering to only update if didn't have zeros
214        ratio_if_not_zero = any_non_zero_quantile_value * quantile_ratios
215
216        # if num_batches and average_ratio are not initialized, we want to initialize them
217        if (
218            self.percentile_batches_tracked.shape[0] == 0
219            or self.average_percentile_ratio.shape[0] == 0
220        ):
221            self.percentile_batches_tracked = torch.zeros_like(
222                any_non_zero_quantile_value
223            )
224            self.average_percentile_ratio = torch.zeros_like(ratio_if_not_zero)
225
226        # also initialize the constant channel var if that is not initialized separately
227        if self.constant_channels.shape[0] == 0:
228            self.constant_channels = torch.zeros_like(any_constant_channels)
229
230        # get current num batches and average ratio
231        num_batches = self.percentile_batches_tracked
232        average_ratio = self.average_percentile_ratio
233
234        # calculate new_number of batches, new_ratios, and get rid of nans because of 0 size batches
235        new_number_of_batches: torch.Tensor = num_batches + any_non_zero_quantile_value
236        new_ratios: torch.Tensor = (
237            (average_ratio * num_batches) + ratio_if_not_zero
238        ) / new_number_of_batches
239        new_ratios = torch.nan_to_num(new_ratios)
240
241        # update the number of non-constant channels
242        new_constant_count: torch.Tensor = (
243            self.constant_channels + any_constant_channels
244        )
245
246        # update the values locally
247        self.percentile_batches_tracked.copy_(new_number_of_batches)
248        self.average_percentile_ratio.copy_(new_ratios)
249        self.constant_channels.copy_(new_constant_count)
250
251        return x_copy
252
253    @torch.jit.export
254    def get_batch_to_epoch_ratio(self):
255        epoch_activation_range = self.epoch_activation_max - self.epoch_activation_min
256
257        if epoch_activation_range == torch.tensor(float(0)):
258            raise ValueError("Range for Epoch is 0")
259        elif epoch_activation_range == torch.tensor(float("inf")):
260            raise ValueError(
261                "No data has been run through observer or infinity value present"
262            )
263        else:
264            return self.average_batch_activation_range / epoch_activation_range
265
266    @torch.jit.export
267    def reset_batch_and_epoch_values(self):
268        # set all the values back to their original defaults for a new epoch
269        # keep device
270        device = self.max_val.device
271        self.num_batches_tracked = 0
272        self.average_batch_activation_range = torch.tensor(float(0), device=device)
273        self.epoch_activation_min = torch.tensor(float("inf"), device=device)
274        self.epoch_activation_max = torch.tensor(float("-inf"), device=device)
275        self.min_val = torch.tensor([], device=device)
276        self.max_val = torch.tensor([], device=device)
277        self.average_percentile_ratio = torch.tensor([], device=device)
278        self.percentile_batches_tracked = torch.tensor([], device=device)
279        self.constant_channels = torch.tensor([], device=device)
280
281    @torch.jit.export
282    def calculate_qparams(self):
283        raise Exception(  # noqa: TRY002
284            "calculate_qparams should not be called for ModelReportObserver"
285        )
286