xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/vision_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import cast
2
3import torchvision_models as models
4from utils import check_for_functorch, extract_weights, GetterReturnType, load_weights
5
6import torch
7from torch import Tensor
8
9
10has_functorch = check_for_functorch()
11
12
13def get_resnet18(device: torch.device) -> GetterReturnType:
14    N = 32
15    model = models.resnet18(pretrained=False)
16
17    if has_functorch:
18        from functorch.experimental import replace_all_batch_norm_modules_
19
20        replace_all_batch_norm_modules_(model)
21
22    criterion = torch.nn.CrossEntropyLoss()
23    model.to(device)
24    params, names = extract_weights(model)
25
26    inputs = torch.rand([N, 3, 224, 224], device=device)
27    labels = torch.rand(N, device=device).mul(10).long()
28
29    def forward(*new_params: Tensor) -> Tensor:
30        load_weights(model, names, new_params)
31        out = model(inputs)
32
33        loss = criterion(out, labels)
34        return loss
35
36    return forward, params
37
38
39def get_fcn_resnet(device: torch.device) -> GetterReturnType:
40    N = 8
41    criterion = torch.nn.MSELoss()
42    model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False)
43
44    if has_functorch:
45        from functorch.experimental import replace_all_batch_norm_modules_
46
47        replace_all_batch_norm_modules_(model)
48        # disable dropout for consistency checking
49        model.eval()
50
51    model.to(device)
52    params, names = extract_weights(model)
53
54    inputs = torch.rand([N, 3, 480, 480], device=device)
55    # Given model has 21 classes
56    labels = torch.rand([N, 21, 480, 480], device=device)
57
58    def forward(*new_params: Tensor) -> Tensor:
59        load_weights(model, names, new_params)
60        out = model(inputs)["out"]
61
62        loss = criterion(out, labels)
63        return loss
64
65    return forward, params
66
67
68def get_detr(device: torch.device) -> GetterReturnType:
69    # All values below are from CLI defaults in https://github.com/facebookresearch/detr
70    N = 2
71    num_classes = 91
72    hidden_dim = 256
73    nheads = 8
74    num_encoder_layers = 6
75    num_decoder_layers = 6
76
77    model = models.DETR(
78        num_classes=num_classes,
79        hidden_dim=hidden_dim,
80        nheads=nheads,
81        num_encoder_layers=num_encoder_layers,
82        num_decoder_layers=num_decoder_layers,
83    )
84
85    if has_functorch:
86        from functorch.experimental import replace_all_batch_norm_modules_
87
88        replace_all_batch_norm_modules_(model)
89
90    losses = ["labels", "boxes", "cardinality"]
91    eos_coef = 0.1
92    bbox_loss_coef = 5
93    giou_loss_coef = 2
94    weight_dict = {
95        "loss_ce": 1,
96        "loss_bbox": bbox_loss_coef,
97        "loss_giou": giou_loss_coef,
98    }
99    matcher = models.HungarianMatcher(1, 5, 2)
100    criterion = models.SetCriterion(
101        num_classes=num_classes,
102        matcher=matcher,
103        weight_dict=weight_dict,
104        eos_coef=eos_coef,
105        losses=losses,
106    )
107
108    model = model.to(device)
109    criterion = criterion.to(device)
110    params, names = extract_weights(model)
111
112    inputs = torch.rand(N, 3, 800, 1200, device=device)
113    labels = []
114    for idx in range(N):
115        targets = {}
116        n_targets: int = int(torch.randint(5, 10, size=()).item())
117        label = torch.randint(5, 10, size=(n_targets,), device=device)
118        targets["labels"] = label
119        boxes = torch.randint(100, 800, size=(n_targets, 4), device=device)
120        for t in range(n_targets):
121            if boxes[t, 0] > boxes[t, 2]:
122                boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0]
123            if boxes[t, 1] > boxes[t, 3]:
124                boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1]
125        targets["boxes"] = boxes.float()
126        labels.append(targets)
127
128    def forward(*new_params: Tensor) -> Tensor:
129        load_weights(model, names, new_params)
130        out = model(inputs)
131
132        loss = criterion(out, labels)
133        weight_dict = criterion.weight_dict
134        final_loss = cast(
135            Tensor,
136            sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict),
137        )
138        return final_loss
139
140    return forward, params
141