xref: /aosp_15_r20/external/pytorch/benchmarks/functional_autograd_benchmark/torchvision_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Taken from https://github.com/pytorch/vision
2# So that we don't need torchvision to be installed
3from collections import OrderedDict
4
5import torch
6from torch import nn
7from torch.jit.annotations import Dict
8from torch.nn import functional as F
9
10
11try:
12    from scipy.optimize import linear_sum_assignment
13
14    scipy_available = True
15except Exception:
16    scipy_available = False
17
18
19def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
20    """3x3 convolution with padding"""
21    return nn.Conv2d(
22        in_planes,
23        out_planes,
24        kernel_size=3,
25        stride=stride,
26        padding=dilation,
27        groups=groups,
28        bias=False,
29        dilation=dilation,
30    )
31
32
33def conv1x1(in_planes, out_planes, stride=1):
34    """1x1 convolution"""
35    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
36
37
38class BasicBlock(nn.Module):
39    expansion = 1
40
41    def __init__(
42        self,
43        inplanes,
44        planes,
45        stride=1,
46        downsample=None,
47        groups=1,
48        base_width=64,
49        dilation=1,
50        norm_layer=None,
51    ):
52        super().__init__()
53        if norm_layer is None:
54            norm_layer = nn.BatchNorm2d
55        if groups != 1 or base_width != 64:
56            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
57        if dilation > 1:
58            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
59        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
60        self.conv1 = conv3x3(inplanes, planes, stride)
61        self.bn1 = norm_layer(planes)
62        self.relu = nn.ReLU(inplace=True)
63        self.conv2 = conv3x3(planes, planes)
64        self.bn2 = norm_layer(planes)
65        self.downsample = downsample
66        self.stride = stride
67
68    def forward(self, x):
69        identity = x
70
71        out = self.conv1(x)
72        out = self.bn1(out)
73        out = self.relu(out)
74
75        out = self.conv2(out)
76        out = self.bn2(out)
77
78        if self.downsample is not None:
79            identity = self.downsample(x)
80
81        out += identity
82        out = self.relu(out)
83
84        return out
85
86
87class Bottleneck(nn.Module):
88    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
89    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
90    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
91    # This variant is also known as ResNet V1.5 and improves accuracy according to
92    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
93
94    expansion = 4
95
96    def __init__(
97        self,
98        inplanes,
99        planes,
100        stride=1,
101        downsample=None,
102        groups=1,
103        base_width=64,
104        dilation=1,
105        norm_layer=None,
106    ):
107        super().__init__()
108        if norm_layer is None:
109            norm_layer = nn.BatchNorm2d
110        width = int(planes * (base_width / 64.0)) * groups
111        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
112        self.conv1 = conv1x1(inplanes, width)
113        self.bn1 = norm_layer(width)
114        self.conv2 = conv3x3(width, width, stride, groups, dilation)
115        self.bn2 = norm_layer(width)
116        self.conv3 = conv1x1(width, planes * self.expansion)
117        self.bn3 = norm_layer(planes * self.expansion)
118        self.relu = nn.ReLU(inplace=True)
119        self.downsample = downsample
120        self.stride = stride
121
122    def forward(self, x):
123        identity = x
124
125        out = self.conv1(x)
126        out = self.bn1(out)
127        out = self.relu(out)
128
129        out = self.conv2(out)
130        out = self.bn2(out)
131        out = self.relu(out)
132
133        out = self.conv3(out)
134        out = self.bn3(out)
135
136        if self.downsample is not None:
137            identity = self.downsample(x)
138
139        out += identity
140        out = self.relu(out)
141
142        return out
143
144
145class ResNet(nn.Module):
146    def __init__(
147        self,
148        block,
149        layers,
150        num_classes=1000,
151        zero_init_residual=False,
152        groups=1,
153        width_per_group=64,
154        replace_stride_with_dilation=None,
155        norm_layer=None,
156    ):
157        super().__init__()
158        if norm_layer is None:
159            norm_layer = nn.BatchNorm2d
160        self._norm_layer = norm_layer
161
162        self.inplanes = 64
163        self.dilation = 1
164        if replace_stride_with_dilation is None:
165            # each element in the tuple indicates if we should replace
166            # the 2x2 stride with a dilated convolution instead
167            replace_stride_with_dilation = [False, False, False]
168        if len(replace_stride_with_dilation) != 3:
169            raise ValueError(
170                "replace_stride_with_dilation should be None "
171                f"or a 3-element tuple, got {replace_stride_with_dilation}"
172            )
173        self.groups = groups
174        self.base_width = width_per_group
175        self.conv1 = nn.Conv2d(
176            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
177        )
178        self.bn1 = norm_layer(self.inplanes)
179        self.relu = nn.ReLU(inplace=True)
180        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
181        self.layer1 = self._make_layer(block, 64, layers[0])
182        self.layer2 = self._make_layer(
183            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
184        )
185        self.layer3 = self._make_layer(
186            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
187        )
188        self.layer4 = self._make_layer(
189            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
190        )
191        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
192        self.fc = nn.Linear(512 * block.expansion, num_classes)
193
194        for m in self.modules():
195            if isinstance(m, nn.Conv2d):
196                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
197            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
198                nn.init.constant_(m.weight, 1)
199                nn.init.constant_(m.bias, 0)
200
201        # Zero-initialize the last BN in each residual branch,
202        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
203        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
204        if zero_init_residual:
205            for m in self.modules():
206                if isinstance(m, Bottleneck):
207                    nn.init.constant_(m.bn3.weight, 0)
208                elif isinstance(m, BasicBlock):
209                    nn.init.constant_(m.bn2.weight, 0)
210
211    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
212        norm_layer = self._norm_layer
213        downsample = None
214        previous_dilation = self.dilation
215        if dilate:
216            self.dilation *= stride
217            stride = 1
218        if stride != 1 or self.inplanes != planes * block.expansion:
219            downsample = nn.Sequential(
220                conv1x1(self.inplanes, planes * block.expansion, stride),
221                norm_layer(planes * block.expansion),
222            )
223
224        layers = []
225        layers.append(
226            block(
227                self.inplanes,
228                planes,
229                stride,
230                downsample,
231                self.groups,
232                self.base_width,
233                previous_dilation,
234                norm_layer,
235            )
236        )
237        self.inplanes = planes * block.expansion
238        for _ in range(1, blocks):
239            layers.append(
240                block(
241                    self.inplanes,
242                    planes,
243                    groups=self.groups,
244                    base_width=self.base_width,
245                    dilation=self.dilation,
246                    norm_layer=norm_layer,
247                )
248            )
249
250        return nn.Sequential(*layers)
251
252    def _forward_impl(self, x):
253        # See note [TorchScript super()]
254        x = self.conv1(x)
255        x = self.bn1(x)
256        x = self.relu(x)
257        x = self.maxpool(x)
258
259        x = self.layer1(x)
260        x = self.layer2(x)
261        x = self.layer3(x)
262        x = self.layer4(x)
263
264        x = self.avgpool(x)
265        x = torch.flatten(x, 1)
266        x = self.fc(x)
267
268        return x
269
270    def forward(self, x):
271        return self._forward_impl(x)
272
273
274def _resnet(arch, block, layers, pretrained, progress, **kwargs):
275    model = ResNet(block, layers, **kwargs)
276    # if pretrained:
277    #     state_dict = load_state_dict_from_url(model_urls[arch],
278    #                                           progress=progress)
279    #     model.load_state_dict(state_dict)
280    return model
281
282
283def resnet18(pretrained=False, progress=True, **kwargs):
284    r"""ResNet-18 model from
285    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
286    Args:
287        pretrained (bool): If True, returns a model pre-trained on ImageNet
288        progress (bool): If True, displays a progress bar of the download to stderr
289    """
290    return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
291
292
293def resnet50(pretrained=False, progress=True, **kwargs):
294    r"""ResNet-50 model from
295    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
296    Args:
297        pretrained (bool): If True, returns a model pre-trained on ImageNet
298        progress (bool): If True, displays a progress bar of the download to stderr
299    """
300    return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
301
302
303class IntermediateLayerGetter(nn.ModuleDict):
304    """
305    Module wrapper that returns intermediate layers from a model
306    It has a strong assumption that the modules have been registered
307    into the model in the same order as they are used.
308    This means that one should **not** reuse the same nn.Module
309    twice in the forward if you want this to work.
310    Additionally, it is only able to query submodules that are directly
311    assigned to the model. So if `model` is passed, `model.feature1` can
312    be returned, but not `model.feature1.layer2`.
313    Args:
314        model (nn.Module): model on which we will extract the features
315        return_layers (Dict[name, new_name]): a dict containing the names
316            of the modules for which the activations will be returned as
317            the key of the dict, and the value of the dict is the name
318            of the returned activation (which the user can specify).
319    Examples::
320        >>> m = torchvision.models.resnet18(pretrained=True)
321        >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
322        >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m,
323        >>>     {'layer1': 'feat1', 'layer3': 'feat2'})
324        >>> out = new_m(torch.rand(1, 3, 224, 224))
325        >>> print([(k, v.shape) for k, v in out.items()])
326        >>>     [('feat1', torch.Size([1, 64, 56, 56])),
327        >>>      ('feat2', torch.Size([1, 256, 14, 14]))]
328    """
329
330    _version = 2
331    __annotations__ = {
332        "return_layers": Dict[str, str],
333    }
334
335    def __init__(self, model, return_layers):
336        if not set(return_layers).issubset(
337            [name for name, _ in model.named_children()]
338        ):
339            raise ValueError("return_layers are not present in model")
340        orig_return_layers = return_layers
341        return_layers = {str(k): str(v) for k, v in return_layers.items()}
342        layers = OrderedDict()
343        for name, module in model.named_children():
344            layers[name] = module
345            if name in return_layers:
346                del return_layers[name]
347            if not return_layers:
348                break
349
350        super().__init__(layers)
351        self.return_layers = orig_return_layers
352
353    def forward(self, x):
354        out = OrderedDict()
355        for name, module in self.items():
356            x = module(x)
357            if name in self.return_layers:
358                out_name = self.return_layers[name]
359                out[out_name] = x
360        return out
361
362
363class _SimpleSegmentationModel(nn.Module):
364    __constants__ = ["aux_classifier"]
365
366    def __init__(self, backbone, classifier, aux_classifier=None):
367        super().__init__()
368        self.backbone = backbone
369        self.classifier = classifier
370        self.aux_classifier = aux_classifier
371
372    def forward(self, x):
373        input_shape = x.shape[-2:]
374        # contract: features is a dict of tensors
375        features = self.backbone(x)
376
377        result = OrderedDict()
378        x = features["out"]
379        x = self.classifier(x)
380        x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
381        result["out"] = x
382
383        if self.aux_classifier is not None:
384            x = features["aux"]
385            x = self.aux_classifier(x)
386            x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False)
387            result["aux"] = x
388
389        return result
390
391
392class FCN(_SimpleSegmentationModel):
393    """
394    Implements a Fully-Convolutional Network for semantic segmentation.
395    Args:
396        backbone (nn.Module): the network used to compute the features for the model.
397            The backbone should return an OrderedDict[Tensor], with the key being
398            "out" for the last feature map used, and "aux" if an auxiliary classifier
399            is used.
400        classifier (nn.Module): module that takes the "out" element returned from
401            the backbone and returns a dense prediction.
402        aux_classifier (nn.Module, optional): auxiliary classifier used during training
403    """
404
405
406class FCNHead(nn.Sequential):
407    def __init__(self, in_channels, channels):
408        inter_channels = in_channels // 4
409        layers = [
410            nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
411            nn.BatchNorm2d(inter_channels),
412            nn.ReLU(),
413            nn.Dropout(0.1),
414            nn.Conv2d(inter_channels, channels, 1),
415        ]
416
417        super().__init__(*layers)
418
419
420def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
421    # backbone = resnet.__dict__[backbone_name](
422    #     pretrained=pretrained_backbone,
423    #     replace_stride_with_dilation=[False, True, True])
424    # Hardcoded resnet 50
425    assert backbone_name == "resnet50"
426    backbone = resnet50(
427        pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True]
428    )
429
430    return_layers = {"layer4": "out"}
431    if aux:
432        return_layers["layer3"] = "aux"
433    backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
434
435    aux_classifier = None
436    if aux:
437        inplanes = 1024
438        aux_classifier = FCNHead(inplanes, num_classes)
439
440    model_map = {
441        # 'deeplabv3': (DeepLabHead, DeepLabV3), # Not used
442        "fcn": (FCNHead, FCN),
443    }
444    inplanes = 2048
445    classifier = model_map[name][0](inplanes, num_classes)
446    base_model = model_map[name][1]
447
448    model = base_model(backbone, classifier, aux_classifier)
449    return model
450
451
452def _load_model(
453    arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs
454):
455    if pretrained:
456        aux_loss = True
457    model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
458    # if pretrained:
459    #     arch = arch_type + '_' + backbone + '_coco'
460    #     model_url = model_urls[arch]
461    #     if model_url is None:
462    #         raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
463    #     else:
464    #         state_dict = load_state_dict_from_url(model_url, progress=progress)
465    #         model.load_state_dict(state_dict)
466    return model
467
468
469def fcn_resnet50(
470    pretrained=False, progress=True, num_classes=21, aux_loss=None, **kwargs
471):
472    """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone.
473    Args:
474        pretrained (bool): If True, returns a model pre-trained on COCO train2017 which
475            contains the same classes as Pascal VOC
476        progress (bool): If True, displays a progress bar of the download to stderr
477    """
478    return _load_model(
479        "fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs
480    )
481
482
483# Taken from @fmassa example slides and https://github.com/facebookresearch/detr
484class DETR(nn.Module):
485    """
486    Demo DETR implementation.
487
488    Demo implementation of DETR in minimal number of lines, with the
489    following differences wrt DETR in the paper:
490    * learned positional encoding (instead of sine)
491    * positional encoding is passed at input (instead of attention)
492    * fc bbox predictor (instead of MLP)
493    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
494    Only batch size 1 supported.
495    """
496
497    def __init__(
498        self,
499        num_classes,
500        hidden_dim=256,
501        nheads=8,
502        num_encoder_layers=6,
503        num_decoder_layers=6,
504    ):
505        super().__init__()
506
507        # create ResNet-50 backbone
508        self.backbone = resnet50()
509        del self.backbone.fc
510
511        # create conversion layer
512        self.conv = nn.Conv2d(2048, hidden_dim, 1)
513
514        # create a default PyTorch transformer
515        self.transformer = nn.Transformer(
516            hidden_dim, nheads, num_encoder_layers, num_decoder_layers
517        )
518
519        # prediction heads, one extra class for predicting non-empty slots
520        # note that in baseline DETR linear_bbox layer is 3-layer MLP
521        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
522        self.linear_bbox = nn.Linear(hidden_dim, 4)
523
524        # output positional encodings (object queries)
525        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
526
527        # spatial positional encodings
528        # note that in baseline DETR we use sine positional encodings
529        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
530        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
531
532    def forward(self, inputs):
533        # propagate inputs through ResNet-50 up to avg-pool layer
534        x = self.backbone.conv1(inputs)
535        x = self.backbone.bn1(x)
536        x = self.backbone.relu(x)
537        x = self.backbone.maxpool(x)
538
539        x = self.backbone.layer1(x)
540        x = self.backbone.layer2(x)
541        x = self.backbone.layer3(x)
542        x = self.backbone.layer4(x)
543
544        # convert from 2048 to 256 feature planes for the transformer
545        h = self.conv(x)
546
547        # construct positional encodings
548        H, W = h.shape[-2:]
549        pos = (
550            torch.cat(
551                [
552                    self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
553                    self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
554                ],
555                dim=-1,
556            )
557            .flatten(0, 1)
558            .unsqueeze(1)
559        )
560
561        # propagate through the transformer
562        # TODO (alband) Why this is not automatically broadcasted? (had to add the repeat)
563        f = pos + 0.1 * h.flatten(2).permute(2, 0, 1)
564        s = self.query_pos.unsqueeze(1)
565        s = s.expand(s.size(0), inputs.size(0), s.size(2))
566        h = self.transformer(f, s).transpose(0, 1)
567
568        # finally project transformer outputs to class labels and bounding boxes
569        return {
570            "pred_logits": self.linear_class(h),
571            "pred_boxes": self.linear_bbox(h).sigmoid(),
572        }
573
574
575def generalized_box_iou(boxes1, boxes2):
576    """
577    Generalized IoU from https://giou.stanford.edu/
578    The boxes should be in [x0, y0, x1, y1] format
579    Returns a [N, M] pairwise matrix, where N = len(boxes1)
580    and M = len(boxes2)
581    """
582    # degenerate boxes gives inf / nan results
583    # so do an early check
584    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
585    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
586    iou, union = box_iou(boxes1, boxes2)
587
588    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
589    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
590
591    wh = (rb - lt).clamp(min=0)  # [N,M,2]
592    area = wh[:, :, 0] * wh[:, :, 1]
593
594    return iou - (area - union) / area
595
596
597def box_cxcywh_to_xyxy(x):
598    x_c, y_c, w, h = x.unbind(-1)
599    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
600    return torch.stack(b, dim=-1)
601
602
603def box_area(boxes):
604    """
605    Computes the area of a set of bounding boxes, which are specified by its
606    (x1, y1, x2, y2) coordinates.
607    Args:
608        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
609            are expected to be in (x1, y1, x2, y2) format
610    Returns:
611        area (Tensor[N]): area for each box
612    """
613    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
614
615
616# modified from torchvision to also return the union
617def box_iou(boxes1, boxes2):
618    area1 = box_area(boxes1)
619    area2 = box_area(boxes2)
620
621    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
622    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
623
624    wh = (rb - lt).clamp(min=0)  # [N,M,2]
625    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
626
627    union = area1[:, None] + area2 - inter
628
629    iou = inter / union
630    return iou, union
631
632
633def is_dist_avail_and_initialized():
634    return False
635
636
637def get_world_size():
638    if not is_dist_avail_and_initialized():
639        return 1
640
641
642@torch.no_grad()
643def accuracy(output, target, topk=(1,)):
644    """Computes the precision@k for the specified values of k"""
645    if target.numel() == 0:
646        return [torch.zeros([], device=output.device)]
647    maxk = max(topk)
648    batch_size = target.size(0)
649
650    _, pred = output.topk(maxk, 1, True, True)
651    pred = pred.t()
652    correct = pred.eq(target.view(1, -1).expand_as(pred))
653
654    res = []
655    for k in topk:
656        correct_k = correct[:k].view(-1).float().sum(0)
657        res.append(correct_k.mul_(100.0 / batch_size))
658    return res
659
660
661class SetCriterion(nn.Module):
662    """This class computes the loss for DETR.
663    The process happens in two steps:
664        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
665        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
666    """
667
668    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
669        """Create the criterion.
670        Parameters:
671            num_classes: number of object categories, omitting the special no-object category
672            matcher: module able to compute a matching between targets and proposals
673            weight_dict: dict containing as key the names of the losses and as values their relative weight.
674            eos_coef: relative classification weight applied to the no-object category
675            losses: list of all the losses to be applied. See get_loss for list of available losses.
676        """
677        super().__init__()
678        self.num_classes = num_classes
679        self.matcher = matcher
680        self.weight_dict = weight_dict
681        self.eos_coef = eos_coef
682        self.losses = losses
683        empty_weight = torch.ones(self.num_classes + 1)
684        empty_weight[-1] = self.eos_coef
685        self.register_buffer("empty_weight", empty_weight)
686
687    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
688        """Classification loss (NLL)
689        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
690        """
691        assert "pred_logits" in outputs
692        src_logits = outputs["pred_logits"]
693
694        idx = self._get_src_permutation_idx(indices)
695        target_classes_o = torch.cat(
696            [t["labels"][J] for t, (_, J) in zip(targets, indices)]
697        )
698        target_classes = torch.full(
699            src_logits.shape[:2],
700            self.num_classes,
701            dtype=torch.int64,
702            device=src_logits.device,
703        )
704        target_classes[idx] = target_classes_o
705
706        loss_ce = F.cross_entropy(
707            src_logits.transpose(1, 2), target_classes, self.empty_weight
708        )
709        losses = {"loss_ce": loss_ce}
710
711        if log:
712            # TODO this should probably be a separate loss, not hacked in this one here
713            losses["class_error"] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
714        return losses
715
716    @torch.no_grad()
717    def loss_cardinality(self, outputs, targets, indices, num_boxes):
718        """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
719        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
720        """
721        pred_logits = outputs["pred_logits"]
722        device = pred_logits.device
723        tgt_lengths = torch.as_tensor(
724            [len(v["labels"]) for v in targets], device=device
725        )
726        # Count the number of predictions that are NOT "no-object" (which is the last class)
727        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
728        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
729        losses = {"cardinality_error": card_err}
730        return losses
731
732    def loss_boxes(self, outputs, targets, indices, num_boxes):
733        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
734        targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
735        The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size.
736        """
737        assert "pred_boxes" in outputs
738        idx = self._get_src_permutation_idx(indices)
739        src_boxes = outputs["pred_boxes"][idx]
740        target_boxes = torch.cat(
741            [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0
742        )
743
744        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")
745
746        losses = {}
747        losses["loss_bbox"] = loss_bbox.sum() / num_boxes
748
749        loss_giou = 1 - torch.diag(
750            generalized_box_iou(
751                box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes)
752            )
753        )
754        losses["loss_giou"] = loss_giou.sum() / num_boxes
755        return losses
756
757    def loss_masks(self, outputs, targets, indices, num_boxes):
758        """Compute the losses related to the masks: the focal loss and the dice loss.
759        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
760        """
761        assert "pred_masks" in outputs
762
763        src_idx = self._get_src_permutation_idx(indices)
764        tgt_idx = self._get_tgt_permutation_idx(indices)
765
766        src_masks = outputs["pred_masks"]
767
768        # TODO use valid to mask invalid areas due to padding in loss
769        target_masks, valid = nested_tensor_from_tensor_list(  # noqa: F821
770            [t["masks"] for t in targets]
771        ).decompose()
772        target_masks = target_masks.to(src_masks)
773
774        src_masks = src_masks[src_idx]
775        # upsample predictions to the target size
776        src_masks = interpolate(  # noqa: F821
777            src_masks[:, None],
778            size=target_masks.shape[-2:],
779            mode="bilinear",
780            align_corners=False,
781        )
782        src_masks = src_masks[:, 0].flatten(1)
783
784        target_masks = target_masks[tgt_idx].flatten(1)
785
786        losses = {
787            "loss_mask": sigmoid_focal_loss(  # noqa: F821
788                src_masks, target_masks, num_boxes
789            ),  # noqa: F821
790            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),  # noqa: F821
791        }
792        return losses
793
794    def _get_src_permutation_idx(self, indices):
795        # permute predictions following indices
796        batch_idx = torch.cat(
797            [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
798        )
799        src_idx = torch.cat([src for (src, _) in indices])
800        return batch_idx, src_idx
801
802    def _get_tgt_permutation_idx(self, indices):
803        # permute targets following indices
804        batch_idx = torch.cat(
805            [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
806        )
807        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
808        return batch_idx, tgt_idx
809
810    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
811        loss_map = {
812            "labels": self.loss_labels,
813            "cardinality": self.loss_cardinality,
814            "boxes": self.loss_boxes,
815            "masks": self.loss_masks,
816        }
817        assert loss in loss_map, f"do you really want to compute {loss} loss?"
818        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
819
820    def forward(self, outputs, targets):
821        """This performs the loss computation.
822        Parameters:
823             outputs: dict of tensors, see the output specification of the model for the format
824             targets: list of dicts, such that len(targets) == batch_size.
825                      The expected keys in each dict depends on the losses applied, see each loss' doc
826        """
827        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
828
829        # Retrieve the matching between the outputs of the last layer and the targets
830        indices = self.matcher(outputs_without_aux, targets)
831
832        # Compute the average number of target boxes across all nodes, for normalization purposes
833        num_boxes = sum(len(t["labels"]) for t in targets)
834        num_boxes = torch.as_tensor(
835            [num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
836        )
837        if is_dist_avail_and_initialized():
838            torch.distributed.all_reduce(num_boxes)
839        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
840
841        # Compute all the requested losses
842        losses = {}
843        for loss in self.losses:
844            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
845
846        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
847        if "aux_outputs" in outputs:
848            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
849                indices = self.matcher(aux_outputs, targets)
850                for loss in self.losses:
851                    if loss == "masks":
852                        # Intermediate masks losses are too costly to compute, we ignore them.
853                        continue
854                    kwargs = {}
855                    if loss == "labels":
856                        # Logging is enabled only for the last layer
857                        kwargs = {"log": False}
858                    l_dict = self.get_loss(
859                        loss, aux_outputs, targets, indices, num_boxes, **kwargs
860                    )
861                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
862                    losses.update(l_dict)
863
864        return losses
865
866
867class HungarianMatcher(nn.Module):
868    """This class computes an assignment between the targets and the predictions of the network
869    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
870    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
871    while the others are un-matched (and thus treated as non-objects).
872    """
873
874    def __init__(
875        self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1
876    ):
877        """Creates the matcher
878        Params:
879            cost_class: This is the relative weight of the classification error in the matching cost
880            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
881            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
882        """
883        super().__init__()
884        self.cost_class = cost_class
885        self.cost_bbox = cost_bbox
886        self.cost_giou = cost_giou
887        assert (
888            cost_class != 0 or cost_bbox != 0 or cost_giou != 0
889        ), "all costs cant be 0"
890
891    @torch.no_grad()
892    def forward(self, outputs, targets):
893        """Performs the matching
894        Params:
895            outputs: This is a dict that contains at least these entries:
896                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
897                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
898            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
899                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
900                           objects in the target) containing the class labels
901                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
902        Returns:
903            A list of size batch_size, containing tuples of (index_i, index_j) where:
904                - index_i is the indices of the selected predictions (in order)
905                - index_j is the indices of the corresponding selected targets (in order)
906            For each batch element, it holds:
907                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
908        """
909        bs, num_queries = outputs["pred_logits"].shape[:2]
910
911        # We flatten to compute the cost matrices in a batch
912        out_prob = (
913            outputs["pred_logits"].flatten(0, 1).softmax(-1)
914        )  # [batch_size * num_queries, num_classes]
915        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
916
917        # Also concat the target labels and boxes
918        tgt_ids = torch.cat([v["labels"] for v in targets])
919        tgt_bbox = torch.cat([v["boxes"] for v in targets])
920
921        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
922        # but approximate it in 1 - proba[target class].
923        # The 1 is a constant that doesn't change the matching, it can be ommitted.
924        cost_class = -out_prob[:, tgt_ids]
925
926        # Compute the L1 cost between boxes
927        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
928
929        # Compute the giou cost betwen boxes
930        cost_giou = -generalized_box_iou(
931            box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)
932        )
933
934        # Final cost matrix
935        C = (
936            self.cost_bbox * cost_bbox
937            + self.cost_class * cost_class
938            + self.cost_giou * cost_giou
939        )
940        C = C.view(bs, num_queries, -1).cpu()
941
942        sizes = [len(v["boxes"]) for v in targets]
943        if not scipy_available:
944            raise RuntimeError(
945                "The 'detr' model requires scipy to run. Please make sure you have it installed"
946                " if you enable the 'detr' model."
947            )
948        indices = [
949            linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))
950        ]
951        return [
952            (
953                torch.as_tensor(i, dtype=torch.int64),
954                torch.as_tensor(j, dtype=torch.int64),
955            )
956            for i, j in indices
957        ]
958