1# Taken from https://github.com/pytorch/vision 2# So that we don't need torchvision to be installed 3from collections import OrderedDict 4 5import torch 6from torch import nn 7from torch.jit.annotations import Dict 8from torch.nn import functional as F 9 10 11try: 12 from scipy.optimize import linear_sum_assignment 13 14 scipy_available = True 15except Exception: 16 scipy_available = False 17 18 19def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 """3x3 convolution with padding""" 21 return nn.Conv2d( 22 in_planes, 23 out_planes, 24 kernel_size=3, 25 stride=stride, 26 padding=dilation, 27 groups=groups, 28 bias=False, 29 dilation=dilation, 30 ) 31 32 33def conv1x1(in_planes, out_planes, stride=1): 34 """1x1 convolution""" 35 return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 37 38class BasicBlock(nn.Module): 39 expansion = 1 40 41 def __init__( 42 self, 43 inplanes, 44 planes, 45 stride=1, 46 downsample=None, 47 groups=1, 48 base_width=64, 49 dilation=1, 50 norm_layer=None, 51 ): 52 super().__init__() 53 if norm_layer is None: 54 norm_layer = nn.BatchNorm2d 55 if groups != 1 or base_width != 64: 56 raise ValueError("BasicBlock only supports groups=1 and base_width=64") 57 if dilation > 1: 58 raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 59 # Both self.conv1 and self.downsample layers downsample the input when stride != 1 60 self.conv1 = conv3x3(inplanes, planes, stride) 61 self.bn1 = norm_layer(planes) 62 self.relu = nn.ReLU(inplace=True) 63 self.conv2 = conv3x3(planes, planes) 64 self.bn2 = norm_layer(planes) 65 self.downsample = downsample 66 self.stride = stride 67 68 def forward(self, x): 69 identity = x 70 71 out = self.conv1(x) 72 out = self.bn1(out) 73 out = self.relu(out) 74 75 out = self.conv2(out) 76 out = self.bn2(out) 77 78 if self.downsample is not None: 79 identity = self.downsample(x) 80 81 out += identity 82 out = self.relu(out) 83 84 return out 85 86 87class Bottleneck(nn.Module): 88 # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 89 # while original implementation places the stride at the first 1x1 convolution(self.conv1) 90 # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 91 # This variant is also known as ResNet V1.5 and improves accuracy according to 92 # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 93 94 expansion = 4 95 96 def __init__( 97 self, 98 inplanes, 99 planes, 100 stride=1, 101 downsample=None, 102 groups=1, 103 base_width=64, 104 dilation=1, 105 norm_layer=None, 106 ): 107 super().__init__() 108 if norm_layer is None: 109 norm_layer = nn.BatchNorm2d 110 width = int(planes * (base_width / 64.0)) * groups 111 # Both self.conv2 and self.downsample layers downsample the input when stride != 1 112 self.conv1 = conv1x1(inplanes, width) 113 self.bn1 = norm_layer(width) 114 self.conv2 = conv3x3(width, width, stride, groups, dilation) 115 self.bn2 = norm_layer(width) 116 self.conv3 = conv1x1(width, planes * self.expansion) 117 self.bn3 = norm_layer(planes * self.expansion) 118 self.relu = nn.ReLU(inplace=True) 119 self.downsample = downsample 120 self.stride = stride 121 122 def forward(self, x): 123 identity = x 124 125 out = self.conv1(x) 126 out = self.bn1(out) 127 out = self.relu(out) 128 129 out = self.conv2(out) 130 out = self.bn2(out) 131 out = self.relu(out) 132 133 out = self.conv3(out) 134 out = self.bn3(out) 135 136 if self.downsample is not None: 137 identity = self.downsample(x) 138 139 out += identity 140 out = self.relu(out) 141 142 return out 143 144 145class ResNet(nn.Module): 146 def __init__( 147 self, 148 block, 149 layers, 150 num_classes=1000, 151 zero_init_residual=False, 152 groups=1, 153 width_per_group=64, 154 replace_stride_with_dilation=None, 155 norm_layer=None, 156 ): 157 super().__init__() 158 if norm_layer is None: 159 norm_layer = nn.BatchNorm2d 160 self._norm_layer = norm_layer 161 162 self.inplanes = 64 163 self.dilation = 1 164 if replace_stride_with_dilation is None: 165 # each element in the tuple indicates if we should replace 166 # the 2x2 stride with a dilated convolution instead 167 replace_stride_with_dilation = [False, False, False] 168 if len(replace_stride_with_dilation) != 3: 169 raise ValueError( 170 "replace_stride_with_dilation should be None " 171 f"or a 3-element tuple, got {replace_stride_with_dilation}" 172 ) 173 self.groups = groups 174 self.base_width = width_per_group 175 self.conv1 = nn.Conv2d( 176 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 177 ) 178 self.bn1 = norm_layer(self.inplanes) 179 self.relu = nn.ReLU(inplace=True) 180 self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 self.layer1 = self._make_layer(block, 64, layers[0]) 182 self.layer2 = self._make_layer( 183 block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 184 ) 185 self.layer3 = self._make_layer( 186 block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 187 ) 188 self.layer4 = self._make_layer( 189 block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 190 ) 191 self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 192 self.fc = nn.Linear(512 * block.expansion, num_classes) 193 194 for m in self.modules(): 195 if isinstance(m, nn.Conv2d): 196 nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 197 elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 198 nn.init.constant_(m.weight, 1) 199 nn.init.constant_(m.bias, 0) 200 201 # Zero-initialize the last BN in each residual branch, 202 # so that the residual branch starts with zeros, and each residual block behaves like an identity. 203 # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 204 if zero_init_residual: 205 for m in self.modules(): 206 if isinstance(m, Bottleneck): 207 nn.init.constant_(m.bn3.weight, 0) 208 elif isinstance(m, BasicBlock): 209 nn.init.constant_(m.bn2.weight, 0) 210 211 def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 212 norm_layer = self._norm_layer 213 downsample = None 214 previous_dilation = self.dilation 215 if dilate: 216 self.dilation *= stride 217 stride = 1 218 if stride != 1 or self.inplanes != planes * block.expansion: 219 downsample = nn.Sequential( 220 conv1x1(self.inplanes, planes * block.expansion, stride), 221 norm_layer(planes * block.expansion), 222 ) 223 224 layers = [] 225 layers.append( 226 block( 227 self.inplanes, 228 planes, 229 stride, 230 downsample, 231 self.groups, 232 self.base_width, 233 previous_dilation, 234 norm_layer, 235 ) 236 ) 237 self.inplanes = planes * block.expansion 238 for _ in range(1, blocks): 239 layers.append( 240 block( 241 self.inplanes, 242 planes, 243 groups=self.groups, 244 base_width=self.base_width, 245 dilation=self.dilation, 246 norm_layer=norm_layer, 247 ) 248 ) 249 250 return nn.Sequential(*layers) 251 252 def _forward_impl(self, x): 253 # See note [TorchScript super()] 254 x = self.conv1(x) 255 x = self.bn1(x) 256 x = self.relu(x) 257 x = self.maxpool(x) 258 259 x = self.layer1(x) 260 x = self.layer2(x) 261 x = self.layer3(x) 262 x = self.layer4(x) 263 264 x = self.avgpool(x) 265 x = torch.flatten(x, 1) 266 x = self.fc(x) 267 268 return x 269 270 def forward(self, x): 271 return self._forward_impl(x) 272 273 274def _resnet(arch, block, layers, pretrained, progress, **kwargs): 275 model = ResNet(block, layers, **kwargs) 276 # if pretrained: 277 # state_dict = load_state_dict_from_url(model_urls[arch], 278 # progress=progress) 279 # model.load_state_dict(state_dict) 280 return model 281 282 283def resnet18(pretrained=False, progress=True, **kwargs): 284 r"""ResNet-18 model from 285 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 286 Args: 287 pretrained (bool): If True, returns a model pre-trained on ImageNet 288 progress (bool): If True, displays a progress bar of the download to stderr 289 """ 290 return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 291 292 293def resnet50(pretrained=False, progress=True, **kwargs): 294 r"""ResNet-50 model from 295 `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ 296 Args: 297 pretrained (bool): If True, returns a model pre-trained on ImageNet 298 progress (bool): If True, displays a progress bar of the download to stderr 299 """ 300 return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 301 302 303class IntermediateLayerGetter(nn.ModuleDict): 304 """ 305 Module wrapper that returns intermediate layers from a model 306 It has a strong assumption that the modules have been registered 307 into the model in the same order as they are used. 308 This means that one should **not** reuse the same nn.Module 309 twice in the forward if you want this to work. 310 Additionally, it is only able to query submodules that are directly 311 assigned to the model. So if `model` is passed, `model.feature1` can 312 be returned, but not `model.feature1.layer2`. 313 Args: 314 model (nn.Module): model on which we will extract the features 315 return_layers (Dict[name, new_name]): a dict containing the names 316 of the modules for which the activations will be returned as 317 the key of the dict, and the value of the dict is the name 318 of the returned activation (which the user can specify). 319 Examples:: 320 >>> m = torchvision.models.resnet18(pretrained=True) 321 >>> # extract layer1 and layer3, giving as names `feat1` and feat2` 322 >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, 323 >>> {'layer1': 'feat1', 'layer3': 'feat2'}) 324 >>> out = new_m(torch.rand(1, 3, 224, 224)) 325 >>> print([(k, v.shape) for k, v in out.items()]) 326 >>> [('feat1', torch.Size([1, 64, 56, 56])), 327 >>> ('feat2', torch.Size([1, 256, 14, 14]))] 328 """ 329 330 _version = 2 331 __annotations__ = { 332 "return_layers": Dict[str, str], 333 } 334 335 def __init__(self, model, return_layers): 336 if not set(return_layers).issubset( 337 [name for name, _ in model.named_children()] 338 ): 339 raise ValueError("return_layers are not present in model") 340 orig_return_layers = return_layers 341 return_layers = {str(k): str(v) for k, v in return_layers.items()} 342 layers = OrderedDict() 343 for name, module in model.named_children(): 344 layers[name] = module 345 if name in return_layers: 346 del return_layers[name] 347 if not return_layers: 348 break 349 350 super().__init__(layers) 351 self.return_layers = orig_return_layers 352 353 def forward(self, x): 354 out = OrderedDict() 355 for name, module in self.items(): 356 x = module(x) 357 if name in self.return_layers: 358 out_name = self.return_layers[name] 359 out[out_name] = x 360 return out 361 362 363class _SimpleSegmentationModel(nn.Module): 364 __constants__ = ["aux_classifier"] 365 366 def __init__(self, backbone, classifier, aux_classifier=None): 367 super().__init__() 368 self.backbone = backbone 369 self.classifier = classifier 370 self.aux_classifier = aux_classifier 371 372 def forward(self, x): 373 input_shape = x.shape[-2:] 374 # contract: features is a dict of tensors 375 features = self.backbone(x) 376 377 result = OrderedDict() 378 x = features["out"] 379 x = self.classifier(x) 380 x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) 381 result["out"] = x 382 383 if self.aux_classifier is not None: 384 x = features["aux"] 385 x = self.aux_classifier(x) 386 x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) 387 result["aux"] = x 388 389 return result 390 391 392class FCN(_SimpleSegmentationModel): 393 """ 394 Implements a Fully-Convolutional Network for semantic segmentation. 395 Args: 396 backbone (nn.Module): the network used to compute the features for the model. 397 The backbone should return an OrderedDict[Tensor], with the key being 398 "out" for the last feature map used, and "aux" if an auxiliary classifier 399 is used. 400 classifier (nn.Module): module that takes the "out" element returned from 401 the backbone and returns a dense prediction. 402 aux_classifier (nn.Module, optional): auxiliary classifier used during training 403 """ 404 405 406class FCNHead(nn.Sequential): 407 def __init__(self, in_channels, channels): 408 inter_channels = in_channels // 4 409 layers = [ 410 nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 411 nn.BatchNorm2d(inter_channels), 412 nn.ReLU(), 413 nn.Dropout(0.1), 414 nn.Conv2d(inter_channels, channels, 1), 415 ] 416 417 super().__init__(*layers) 418 419 420def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): 421 # backbone = resnet.__dict__[backbone_name]( 422 # pretrained=pretrained_backbone, 423 # replace_stride_with_dilation=[False, True, True]) 424 # Hardcoded resnet 50 425 assert backbone_name == "resnet50" 426 backbone = resnet50( 427 pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True] 428 ) 429 430 return_layers = {"layer4": "out"} 431 if aux: 432 return_layers["layer3"] = "aux" 433 backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 434 435 aux_classifier = None 436 if aux: 437 inplanes = 1024 438 aux_classifier = FCNHead(inplanes, num_classes) 439 440 model_map = { 441 # 'deeplabv3': (DeepLabHead, DeepLabV3), # Not used 442 "fcn": (FCNHead, FCN), 443 } 444 inplanes = 2048 445 classifier = model_map[name][0](inplanes, num_classes) 446 base_model = model_map[name][1] 447 448 model = base_model(backbone, classifier, aux_classifier) 449 return model 450 451 452def _load_model( 453 arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs 454): 455 if pretrained: 456 aux_loss = True 457 model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs) 458 # if pretrained: 459 # arch = arch_type + '_' + backbone + '_coco' 460 # model_url = model_urls[arch] 461 # if model_url is None: 462 # raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 463 # else: 464 # state_dict = load_state_dict_from_url(model_url, progress=progress) 465 # model.load_state_dict(state_dict) 466 return model 467 468 469def fcn_resnet50( 470 pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs 471): 472 """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. 473 Args: 474 pretrained (bool): If True, returns a model pre-trained on COCO train2017 which 475 contains the same classes as Pascal VOC 476 progress (bool): If True, displays a progress bar of the download to stderr 477 """ 478 return _load_model( 479 "fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs 480 ) 481 482 483# Taken from @fmassa example slides and https://github.com/facebookresearch/detr 484class DETR(nn.Module): 485 """ 486 Demo DETR implementation. 487 488 Demo implementation of DETR in minimal number of lines, with the 489 following differences wrt DETR in the paper: 490 * learned positional encoding (instead of sine) 491 * positional encoding is passed at input (instead of attention) 492 * fc bbox predictor (instead of MLP) 493 The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100. 494 Only batch size 1 supported. 495 """ 496 497 def __init__( 498 self, 499 num_classes, 500 hidden_dim=256, 501 nheads=8, 502 num_encoder_layers=6, 503 num_decoder_layers=6, 504 ): 505 super().__init__() 506 507 # create ResNet-50 backbone 508 self.backbone = resnet50() 509 del self.backbone.fc 510 511 # create conversion layer 512 self.conv = nn.Conv2d(2048, hidden_dim, 1) 513 514 # create a default PyTorch transformer 515 self.transformer = nn.Transformer( 516 hidden_dim, nheads, num_encoder_layers, num_decoder_layers 517 ) 518 519 # prediction heads, one extra class for predicting non-empty slots 520 # note that in baseline DETR linear_bbox layer is 3-layer MLP 521 self.linear_class = nn.Linear(hidden_dim, num_classes + 1) 522 self.linear_bbox = nn.Linear(hidden_dim, 4) 523 524 # output positional encodings (object queries) 525 self.query_pos = nn.Parameter(torch.rand(100, hidden_dim)) 526 527 # spatial positional encodings 528 # note that in baseline DETR we use sine positional encodings 529 self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) 530 self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)) 531 532 def forward(self, inputs): 533 # propagate inputs through ResNet-50 up to avg-pool layer 534 x = self.backbone.conv1(inputs) 535 x = self.backbone.bn1(x) 536 x = self.backbone.relu(x) 537 x = self.backbone.maxpool(x) 538 539 x = self.backbone.layer1(x) 540 x = self.backbone.layer2(x) 541 x = self.backbone.layer3(x) 542 x = self.backbone.layer4(x) 543 544 # convert from 2048 to 256 feature planes for the transformer 545 h = self.conv(x) 546 547 # construct positional encodings 548 H, W = h.shape[-2:] 549 pos = ( 550 torch.cat( 551 [ 552 self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1), 553 self.row_embed[:H].unsqueeze(1).repeat(1, W, 1), 554 ], 555 dim=-1, 556 ) 557 .flatten(0, 1) 558 .unsqueeze(1) 559 ) 560 561 # propagate through the transformer 562 # TODO (alband) Why this is not automatically broadcasted? (had to add the repeat) 563 f = pos + 0.1 * h.flatten(2).permute(2, 0, 1) 564 s = self.query_pos.unsqueeze(1) 565 s = s.expand(s.size(0), inputs.size(0), s.size(2)) 566 h = self.transformer(f, s).transpose(0, 1) 567 568 # finally project transformer outputs to class labels and bounding boxes 569 return { 570 "pred_logits": self.linear_class(h), 571 "pred_boxes": self.linear_bbox(h).sigmoid(), 572 } 573 574 575def generalized_box_iou(boxes1, boxes2): 576 """ 577 Generalized IoU from https://giou.stanford.edu/ 578 The boxes should be in [x0, y0, x1, y1] format 579 Returns a [N, M] pairwise matrix, where N = len(boxes1) 580 and M = len(boxes2) 581 """ 582 # degenerate boxes gives inf / nan results 583 # so do an early check 584 assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 585 assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 586 iou, union = box_iou(boxes1, boxes2) 587 588 lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 589 rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 590 591 wh = (rb - lt).clamp(min=0) # [N,M,2] 592 area = wh[:, :, 0] * wh[:, :, 1] 593 594 return iou - (area - union) / area 595 596 597def box_cxcywh_to_xyxy(x): 598 x_c, y_c, w, h = x.unbind(-1) 599 b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 600 return torch.stack(b, dim=-1) 601 602 603def box_area(boxes): 604 """ 605 Computes the area of a set of bounding boxes, which are specified by its 606 (x1, y1, x2, y2) coordinates. 607 Args: 608 boxes (Tensor[N, 4]): boxes for which the area will be computed. They 609 are expected to be in (x1, y1, x2, y2) format 610 Returns: 611 area (Tensor[N]): area for each box 612 """ 613 return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 614 615 616# modified from torchvision to also return the union 617def box_iou(boxes1, boxes2): 618 area1 = box_area(boxes1) 619 area2 = box_area(boxes2) 620 621 lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 622 rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 623 624 wh = (rb - lt).clamp(min=0) # [N,M,2] 625 inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 626 627 union = area1[:, None] + area2 - inter 628 629 iou = inter / union 630 return iou, union 631 632 633def is_dist_avail_and_initialized(): 634 return False 635 636 637def get_world_size(): 638 if not is_dist_avail_and_initialized(): 639 return 1 640 641 642@torch.no_grad() 643def accuracy(output, target, topk=(1,)): 644 """Computes the precision@k for the specified values of k""" 645 if target.numel() == 0: 646 return [torch.zeros([], device=output.device)] 647 maxk = max(topk) 648 batch_size = target.size(0) 649 650 _, pred = output.topk(maxk, 1, True, True) 651 pred = pred.t() 652 correct = pred.eq(target.view(1, -1).expand_as(pred)) 653 654 res = [] 655 for k in topk: 656 correct_k = correct[:k].view(-1).float().sum(0) 657 res.append(correct_k.mul_(100.0 / batch_size)) 658 return res 659 660 661class SetCriterion(nn.Module): 662 """This class computes the loss for DETR. 663 The process happens in two steps: 664 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 665 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 666 """ 667 668 def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): 669 """Create the criterion. 670 Parameters: 671 num_classes: number of object categories, omitting the special no-object category 672 matcher: module able to compute a matching between targets and proposals 673 weight_dict: dict containing as key the names of the losses and as values their relative weight. 674 eos_coef: relative classification weight applied to the no-object category 675 losses: list of all the losses to be applied. See get_loss for list of available losses. 676 """ 677 super().__init__() 678 self.num_classes = num_classes 679 self.matcher = matcher 680 self.weight_dict = weight_dict 681 self.eos_coef = eos_coef 682 self.losses = losses 683 empty_weight = torch.ones(self.num_classes + 1) 684 empty_weight[-1] = self.eos_coef 685 self.register_buffer("empty_weight", empty_weight) 686 687 def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 688 """Classification loss (NLL) 689 targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 690 """ 691 assert "pred_logits" in outputs 692 src_logits = outputs["pred_logits"] 693 694 idx = self._get_src_permutation_idx(indices) 695 target_classes_o = torch.cat( 696 [t["labels"][J] for t, (_, J) in zip(targets, indices)] 697 ) 698 target_classes = torch.full( 699 src_logits.shape[:2], 700 self.num_classes, 701 dtype=torch.int64, 702 device=src_logits.device, 703 ) 704 target_classes[idx] = target_classes_o 705 706 loss_ce = F.cross_entropy( 707 src_logits.transpose(1, 2), target_classes, self.empty_weight 708 ) 709 losses = {"loss_ce": loss_ce} 710 711 if log: 712 # TODO this should probably be a separate loss, not hacked in this one here 713 losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 714 return losses 715 716 @torch.no_grad() 717 def loss_cardinality(self, outputs, targets, indices, num_boxes): 718 """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 719 This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 720 """ 721 pred_logits = outputs["pred_logits"] 722 device = pred_logits.device 723 tgt_lengths = torch.as_tensor( 724 [len(v["labels"]) for v in targets], device=device 725 ) 726 # Count the number of predictions that are NOT "no-object" (which is the last class) 727 card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 728 card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 729 losses = {"cardinality_error": card_err} 730 return losses 731 732 def loss_boxes(self, outputs, targets, indices, num_boxes): 733 """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 734 targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 735 The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. 736 """ 737 assert "pred_boxes" in outputs 738 idx = self._get_src_permutation_idx(indices) 739 src_boxes = outputs["pred_boxes"][idx] 740 target_boxes = torch.cat( 741 [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0 742 ) 743 744 loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") 745 746 losses = {} 747 losses["loss_bbox"] = loss_bbox.sum() / num_boxes 748 749 loss_giou = 1 - torch.diag( 750 generalized_box_iou( 751 box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes) 752 ) 753 ) 754 losses["loss_giou"] = loss_giou.sum() / num_boxes 755 return losses 756 757 def loss_masks(self, outputs, targets, indices, num_boxes): 758 """Compute the losses related to the masks: the focal loss and the dice loss. 759 targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 760 """ 761 assert "pred_masks" in outputs 762 763 src_idx = self._get_src_permutation_idx(indices) 764 tgt_idx = self._get_tgt_permutation_idx(indices) 765 766 src_masks = outputs["pred_masks"] 767 768 # TODO use valid to mask invalid areas due to padding in loss 769 target_masks, valid = nested_tensor_from_tensor_list( # noqa: F821 770 [t["masks"] for t in targets] 771 ).decompose() 772 target_masks = target_masks.to(src_masks) 773 774 src_masks = src_masks[src_idx] 775 # upsample predictions to the target size 776 src_masks = interpolate( # noqa: F821 777 src_masks[:, None], 778 size=target_masks.shape[-2:], 779 mode="bilinear", 780 align_corners=False, 781 ) 782 src_masks = src_masks[:, 0].flatten(1) 783 784 target_masks = target_masks[tgt_idx].flatten(1) 785 786 losses = { 787 "loss_mask": sigmoid_focal_loss( # noqa: F821 788 src_masks, target_masks, num_boxes 789 ), # noqa: F821 790 "loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821 791 } 792 return losses 793 794 def _get_src_permutation_idx(self, indices): 795 # permute predictions following indices 796 batch_idx = torch.cat( 797 [torch.full_like(src, i) for i, (src, _) in enumerate(indices)] 798 ) 799 src_idx = torch.cat([src for (src, _) in indices]) 800 return batch_idx, src_idx 801 802 def _get_tgt_permutation_idx(self, indices): 803 # permute targets following indices 804 batch_idx = torch.cat( 805 [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)] 806 ) 807 tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 808 return batch_idx, tgt_idx 809 810 def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 811 loss_map = { 812 "labels": self.loss_labels, 813 "cardinality": self.loss_cardinality, 814 "boxes": self.loss_boxes, 815 "masks": self.loss_masks, 816 } 817 assert loss in loss_map, f"do you really want to compute {loss} loss?" 818 return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 819 820 def forward(self, outputs, targets): 821 """This performs the loss computation. 822 Parameters: 823 outputs: dict of tensors, see the output specification of the model for the format 824 targets: list of dicts, such that len(targets) == batch_size. 825 The expected keys in each dict depends on the losses applied, see each loss' doc 826 """ 827 outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 828 829 # Retrieve the matching between the outputs of the last layer and the targets 830 indices = self.matcher(outputs_without_aux, targets) 831 832 # Compute the average number of target boxes across all nodes, for normalization purposes 833 num_boxes = sum(len(t["labels"]) for t in targets) 834 num_boxes = torch.as_tensor( 835 [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device 836 ) 837 if is_dist_avail_and_initialized(): 838 torch.distributed.all_reduce(num_boxes) 839 num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 840 841 # Compute all the requested losses 842 losses = {} 843 for loss in self.losses: 844 losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 845 846 # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 847 if "aux_outputs" in outputs: 848 for i, aux_outputs in enumerate(outputs["aux_outputs"]): 849 indices = self.matcher(aux_outputs, targets) 850 for loss in self.losses: 851 if loss == "masks": 852 # Intermediate masks losses are too costly to compute, we ignore them. 853 continue 854 kwargs = {} 855 if loss == "labels": 856 # Logging is enabled only for the last layer 857 kwargs = {"log": False} 858 l_dict = self.get_loss( 859 loss, aux_outputs, targets, indices, num_boxes, **kwargs 860 ) 861 l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 862 losses.update(l_dict) 863 864 return losses 865 866 867class HungarianMatcher(nn.Module): 868 """This class computes an assignment between the targets and the predictions of the network 869 For efficiency reasons, the targets don't include the no_object. Because of this, in general, 870 there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 871 while the others are un-matched (and thus treated as non-objects). 872 """ 873 874 def __init__( 875 self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1 876 ): 877 """Creates the matcher 878 Params: 879 cost_class: This is the relative weight of the classification error in the matching cost 880 cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 881 cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 882 """ 883 super().__init__() 884 self.cost_class = cost_class 885 self.cost_bbox = cost_bbox 886 self.cost_giou = cost_giou 887 assert ( 888 cost_class != 0 or cost_bbox != 0 or cost_giou != 0 889 ), "all costs cant be 0" 890 891 @torch.no_grad() 892 def forward(self, outputs, targets): 893 """Performs the matching 894 Params: 895 outputs: This is a dict that contains at least these entries: 896 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 897 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 898 targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 899 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 900 objects in the target) containing the class labels 901 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 902 Returns: 903 A list of size batch_size, containing tuples of (index_i, index_j) where: 904 - index_i is the indices of the selected predictions (in order) 905 - index_j is the indices of the corresponding selected targets (in order) 906 For each batch element, it holds: 907 len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 908 """ 909 bs, num_queries = outputs["pred_logits"].shape[:2] 910 911 # We flatten to compute the cost matrices in a batch 912 out_prob = ( 913 outputs["pred_logits"].flatten(0, 1).softmax(-1) 914 ) # [batch_size * num_queries, num_classes] 915 out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 916 917 # Also concat the target labels and boxes 918 tgt_ids = torch.cat([v["labels"] for v in targets]) 919 tgt_bbox = torch.cat([v["boxes"] for v in targets]) 920 921 # Compute the classification cost. Contrary to the loss, we don't use the NLL, 922 # but approximate it in 1 - proba[target class]. 923 # The 1 is a constant that doesn't change the matching, it can be ommitted. 924 cost_class = -out_prob[:, tgt_ids] 925 926 # Compute the L1 cost between boxes 927 cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 928 929 # Compute the giou cost betwen boxes 930 cost_giou = -generalized_box_iou( 931 box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) 932 ) 933 934 # Final cost matrix 935 C = ( 936 self.cost_bbox * cost_bbox 937 + self.cost_class * cost_class 938 + self.cost_giou * cost_giou 939 ) 940 C = C.view(bs, num_queries, -1).cpu() 941 942 sizes = [len(v["boxes"]) for v in targets] 943 if not scipy_available: 944 raise RuntimeError( 945 "The 'detr' model requires scipy to run. Please make sure you have it installed" 946 " if you enable the 'detr' model." 947 ) 948 indices = [ 949 linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1)) 950 ] 951 return [ 952 ( 953 torch.as_tensor(i, dtype=torch.int64), 954 torch.as_tensor(j, dtype=torch.int64), 955 ) 956 for i, j in indices 957 ] 958