xref: /aosp_15_r20/external/pytorch/scripts/release_notes/classifier.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import math
3import pickle
4import random
5from dataclasses import dataclass
6from itertools import chain
7from pathlib import Path
8from typing import Dict, List
9
10import common
11import pandas as pd
12import torchtext
13from torchtext.functional import to_tensor
14from tqdm import tqdm
15
16import torch
17import torch.nn as nn
18
19
20XLMR_BASE = torchtext.models.XLMR_BASE_ENCODER
21# This should not be here but it works for now
22device = "cuda" if torch.cuda.is_available() else "cpu"
23
24HAS_IMBLEARN = False
25try:
26    import imblearn
27
28    HAS_IMBLEARN = True
29except ImportError:
30    HAS_IMBLEARN = False
31
32# 94% of all files are captured at len 5, good hyperparameter to play around with.
33MAX_LEN_FILE = 6
34
35UNKNOWN_TOKEN = "<Unknown>"
36
37# Utilities for working with a truncated file graph
38
39
40def truncate_file(file: Path, max_len: int = 5):
41    return ("/").join(file.parts[:max_len])
42
43
44def build_file_set(all_files: List[Path], max_len: int):
45    truncated_files = [truncate_file(file, max_len) for file in all_files]
46    return set(truncated_files)
47
48
49@dataclass
50class CommitClassifierInputs:
51    title: List[str]
52    files: List[str]
53    author: List[str]
54
55
56@dataclass
57class CategoryConfig:
58    categories: List[str]
59    input_dim: int = 768
60    inner_dim: int = 128
61    dropout: float = 0.1
62    activation = nn.ReLU
63    embedding_dim: int = 8
64    file_embedding_dim: int = 32
65
66
67class CommitClassifier(nn.Module):
68    def __init__(
69        self,
70        encoder_base: torchtext.models.XLMR_BASE_ENCODER,
71        author_map: Dict[str, int],
72        file_map: [str, int],
73        config: CategoryConfig,
74    ):
75        super().__init__()
76        self.encoder = encoder_base.get_model().requires_grad_(False)
77        self.transform = encoder_base.transform()
78        self.author_map = author_map
79        self.file_map = file_map
80        self.categories = config.categories
81        self.num_authors = len(author_map)
82        self.num_files = len(file_map)
83        self.embedding_table = nn.Embedding(self.num_authors, config.embedding_dim)
84        self.file_embedding_bag = nn.EmbeddingBag(
85            self.num_files, config.file_embedding_dim, mode="sum"
86        )
87        self.dense_title = nn.Linear(config.input_dim, config.inner_dim)
88        self.dense_files = nn.Linear(config.file_embedding_dim, config.inner_dim)
89        self.dense_author = nn.Linear(config.embedding_dim, config.inner_dim)
90        self.dropout = nn.Dropout(config.dropout)
91        self.out_proj_title = nn.Linear(config.inner_dim, len(self.categories))
92        self.out_proj_files = nn.Linear(config.inner_dim, len(self.categories))
93        self.out_proj_author = nn.Linear(config.inner_dim, len(self.categories))
94        self.activation_fn = config.activation()
95
96    def forward(self, input_batch: CommitClassifierInputs):
97        # Encode input title
98        title: List[str] = input_batch.title
99        model_input = to_tensor(self.transform(title), padding_value=1).to(device)
100        title_features = self.encoder(model_input)
101        title_embed = title_features[:, 0, :]
102        title_embed = self.dropout(title_embed)
103        title_embed = self.dense_title(title_embed)
104        title_embed = self.activation_fn(title_embed)
105        title_embed = self.dropout(title_embed)
106        title_embed = self.out_proj_title(title_embed)
107
108        files: list[str] = input_batch.files
109        batch_file_indexes = []
110        for file in files:
111            paths = [
112                truncate_file(Path(file_part), MAX_LEN_FILE)
113                for file_part in file.split(" ")
114            ]
115            batch_file_indexes.append(
116                [
117                    self.file_map.get(file, self.file_map[UNKNOWN_TOKEN])
118                    for file in paths
119                ]
120            )
121
122        flat_indexes = torch.tensor(
123            list(chain.from_iterable(batch_file_indexes)),
124            dtype=torch.long,
125            device=device,
126        )
127        offsets = [0]
128        offsets.extend(len(files) for files in batch_file_indexes[:-1])
129        offsets = torch.tensor(offsets, dtype=torch.long, device=device)
130        offsets = offsets.cumsum(dim=0)
131
132        files_embed = self.file_embedding_bag(flat_indexes, offsets)
133        files_embed = self.dense_files(files_embed)
134        files_embed = self.activation_fn(files_embed)
135        files_embed = self.dropout(files_embed)
136        files_embed = self.out_proj_files(files_embed)
137
138        # Add author embedding
139        authors: List[str] = input_batch.author
140        author_ids = [
141            self.author_map.get(author, self.author_map[UNKNOWN_TOKEN])
142            for author in authors
143        ]
144        author_ids = torch.tensor(author_ids).to(device)
145        author_embed = self.embedding_table(author_ids)
146        author_embed = self.dense_author(author_embed)
147        author_embed = self.activation_fn(author_embed)
148        author_embed = self.dropout(author_embed)
149        author_embed = self.out_proj_author(author_embed)
150
151        return title_embed + files_embed + author_embed
152
153    def convert_index_to_category_name(self, most_likely_index):
154        if isinstance(most_likely_index, int):
155            return self.categories[most_likely_index]
156        elif isinstance(most_likely_index, torch.Tensor):
157            return [self.categories[i] for i in most_likely_index]
158
159    def get_most_likely_category_name(self, inpt):
160        # Input will be a dict with title and author keys
161        logits = self.forward(inpt)
162        most_likely_index = torch.argmax(logits, dim=1)
163        return self.convert_index_to_category_name(most_likely_index)
164
165
166def get_train_val_data(data_folder: Path, regen_data: bool, train_percentage=0.95):
167    if (
168        not regen_data
169        and Path(data_folder / "train_df.csv").exists()
170        and Path(data_folder / "val_df.csv").exists()
171    ):
172        train_data = pd.read_csv(data_folder / "train_df.csv")
173        val_data = pd.read_csv(data_folder / "val_df.csv")
174        return train_data, val_data
175    else:
176        print("Train, Val, Test Split not found generating from scratch.")
177        commit_list_df = pd.read_csv(data_folder / "commitlist.csv")
178        test_df = commit_list_df[commit_list_df["category"] == "Uncategorized"]
179        all_train_df = commit_list_df[commit_list_df["category"] != "Uncategorized"]
180        # We are going to drop skip from training set since it is so imbalanced
181        print(
182            "We are removing skip categories, YOU MIGHT WANT TO CHANGE THIS, BUT THIS IS A MORE HELPFUL CLASSIFIER FOR LABELING."
183        )
184        all_train_df = all_train_df[all_train_df["category"] != "skip"]
185        all_train_df = all_train_df.sample(frac=1).reset_index(drop=True)
186        split_index = math.floor(train_percentage * len(all_train_df))
187        train_df = all_train_df[:split_index]
188        val_df = all_train_df[split_index:]
189        print("Train data size: ", len(train_df))
190        print("Val data size: ", len(val_df))
191
192        test_df.to_csv(data_folder / "test_df.csv", index=False)
193        train_df.to_csv(data_folder / "train_df.csv", index=False)
194        val_df.to_csv(data_folder / "val_df.csv", index=False)
195        return train_df, val_df
196
197
198def get_author_map(data_folder: Path, regen_data, assert_stored=False):
199    if not regen_data and Path(data_folder / "author_map.pkl").exists():
200        with open(data_folder / "author_map.pkl", "rb") as f:
201            return pickle.load(f)
202    else:
203        if assert_stored:
204            raise FileNotFoundError(
205                "Author map not found, you are loading for inference you need to have an author map!"
206            )
207        print("Regenerating Author Map")
208        all_data = pd.read_csv(data_folder / "commitlist.csv")
209        authors = all_data.author.unique().tolist()
210        authors.append(UNKNOWN_TOKEN)
211        author_map = {author: i for i, author in enumerate(authors)}
212        with open(data_folder / "author_map.pkl", "wb") as f:
213            pickle.dump(author_map, f)
214        return author_map
215
216
217def get_file_map(data_folder: Path, regen_data, assert_stored=False):
218    if not regen_data and Path(data_folder / "file_map.pkl").exists():
219        with open(data_folder / "file_map.pkl", "rb") as f:
220            return pickle.load(f)
221    else:
222        if assert_stored:
223            raise FileNotFoundError(
224                "File map not found, you are loading for inference you need to have a file map!"
225            )
226        print("Regenerating File Map")
227        all_data = pd.read_csv(data_folder / "commitlist.csv")
228        # Lets explore files
229        files = all_data.files_changed.to_list()
230
231        all_files = []
232        for file in files:
233            paths = [Path(file_part) for file_part in file.split(" ")]
234            all_files.extend(paths)
235        all_files.append(Path(UNKNOWN_TOKEN))
236        file_set = build_file_set(all_files, MAX_LEN_FILE)
237        file_map = {file: i for i, file in enumerate(file_set)}
238        with open(data_folder / "file_map.pkl", "wb") as f:
239            pickle.dump(file_map, f)
240        return file_map
241
242
243#  Generate a dataset for training
244
245
246def get_title_files_author_categories_zip_list(dataframe: pd.DataFrame):
247    title = dataframe.title.to_list()
248    files_str = dataframe.files_changed.to_list()
249    author = dataframe.author.fillna(UNKNOWN_TOKEN).to_list()
250    category = dataframe.category.to_list()
251    return list(zip(title, files_str, author, category))
252
253
254def generate_batch(batch):
255    title, files, author, category = zip(*batch)
256    title = list(title)
257    files = list(files)
258    author = list(author)
259    category = list(category)
260    targets = torch.tensor([common.categories.index(cat) for cat in category]).to(
261        device
262    )
263    return CommitClassifierInputs(title, files, author), targets
264
265
266def train_step(batch, model, optimizer, loss):
267    inpt, targets = batch
268    optimizer.zero_grad()
269    output = model(inpt)
270    l = loss(output, targets)
271    l.backward()
272    optimizer.step()
273    return l
274
275
276@torch.no_grad()
277def eval_step(batch, model, loss):
278    inpt, targets = batch
279    output = model(inpt)
280    l = loss(output, targets)
281    return l
282
283
284def balance_dataset(dataset: List):
285    if not HAS_IMBLEARN:
286        return dataset
287    title, files, author, category = zip(*dataset)
288    category = [common.categories.index(cat) for cat in category]
289    inpt_data = list(zip(title, files, author))
290    from imblearn.over_sampling import RandomOverSampler
291
292    # from imblearn.under_sampling import RandomUnderSampler
293    rus = RandomOverSampler(random_state=42)
294    X, y = rus.fit_resample(inpt_data, category)
295    merged = list(zip(X, y))
296    merged = random.sample(merged, k=2 * len(dataset))
297    X, y = zip(*merged)
298    rebuilt_dataset = []
299    for i in range(len(X)):
300        rebuilt_dataset.append((*X[i], common.categories[y[i]]))
301    return rebuilt_dataset
302
303
304def gen_class_weights(dataset: List):
305    from collections import Counter
306
307    epsilon = 1e-1
308    title, files, author, category = zip(*dataset)
309    category = [common.categories.index(cat) for cat in category]
310    counter = Counter(category)
311    percentile_33 = len(category) // 3
312    most_common = counter.most_common(percentile_33)
313    least_common = counter.most_common()[-percentile_33:]
314    smoothed_top = sum(i[1] + epsilon for i in most_common) / len(most_common)
315    smoothed_bottom = sum(i[1] + epsilon for i in least_common) / len(least_common) // 3
316    class_weights = torch.tensor(
317        [
318            1.0 / (min(max(counter[i], smoothed_bottom), smoothed_top) + epsilon)
319            for i in range(len(common.categories))
320        ],
321        device=device,
322    )
323    return class_weights
324
325
326def train(save_path: Path, data_folder: Path, regen_data: bool, resample: bool):
327    train_data, val_data = get_train_val_data(data_folder, regen_data)
328    train_zip_list = get_title_files_author_categories_zip_list(train_data)
329    val_zip_list = get_title_files_author_categories_zip_list(val_data)
330
331    classifier_config = CategoryConfig(common.categories)
332    author_map = get_author_map(data_folder, regen_data)
333    file_map = get_file_map(data_folder, regen_data)
334    commit_classifier = CommitClassifier(
335        XLMR_BASE, author_map, file_map, classifier_config
336    ).to(device)
337
338    # Lets train this bag of bits
339    class_weights = gen_class_weights(train_zip_list)
340    loss = torch.nn.CrossEntropyLoss(weight=class_weights)
341    optimizer = torch.optim.Adam(commit_classifier.parameters(), lr=3e-3)
342
343    num_epochs = 25
344    batch_size = 256
345
346    if resample:
347        # Lets not use this
348        train_zip_list = balance_dataset(train_zip_list)
349    data_size = len(train_zip_list)
350
351    print(f"Training on {data_size} examples.")
352    # We can fit all of val into one batch
353    val_batch = generate_batch(val_zip_list)
354
355    for i in tqdm(range(num_epochs), desc="Epochs"):
356        start = 0
357        random.shuffle(train_zip_list)
358        while start < data_size:
359            end = start + batch_size
360            # make the last batch bigger if needed
361            if end > data_size:
362                end = data_size
363            train_batch = train_zip_list[start:end]
364            train_batch = generate_batch(train_batch)
365            l = train_step(train_batch, commit_classifier, optimizer, loss)
366            start = end
367
368        val_l = eval_step(val_batch, commit_classifier, loss)
369        tqdm.write(
370            f"Finished epoch {i} with a train loss of: {l.item()} and a val_loss of: {val_l.item()}"
371        )
372
373    with torch.no_grad():
374        commit_classifier.eval()
375        val_inpts, val_targets = val_batch
376        val_output = commit_classifier(val_inpts)
377        val_preds = torch.argmax(val_output, dim=1)
378        val_acc = torch.sum(val_preds == val_targets).item() / len(val_preds)
379        print(f"Final Validation accuracy is {val_acc}")
380
381    print(f"Jobs done! Saving to {save_path}")
382    torch.save(commit_classifier.state_dict(), save_path)
383
384
385def main():
386    parser = argparse.ArgumentParser(
387        description="Tool to create a classifier for helping to categorize commits"
388    )
389
390    parser.add_argument("--train", action="store_true", help="Train a new classifier")
391    parser.add_argument("--commit_data_folder", default="results/classifier/")
392    parser.add_argument(
393        "--save_path", default="results/classifier/commit_classifier.pt"
394    )
395    parser.add_argument(
396        "--regen_data",
397        action="store_true",
398        help="Regenerate the training data, helps if labeled more examples and want to re-train.",
399    )
400    parser.add_argument(
401        "--resample",
402        action="store_true",
403        help="Resample the training data to be balanced. (Only works if imblearn is installed.)",
404    )
405    args = parser.parse_args()
406
407    if args.train:
408        train(
409            Path(args.save_path),
410            Path(args.commit_data_folder),
411            args.regen_data,
412            args.resample,
413        )
414        return
415
416    print(
417        "Currently this file only trains a new classifier please pass in --train to train a new classifier"
418    )
419
420
421if __name__ == "__main__":
422    main()
423