xref: /aosp_15_r20/external/pytorch/torch/ao/nn/intrinsic/quantized/modules/bn_relu.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch
4import torch.ao.nn.intrinsic
5import torch.ao.nn.intrinsic.qat
6import torch.ao.nn.quantized as nnq
7
8
9__all__ = ["BNReLU2d", "BNReLU3d"]
10
11
12class BNReLU2d(nnq.BatchNorm2d):
13    r"""
14    A BNReLU2d module is a fused module of BatchNorm2d and ReLU
15
16    We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm2d`.
17
18    Attributes:
19        Same as torch.ao.nn.quantized.BatchNorm2d
20
21    """
22    _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU2d
23
24    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
25        super().__init__(
26            num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
27        )
28
29    def forward(self, input):
30        # Temporarily using len(shape) instead of ndim due to JIT issue
31        # https://github.com/pytorch/pytorch/issues/23890
32        if len(input.shape) != 4:
33            raise ValueError("Input shape must be `(N, C, H, W)`!")
34        return torch.ops.quantized.batch_norm2d_relu(
35            input,
36            self.weight,
37            self.bias,
38            self.running_mean,
39            self.running_var,
40            self.eps,
41            self.scale,
42            self.zero_point,
43        )
44
45    def _get_name(self):
46        return "QuantizedBNReLU2d"
47
48    @classmethod
49    def from_float(cls, mod, use_precomputed_fake_quant=False):
50        # TODO: Add qat support for BNReLU2d
51        return super().from_float(
52            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
53        )
54
55    @classmethod
56    def from_reference(cls, bn_relu, output_scale, output_zero_point):
57        return super().from_reference(bn_relu[0], output_scale, output_zero_point)
58
59
60class BNReLU3d(nnq.BatchNorm3d):
61    r"""
62    A BNReLU3d module is a fused module of BatchNorm3d and ReLU
63
64    We adopt the same interface as :class:`torch.ao.nn.quantized.BatchNorm3d`.
65
66    Attributes:
67        Same as torch.ao.nn.quantized.BatchNorm3d
68
69    """
70    _FLOAT_MODULE = torch.ao.nn.intrinsic.BNReLU3d
71
72    def __init__(self, num_features, eps=1e-5, momentum=0.1, device=None, dtype=None):
73        super().__init__(
74            num_features, eps=eps, momentum=momentum, device=device, dtype=dtype
75        )
76
77    def forward(self, input):
78        # Temporarily using len(shape) instead of ndim due to JIT issue
79        # https://github.com/pytorch/pytorch/issues/23890
80        if len(input.shape) != 5:
81            raise ValueError("Input shape must be `(N, C, D, H, W)`!")
82        return torch.ops.quantized.batch_norm3d_relu(
83            input,
84            self.weight,
85            self.bias,
86            self.running_mean,
87            self.running_var,
88            self.eps,
89            self.scale,
90            self.zero_point,
91        )
92
93    def _get_name(self):
94        return "QuantizedBNReLU3d"
95
96    @classmethod
97    def from_float(cls, mod, use_precomputed_fake_quant=False):
98        # TODO: Add qat support for BNReLU3d
99        return super().from_float(
100            mod, use_precomputed_fake_quant=use_precomputed_fake_quant
101        )
102
103    @classmethod
104    def from_reference(cls, bn_relu, output_scale, output_zero_point):
105        return super().from_reference(bn_relu[0], output_scale, output_zero_point)
106