1from typing import cast 2 3import torchvision_models as models 4from utils import check_for_functorch, extract_weights, GetterReturnType, load_weights 5 6import torch 7from torch import Tensor 8 9 10has_functorch = check_for_functorch() 11 12 13def get_resnet18(device: torch.device) -> GetterReturnType: 14 N = 32 15 model = models.resnet18(pretrained=False) 16 17 if has_functorch: 18 from functorch.experimental import replace_all_batch_norm_modules_ 19 20 replace_all_batch_norm_modules_(model) 21 22 criterion = torch.nn.CrossEntropyLoss() 23 model.to(device) 24 params, names = extract_weights(model) 25 26 inputs = torch.rand([N, 3, 224, 224], device=device) 27 labels = torch.rand(N, device=device).mul(10).long() 28 29 def forward(*new_params: Tensor) -> Tensor: 30 load_weights(model, names, new_params) 31 out = model(inputs) 32 33 loss = criterion(out, labels) 34 return loss 35 36 return forward, params 37 38 39def get_fcn_resnet(device: torch.device) -> GetterReturnType: 40 N = 8 41 criterion = torch.nn.MSELoss() 42 model = models.fcn_resnet50(pretrained=False, pretrained_backbone=False) 43 44 if has_functorch: 45 from functorch.experimental import replace_all_batch_norm_modules_ 46 47 replace_all_batch_norm_modules_(model) 48 # disable dropout for consistency checking 49 model.eval() 50 51 model.to(device) 52 params, names = extract_weights(model) 53 54 inputs = torch.rand([N, 3, 480, 480], device=device) 55 # Given model has 21 classes 56 labels = torch.rand([N, 21, 480, 480], device=device) 57 58 def forward(*new_params: Tensor) -> Tensor: 59 load_weights(model, names, new_params) 60 out = model(inputs)["out"] 61 62 loss = criterion(out, labels) 63 return loss 64 65 return forward, params 66 67 68def get_detr(device: torch.device) -> GetterReturnType: 69 # All values below are from CLI defaults in https://github.com/facebookresearch/detr 70 N = 2 71 num_classes = 91 72 hidden_dim = 256 73 nheads = 8 74 num_encoder_layers = 6 75 num_decoder_layers = 6 76 77 model = models.DETR( 78 num_classes=num_classes, 79 hidden_dim=hidden_dim, 80 nheads=nheads, 81 num_encoder_layers=num_encoder_layers, 82 num_decoder_layers=num_decoder_layers, 83 ) 84 85 if has_functorch: 86 from functorch.experimental import replace_all_batch_norm_modules_ 87 88 replace_all_batch_norm_modules_(model) 89 90 losses = ["labels", "boxes", "cardinality"] 91 eos_coef = 0.1 92 bbox_loss_coef = 5 93 giou_loss_coef = 2 94 weight_dict = { 95 "loss_ce": 1, 96 "loss_bbox": bbox_loss_coef, 97 "loss_giou": giou_loss_coef, 98 } 99 matcher = models.HungarianMatcher(1, 5, 2) 100 criterion = models.SetCriterion( 101 num_classes=num_classes, 102 matcher=matcher, 103 weight_dict=weight_dict, 104 eos_coef=eos_coef, 105 losses=losses, 106 ) 107 108 model = model.to(device) 109 criterion = criterion.to(device) 110 params, names = extract_weights(model) 111 112 inputs = torch.rand(N, 3, 800, 1200, device=device) 113 labels = [] 114 for idx in range(N): 115 targets = {} 116 n_targets: int = int(torch.randint(5, 10, size=()).item()) 117 label = torch.randint(5, 10, size=(n_targets,), device=device) 118 targets["labels"] = label 119 boxes = torch.randint(100, 800, size=(n_targets, 4), device=device) 120 for t in range(n_targets): 121 if boxes[t, 0] > boxes[t, 2]: 122 boxes[t, 0], boxes[t, 2] = boxes[t, 2], boxes[t, 0] 123 if boxes[t, 1] > boxes[t, 3]: 124 boxes[t, 1], boxes[t, 3] = boxes[t, 3], boxes[t, 1] 125 targets["boxes"] = boxes.float() 126 labels.append(targets) 127 128 def forward(*new_params: Tensor) -> Tensor: 129 load_weights(model, names, new_params) 130 out = model(inputs) 131 132 loss = criterion(out, labels) 133 weight_dict = criterion.weight_dict 134 final_loss = cast( 135 Tensor, 136 sum(loss[k] * weight_dict[k] for k in loss.keys() if k in weight_dict), 137 ) 138 return final_loss 139 140 return forward, params 141