xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/train.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import argparse
4import json
5import warnings
6
7import pandas as pd  # type: ignore[import-untyped]
8
9from torch._inductor.autoheuristic.autoheuristic_utils import (
10    CHOICE_COL,
11    get_metadata_str_from_log,
12)
13
14
15# TODO (AlnisM): Fix these warnings
16warnings.filterwarnings(
17    "ignore",
18    message="The behavior of DataFrame concatenation with empty or all-NA entries is deprecated",
19)
20warnings.filterwarnings(
21    "ignore",
22    message="DataFrameGroupBy.apply operated on the grouping columns.",
23)
24
25
26class AHTrain:
27    """
28    Base class for AutoHeuristic training.
29    """
30
31    def __init__(self) -> None:
32        self.parser = argparse.ArgumentParser()
33        self.add_base_arguments()
34        self.args = None
35
36    def add_base_arguments(self):
37        self.parser.add_argument(
38            "dataset",
39            type=str,
40            help="Path to text file containing data collected with AutoHeuristic.",
41        )
42        self.parser.add_argument(
43            "--nrows",
44            type=int,
45            default=None,
46            help="Only read first n rows of the dataset.",
47        )
48        self.parser.add_argument(
49            "--heuristic-name",
50            type=str,
51            default="learned_heuristic",
52            help="Name of the heuristic to be generated.",
53        )
54        self.parser.add_argument(
55            "--data",
56            nargs=2,
57            action="append",
58            metavar=("TYPE", "PATH"),
59            help="Specify name of datasets and file paths to be evaluated.",
60        )
61        self.parser.add_argument(
62            "--save-dot",
63            action="store_true",
64            help="Export heuristic to graphviz dot.",
65        )
66        self.parser.add_argument(
67            "--ranking",
68            type=int,
69            default=None,
70            help="""
71                Makes AutoHeuristic learn a heuristic that ranks choices instead of predicting a single choice.
72                The argument is the number of choices the heuristic will provide.
73            """,
74        )
75
76    def parse_args(self):
77        return self.parser.parse_args()
78
79    def parse_log(self, log_path, nrows=None):
80        (df, metadata) = self.deserialize_data(log_path)
81        numerical_features = metadata["numerical_features"]
82        categorical_features = metadata["categorical_features"]
83        choices = df[CHOICE_COL].unique().tolist()
84        features = numerical_features + categorical_features
85        if nrows is not None:
86            df = df.head(nrows)
87        df = self.filter_df(df)
88        return (df, metadata, features, categorical_features, choices)
89
90    def generate_heuristic(self):
91        self.args = self.parse_args()
92        self.main(
93            self.args.dataset,
94            self.args.data,
95            self.args.nrows,
96            self.args.heuristic_name,
97            self.args.save_dot,
98            self.args.ranking is not None,
99        )
100
101    def filter_df(self, df):
102        return df
103
104    def add_new_features(self, results):
105        return (results, [])
106
107    def add_real_datasets(self, datasets, other_datasets, cat_feature2cats):
108        if other_datasets:
109            for name, path in other_datasets:
110                (df_other, choices, _, _, _) = self.get_df(
111                    path, cat_feature2cats=cat_feature2cats, apply_filters=False
112                )
113                datasets[name] = df_other
114
115    def handle_categorical_features(
116        self, cat_feature2cats, categorical_features, results
117    ):
118        # Doing this here because if we create another df for testing purposes
119        # and that other df does not contain all categories for a categorical feature,
120        # pd.dummies will not create columns for the missing categories
121        if not cat_feature2cats:
122            cat_feature2cats = {}
123        for cat_feature in categorical_features:
124            if cat_feature in cat_feature2cats:
125                categories = cat_feature2cats[cat_feature]
126            else:
127                categories = results[cat_feature].unique()
128                cat_feature2cats[cat_feature] = categories
129            results[cat_feature] = pd.Categorical(
130                results[cat_feature], categories=categories
131            )
132
133        dummy_col_2_col_val = {}
134        for col in categorical_features:
135            unique_vals = results[col].unique()
136            for val in unique_vals:
137                dummy_col_2_col_val[f"{col}_{val}"] = (col, val)
138        # one-hot encode categorical features
139        results = pd.get_dummies(results, columns=categorical_features)
140        return (results, cat_feature2cats, dummy_col_2_col_val)
141
142    def gen_precondition(self, opt_name, shared_memory, device_capa):
143        return f"""    def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool:
144        return (
145            metadata.name == self.get_name()
146            and metadata.shared_memory == {shared_memory}
147            and str(metadata.device_capa) == "{device_capa}"
148        )"""
149
150    def codegen_boilerplate(
151        self, heuristic_name, opt_name, threshold, shared_memory, device_capa, dt
152    ):
153        pass
154
155    def gen_predict_fn_def(self):
156        pass
157
158    def write_heuristic_to_file(self, lines, heuristic_name):
159        output_file = (
160            f"../../../torch/_inductor/autoheuristic/artifacts/_{heuristic_name}.py"
161        )
162        path = f"{output_file}"
163        with open(path, "w") as f:
164            f.write("\n".join(lines) + "\n")
165
166    def deserialize_data(self, log_path):
167        json_string = get_metadata_str_from_log(log_path)
168        metadata = self.deserialize_metadata(json_string)
169
170        df = pd.read_csv(log_path, skiprows=1, on_bad_lines="skip")
171        return (df, metadata)
172
173    def deserialize_metadata(self, json_string):
174        return json.loads(json_string)
175
176
177if __name__ == "__main__":
178    train = AHTrain()
179    train.generate_heuristic()
180