xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/_model_report/model_report_visualizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from collections import OrderedDict as OrdDict
3from typing import Any, Dict, List, OrderedDict, Set, Tuple
4
5import torch
6
7
8# try to import tablate
9got_tabulate = True
10try:
11    from tabulate import tabulate
12except ImportError:
13    got_tabulate = False
14
15
16# var to see if we could import matplotlib
17got_matplotlib = True
18try:
19    import matplotlib.pyplot as plt
20except ImportError:
21    got_matplotlib = False
22
23
24class ModelReportVisualizer:
25    r"""
26    The ModelReportVisualizer class aims to provide users a way to visualize some of the statistics
27    that were generated by the ModelReport API. However, at a higher level, the class aims to provide
28    some level of visualization of statistics to PyTorch in order to make it easier to parse data and
29    diagnose any potential issues with data or a specific model. With respect to the visualizations,
30    the ModelReportVisualizer class currently supports several methods of visualizing data.
31
32    Supported Visualization Methods Include:
33    - Table format
34    - Plot format (line graph)
35    - Histogram format
36
37    For all of the existing visualization methods, there is the option to filter data based on:
38    - A module fqn prefix
39    - Feature [required for the plot and histogram]
40
41    * :attr:`generated_reports` The reports generated by the ModelReport class in the structure below
42        Ensure sure that features that are the same across different report contain the same name
43        Ensure that objects representing the same features are the same type / dimension (where applicable)
44
45    Note:
46        Currently, the ModelReportVisualizer class supports visualization of data generated by the
47        ModelReport class. However, this structure is extensible and should allow the visualization of
48        other information as long as the information is structured in the following general format:
49
50        Report Structure
51        -- module_fqn [module with attached detectors]
52            |
53            -- feature keys [not every detector extracts same information]
54                                    [same collected info has same keys, unless can be specific to detector]
55
56
57    The goal behind the class is that the generated visualizations can be used in conjunction with the generated
58    report for people to get a better understanding of issues and what the fix might be. It is also just to provide
59    a good visualization platform, since it might be hard to parse through the ModelReport returned dictionary as
60    that grows in size.
61
62    General Use Flow Expected
63    1.) Initialize ModelReport object with reports of interest by passing in initialized detector objects
64    2.) Prepare your model with prepare_fx
65    3.) Call model_report.prepare_detailed_calibration on your model to add relevant observers
66    4.) Callibrate your model with data
67    5.) Call model_report.generate_report on your model to generate report and optionally remove added observers
68    6.) Use output of model_report.generate_report to initialize ModelReportVisualizer instance
69    7.) Use instance to view different views of data as desired, applying filters as needed
70        8.) Either see the super detailed information or just the actual printed or shown table / plot / histogram
71
72    """
73
74    # keys for table dict
75    TABLE_TENSOR_KEY = "tensor_level_info"
76    TABLE_CHANNEL_KEY = "channel_level_info"
77
78    # Constants for header vals
79    NUM_NON_FEATURE_TENSOR_HEADERS = 2
80    NUM_NON_FEATURE_CHANNEL_HEADERS = 3
81
82    # Constants for row index in header
83    CHANNEL_NUM_INDEX = 2
84
85    def __init__(self, generated_reports: OrderedDict[str, Any]):
86        r"""
87        Initializes the ModelReportVisualizer instance with the necessary reports.
88
89        Args:
90            generated_reports (Dict[str, Any]): The reports generated by the ModelReport class
91                can also be a dictionary generated in another manner, as long as format is same
92        """
93        self.generated_reports = generated_reports
94
95    def get_all_unique_module_fqns(self) -> Set[str]:
96        r"""
97        The purpose of this method is to provide a user the set of all module_fqns so that if
98        they wish to use some of the filtering capabilities of the ModelReportVisualizer class,
99        they don't need to manually parse the generated_reports dictionary to get this information.
100
101        Returns all the unique module fqns present in the reports the ModelReportVisualizer
102        instance was initialized with.
103        """
104        # returns the keys of the ordered dict
105        return set(self.generated_reports.keys())
106
107    def get_all_unique_feature_names(
108        self, plottable_features_only: bool = True
109    ) -> Set[str]:
110        r"""
111        The purpose of this method is to provide a user the set of all feature names so that if
112        they wish to use the filtering capabilities of the generate_table_view(), or use either of
113        the generate_plot_view() or generate_histogram_view(), they don't need to manually parse
114        the generated_reports dictionary to get this information.
115
116        Args:
117            plottable_features_only (bool): True if the user is only looking for plottable features,
118                False otherwise
119                plottable features are those that are tensor values
120                Default: True (only return those feature names that are plottable)
121
122        Returns all the unique module fqns present in the reports the ModelReportVisualizer
123        instance was initialized with.
124        """
125        unique_feature_names = set()
126        for module_fqn in self.generated_reports:
127            # get dict of the features
128            feature_dict: Dict[str, Any] = self.generated_reports[module_fqn]
129
130            # loop through features
131            for feature_name in feature_dict:
132                # if we need plottable, ensure type of val is tensor
133                if (
134                    not plottable_features_only
135                    or type(feature_dict[feature_name]) == torch.Tensor
136                ):
137                    unique_feature_names.add(feature_name)
138
139        # return our compiled set of unique feature names
140        return unique_feature_names
141
142    def _get_filtered_data(
143        self, feature_filter: str, module_fqn_filter: str
144    ) -> OrderedDict[str, Any]:
145        r"""
146        Filters the data and returns it in the same ordered dictionary format so the relevant views can be displayed.
147
148        Args:
149            feature_filter (str): The feature filter, if we want to filter the set of data to only include
150                a certain set of features that include feature_filter
151                If feature = "", then we do not filter based on any features
152            module_fqn_filter (str): The filter on prefix for the module fqn. All modules that have fqn with
153                this prefix will be included
154                If module_fqn_filter = "" we do not filter based on module fqn, and include all modules
155
156        First, the data is filtered based on module_fqn, and then filtered based on feature
157        Returns an OrderedDict (sorted in order of model) mapping:
158            module_fqns -> feature_names -> values
159        """
160        # create return dict
161        filtered_dict: OrderedDict[str, Any] = OrdDict()
162
163        for module_fqn in self.generated_reports:
164            # first filter based on module
165            if module_fqn_filter == "" or module_fqn_filter in module_fqn:
166                # create entry for module and loop through features
167                filtered_dict[module_fqn] = {}
168                module_reports = self.generated_reports[module_fqn]
169                for feature_name in module_reports:
170                    # check if filtering on features and do so if desired
171                    if feature_filter == "" or feature_filter in feature_name:
172                        filtered_dict[module_fqn][feature_name] = module_reports[
173                            feature_name
174                        ]
175
176        # we have populated the filtered dict, and must return it
177
178        return filtered_dict
179
180    def _generate_tensor_table(
181        self,
182        filtered_data: OrderedDict[str, Dict[str, Any]],
183        tensor_features: List[str],
184    ) -> Tuple[List, List]:
185        r"""
186        Takes in the filtered data and features list and generates the tensor headers and table
187
188        Currently meant to generate the headers and table for both the tensor information.
189
190        Args:
191            filtered_data (OrderedDict[str, Dict[str, Any]]): An OrderedDict (sorted in order of model) mapping:
192                module_fqns -> feature_names -> values
193            tensor_features (List[str]): A list of the tensor level features
194
195        Returns a tuple with:
196            A list of the headers of the tensor table
197            A list of lists containing the table information row by row
198            The 0th index row will contain the headers of the columns
199            The rest of the rows will contain data
200        """
201        # now we compose the tensor information table
202        tensor_table: List[List[Any]] = []
203        tensor_headers: List[str] = []
204
205        # append the table row to the table only if we have features
206        if len(tensor_features) > 0:
207            # now we add all the data
208            for index, module_fqn in enumerate(filtered_data):
209                # we make a new row for the tensor table
210                tensor_table_row = [index, module_fqn]
211                for feature in tensor_features:
212                    # we iterate in same order of added features
213
214                    if feature in filtered_data[module_fqn]:
215                        # add value if applicable to module
216                        feature_val = filtered_data[module_fqn][feature]
217                    else:
218                        # add that it is not applicable
219                        feature_val = "Not Applicable"
220
221                    # if it's a tensor we want to extract val
222                    if isinstance(feature_val, torch.Tensor):
223                        feature_val = feature_val.item()
224
225                    # we add to our list of values
226                    tensor_table_row.append(feature_val)
227
228                tensor_table.append(tensor_table_row)
229
230        # add row of headers of we actually have something, otherwise just empty
231        if len(tensor_table) != 0:
232            tensor_headers = ["idx", "layer_fqn"] + tensor_features
233
234        return (tensor_headers, tensor_table)
235
236    def _generate_channels_table(
237        self,
238        filtered_data: OrderedDict[str, Any],
239        channel_features: List[str],
240        num_channels: int,
241    ) -> Tuple[List, List]:
242        r"""
243        Takes in the filtered data and features list and generates the channels headers and table
244
245        Currently meant to generate the headers and table for both the channels information.
246
247        Args:
248            filtered_data (OrderedDict[str, Any]): An OrderedDict (sorted in order of model) mapping:
249                module_fqns -> feature_names -> values
250            channel_features (List[str]): A list of the channel level features
251            num_channels (int): Number of channels in the channel data
252
253        Returns a tuple with:
254            A list of the headers of the channel table
255            A list of lists containing the table information row by row
256            The 0th index row will contain the headers of the columns
257            The rest of the rows will contain data
258        """
259        # now we compose the table for the channel information table
260        channel_table: List[List[Any]] = []
261        channel_headers: List[str] = []
262
263        # counter to keep track of number of entries in
264        channel_table_entry_counter: int = 0
265
266        if len(channel_features) > 0:
267            # now we add all channel data
268            for module_fqn in filtered_data:
269                # we iterate over all channels
270                for channel in range(num_channels):
271                    # we make a new row for the channel
272                    new_channel_row = [channel_table_entry_counter, module_fqn, channel]
273                    for feature in channel_features:
274                        if feature in filtered_data[module_fqn]:
275                            # add value if applicable to module
276                            feature_val = filtered_data[module_fqn][feature][channel]
277                        else:
278                            # add that it is not applicable
279                            feature_val = "Not Applicable"
280
281                        # if it's a tensor we want to extract val
282                        if type(feature_val) is torch.Tensor:
283                            feature_val = feature_val.item()
284
285                        # add value to channel specific row
286                        new_channel_row.append(feature_val)
287
288                    # add to table and increment row index counter
289                    channel_table.append(new_channel_row)
290                    channel_table_entry_counter += 1
291
292        # add row of headers of we actually have something, otherwise just empty
293        if len(channel_table) != 0:
294            channel_headers = ["idx", "layer_fqn", "channel"] + channel_features
295
296        return (channel_headers, channel_table)
297
298    def generate_filtered_tables(
299        self, feature_filter: str = "", module_fqn_filter: str = ""
300    ) -> Dict[str, Tuple[List, List]]:
301        r"""
302        Takes in optional filter values and generates two tables with desired information.
303
304        The generated tables are presented in both a list-of-lists format
305
306        The reason for the two tables are that they handle different things:
307        1.) the first table handles all tensor level information
308        2.) the second table handles and displays all channel based information
309
310        The reasoning for this is that having all the info in one table can make it ambiguous which collected
311            statistics are global, and which are actually per-channel, so it's better to split it up into two
312            tables. This also makes the information much easier to digest given the plethora of statistics collected
313
314        Tensor table columns:
315            idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
316            ----  ---------  ---------   ---------   ---------        ---------
317
318        Per-Channel table columns:
319            idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
320            ----  ---------  -------  ---------   ---------   ---------        ---------
321
322        Args:
323            feature_filter (str, optional): Filters the features presented to only those that
324                contain this filter substring
325                Default = "", results in all the features being printed
326            module_fqn_filter (str, optional): Only includes modules that contains this string
327                Default = "", results in all the modules in the reports to be visible in the table
328
329        Returns a dictionary with two keys:
330            (Dict[str, Tuple[List, List]]) A dict containing two keys:
331            "tensor_level_info", "channel_level_info"
332                Each key maps to a tuple with:
333                    A list of the headers of each table
334                    A list of lists containing the table information row by row
335                    The 0th index row will contain the headers of the columns
336                    The rest of the rows will contain data
337
338        Example Use:
339            >>> # xdoctest: +SKIP("undefined variables")
340            >>> mod_report_visualizer.generate_filtered_tables(
341            ...     feature_filter = "per_channel_min",
342            ...     module_fqn_filter = "block1"
343            ... ) # generates table with per_channel_min info for all modules in block 1 of the model
344        """
345        # first get the filtered data
346        filtered_data: OrderedDict[str, Any] = self._get_filtered_data(
347            feature_filter, module_fqn_filter
348        )
349
350        # now we split into tensor and per-channel data
351        tensor_features: Set[str] = set()
352        channel_features: Set[str] = set()
353
354        # keep track of the number of channels we have
355        num_channels: int = 0
356
357        for module_fqn in filtered_data:
358            for feature_name in filtered_data[module_fqn]:
359                # get the data for that specific feature
360                feature_data = filtered_data[module_fqn][feature_name]
361
362                # check if not zero dim tensor
363                is_tensor: bool = isinstance(feature_data, torch.Tensor)
364                is_not_zero_dim: bool = is_tensor and len(feature_data.shape) != 0
365
366                if is_not_zero_dim or isinstance(feature_data, list):
367                    # works means per channel
368                    channel_features.add(feature_name)
369                    num_channels = len(feature_data)
370                else:
371                    # means is per-tensor
372                    tensor_features.add(feature_name)
373
374        # we make them lists for iteration purposes
375        tensor_features_list: List[str] = sorted(tensor_features)
376        channel_features_list: List[str] = sorted(channel_features)
377
378        # get the tensor info
379        tensor_headers, tensor_table = self._generate_tensor_table(
380            filtered_data, tensor_features_list
381        )
382
383        # get the channel info
384        channel_headers, channel_table = self._generate_channels_table(
385            filtered_data, channel_features_list, num_channels
386        )
387
388        # let's now create the dictionary to return
389        table_dict = {
390            self.TABLE_TENSOR_KEY: (tensor_headers, tensor_table),
391            self.TABLE_CHANNEL_KEY: (channel_headers, channel_table),
392        }
393
394        # return the two tables
395        return table_dict
396
397    def generate_table_visualization(
398        self, feature_filter: str = "", module_fqn_filter: str = ""
399    ):
400        r"""
401        Takes in optional filter values and prints out formatted tables of the information.
402
403        The reason for the two tables printed out instead of one large one are that they handle different things:
404        1.) the first table handles all tensor level information
405        2.) the second table handles and displays all channel based information
406
407        The reasoning for this is that having all the info in one table can make it ambiguous which collected
408            statistics are global, and which are actually per-channel, so it's better to split it up into two
409            tables. This also makes the information much easier to digest given the plethora of statistics collected
410
411        Tensor table columns:
412         idx  layer_fqn  feature_1   feature_2   feature_3   .... feature_n
413        ----  ---------  ---------   ---------   ---------        ---------
414
415        Per-Channel table columns:
416
417         idx  layer_fqn  channel  feature_1   feature_2   feature_3   .... feature_n
418        ----  ---------  -------  ---------   ---------   ---------        ---------
419
420        Args:
421            feature_filter (str, optional): Filters the features presented to only those that
422                contain this filter substring
423                Default = "", results in all the features being printed
424            module_fqn_filter (str, optional): Only includes modules that contains this string
425                Default = "", results in all the modules in the reports to be visible in the table
426
427        Example Use:
428            >>> # xdoctest: +SKIP("undefined variables")
429            >>> mod_report_visualizer.generate_table_visualization(
430            ...     feature_filter = "per_channel_min",
431            ...     module_fqn_filter = "block1"
432            ... )
433            >>> # prints out neatly formatted table with per_channel_min info
434            >>> # for all modules in block 1 of the model
435        """
436        # see if we got tabulate
437        if not got_tabulate:
438            print("Make sure to install tabulate and try again.")
439            return None
440
441        # get the table dict and the specific tables of interest
442        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
443        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
444        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
445
446        # get the table string and print it out
447        # now we have populated the tables for each one
448        # let's create the strings to be returned
449        table_str = ""
450        # the tables will have some headers columns that are non-feature
451        # ex. table index, module name, channel index, etc.
452        # we want to look at header columns for features, that come after those headers
453        if len(tensor_headers) > self.NUM_NON_FEATURE_TENSOR_HEADERS:
454            # if we have at least one tensor level feature to be added we add tensor table
455            table_str += "Tensor Level Information \n"
456            table_str += tabulate(tensor_table, headers=tensor_headers)
457        if len(channel_headers) > self.NUM_NON_FEATURE_CHANNEL_HEADERS:
458            # if we have at least one channel level feature to be added we add tensor table
459            table_str += "\n\n Channel Level Information \n"
460            table_str += tabulate(channel_table, headers=channel_headers)
461
462        # if no features at all, let user know
463        if table_str == "":
464            table_str = "No data points to generate table with."
465
466        print(table_str)
467
468    def _get_plottable_data(
469        self, feature_filter: str, module_fqn_filter: str
470    ) -> Tuple[List, List[List], bool]:
471        r"""
472        Takes in the feature filters and module filters and outputs the x and y data for plotting
473
474        Args:
475            feature_filter (str): Filters the features presented to only those that
476                contain this filter substring
477            module_fqn_filter (str): Only includes modules that contains this string
478
479        Returns a tuple of three elements
480            The first is a list containing relevant x-axis data
481            The second is a list containing the corresponding y-axis data
482            If the data is per channel
483        """
484        # get the table dict and the specific tables of interest
485        table_dict = self.generate_filtered_tables(feature_filter, module_fqn_filter)
486        tensor_headers, tensor_table = table_dict[self.TABLE_TENSOR_KEY]
487        channel_headers, channel_table = table_dict[self.TABLE_CHANNEL_KEY]
488
489        # make sure it is only 1 feature that is being plotted
490        # get the number of features in each of these
491        tensor_info_features_count = (
492            len(tensor_headers) - ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
493        )
494        channel_info_features_count = (
495            len(channel_headers) - ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
496        )
497
498        # see if valid tensor or channel plot
499        is_valid_per_tensor_plot: bool = tensor_info_features_count == 1
500        is_valid_per_channel_plot: bool = channel_info_features_count == 1
501
502        # offset should either be one of tensor or channel table or neither
503        feature_column_offset = ModelReportVisualizer.NUM_NON_FEATURE_TENSOR_HEADERS
504        table = tensor_table
505
506        # if a per_channel plot, we have different offset and table
507        if is_valid_per_channel_plot:
508            feature_column_offset = (
509                ModelReportVisualizer.NUM_NON_FEATURE_CHANNEL_HEADERS
510            )
511            table = channel_table
512
513        x_data: List = []
514        y_data: List[List] = []
515        # the feature will either be a tensor feature or channel feature
516        if is_valid_per_tensor_plot:
517            for table_row_num, row in enumerate(table):
518                # get x_value to append
519                x_val_to_append = table_row_num
520                # the index of the feature will the 0 + num non feature columns
521                tensor_feature_index = feature_column_offset
522                row_value = row[tensor_feature_index]
523                if not type(row_value) == str:
524                    x_data.append(x_val_to_append)
525                    y_data.append(row_value)
526        elif is_valid_per_channel_plot:
527            # gather the x_data and multiple y_data
528            # calculate the number of channels
529            num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
530            for channel in range(num_channels):
531                y_data.append([])  # separate data list per channel
532
533            for table_row_num, row in enumerate(table):
534                # get x_value to append
535                x_val_to_append = table_row_num
536                current_channel = row[
537                    self.CHANNEL_NUM_INDEX
538                ]  # initially chose current channel
539                new_module_index: int = table_row_num // num_channels
540                x_val_to_append = new_module_index
541
542                # the index of the feature will the 0 + num non feature columns
543                tensor_feature_index = feature_column_offset
544                row_value = row[tensor_feature_index]
545                if not type(row_value) == str:
546                    # only append if new index we are appending
547                    if len(x_data) == 0 or x_data[-1] != x_val_to_append:
548                        x_data.append(x_val_to_append)
549
550                    # append value for that channel
551                    y_data[current_channel].append(row_value)
552        else:
553            # more than one feature was chosen
554            error_str = "Make sure to pick only a single feature with your filter to plot a graph."
555            error_str += " We recommend calling get_all_unique_feature_names() to find unique feature names."
556            error_str += " Pick one of those features to plot."
557            raise ValueError(error_str)
558
559        # return x, y values, and if data is per-channel
560        return (x_data, y_data, is_valid_per_channel_plot)
561
562    def generate_plot_visualization(
563        self, feature_filter: str, module_fqn_filter: str = ""
564    ):
565        r"""
566        Takes in a feature and optional module_filter and plots of the desired data.
567
568        For per channel features, it averages the value across the channels and plots a point
569        per module. The reason for this is that for models with hundreds of channels, it can
570        be hard to differentiate one channel line from another, and so the point of generating
571        a single average point per module is to give a sense of general trends that encourage
572        further deep dives.
573
574        Note:
575            Only features in the report that have tensor value data are plottable by this class
576            When the tensor information is plotted, it will plot:
577                idx as the x val, feature value as the y_val
578            When the channel information is plotted, it will plot:
579                the first idx of each module as the x val, feature value as the y_val [for each channel]
580                The reason for this is that we want to be able to compare values across the
581                channels for same layer, and it will be hard if values are staggered by idx
582                This means each module is represented by only 1 x value
583        Args:
584            feature_filter (str): Filters the features presented to only those that
585                contain this filter substring
586            module_fqn_filter (str, optional): Only includes modules that contains this string
587                Default = "", results in all the modules in the reports to be visible in the table
588
589        Example Use:
590            >>> # xdoctest: +SKIP("undefined variables")
591            >>> mod_report_visualizer.generate_plot_visualization(
592            ...     feature_filter = "per_channel_min",
593            ...     module_fqn_filter = "block1"
594            ... )
595            >>> # outputs line plot of per_channel_min information for all
596            >>> # modules in block1 of model each channel gets it's own line,
597            >>> # and it's plotted across the in-order modules on the x-axis
598        """
599        # checks if we have matplotlib and let's user know to install it if don't
600        if not got_matplotlib:
601            print("make sure to install matplotlib and try again.")
602            return None
603
604        # get the x and y data and if per channel
605        x_data, y_data, data_per_channel = self._get_plottable_data(
606            feature_filter, module_fqn_filter
607        )
608
609        # plot based on whether data is per channel or not
610        ax = plt.subplot()
611        ax.set_ylabel(feature_filter)
612        ax.set_title(feature_filter + " Plot")
613        plt.xticks(x_data)  # only show ticks for actual points
614
615        if data_per_channel:
616            ax.set_xlabel("First idx of module")
617            # set the legend as well
618            # plot a single line that is average of the channel values
619            num_modules = len(
620                y_data[0]
621            )  # all y_data have same length, so get num modules
622            num_channels = len(
623                y_data
624            )  # we want num channels to be able to calculate average later
625
626            avg_vals = [
627                sum(y_data[:][index]) / num_channels for index in range(num_modules)
628            ]
629
630            # plot the three things we measured
631            ax.plot(
632                x_data, avg_vals, label=f"Average Value Across {num_channels} Channels"
633            )
634            ax.legend(loc="upper right")
635        else:
636            ax.set_xlabel("idx")
637            ax.plot(x_data, y_data)
638
639        # actually show the plot
640        plt.show()
641
642    def generate_histogram_visualization(
643        self, feature_filter: str, module_fqn_filter: str = "", num_bins: int = 10
644    ):
645        r"""
646        Takes in a feature and optional module_filter and plots the histogram of desired data.
647
648        Note:
649            Only features in the report that have tensor value data can be viewed as a histogram
650            If you want to plot a histogram from all the channel values of a specific feature for
651                a specific model, make sure to specify both the model and the feature properly
652                in the filters and you should be able to see a distribution of the channel data
653
654        Args:
655            feature_filter (str, optional): Filters the features presented to only those that
656                contain this filter substring
657                Default = "", results in all the features being printed
658            module_fqn_filter (str, optional): Only includes modules that contains this string
659                Default = "", results in all the modules in the reports to be visible in the table
660            num_bins (int, optional): The number of bins to create the histogram with
661                Default = 10, the values will be split into 10 equal sized bins
662
663        Example Use:
664            >>> # xdoctest: +SKIP
665            >>> mod_report_visualizer.generategenerate_histogram_visualization_plot_visualization(
666            ...     feature_filter = "per_channel_min",
667            ...     module_fqn_filter = "block1"
668            ... )
669            # outputs histogram of per_channel_min information for all modules in block1 of model
670                information is gathered across all channels for all modules in block 1 for the
671                per_channel_min and is displayed in a histogram of equally sized bins
672        """
673        # checks if we have matplotlib and let's user know to install it if don't
674        if not got_matplotlib:
675            print("make sure to install matplotlib and try again.")
676            return None
677
678        # get the x and y data and if per channel
679        x_data, y_data, data_per_channel = self._get_plottable_data(
680            feature_filter, module_fqn_filter
681        )
682
683        # for histogram, we just care about plotting the y data
684        # plot based on whether data is per channel or not
685        ax = plt.subplot()
686        ax.set_xlabel(feature_filter)
687        ax.set_ylabel("Frequency")
688        ax.set_title(feature_filter + " Histogram")
689
690        if data_per_channel:
691            # set the legend as well
692            # combine all the data
693            all_data = []
694            for channel_info in y_data:
695                all_data.extend(channel_info)
696
697            val, bins, _ = plt.hist(
698                all_data,
699                bins=num_bins,
700                stacked=True,
701                rwidth=0.8,
702            )
703            plt.xticks(bins)
704        else:
705            val, bins, _ = plt.hist(
706                y_data,
707                bins=num_bins,
708                stacked=False,
709                rwidth=0.8,
710            )
711            plt.xticks(bins)
712
713        plt.show()
714