1# mypy: allow-untyped-defs 2import torch 3from torch.nn import ( 4 BatchNorm1d, 5 BatchNorm2d, 6 BatchNorm3d, 7 Conv1d, 8 Conv2d, 9 Conv3d, 10 Linear, 11 ReLU, 12) 13from torch.nn.utils.parametrize import type_before_parametrizations 14 15 16__all__ = [ 17 "ConvReLU1d", 18 "ConvReLU2d", 19 "ConvReLU3d", 20 "LinearReLU", 21 "ConvBn1d", 22 "ConvBn2d", 23 "ConvBnReLU1d", 24 "ConvBnReLU2d", 25 "ConvBn3d", 26 "ConvBnReLU3d", 27 "BNReLU2d", 28 "BNReLU3d", 29 "LinearBn1d", 30 "LinearLeakyReLU", 31 "LinearTanh", 32 "ConvAdd2d", 33 "ConvAddReLU2d", 34] 35 36 37# Used for identifying intrinsic modules used in quantization 38class _FusedModule(torch.nn.Sequential): 39 pass 40 41 42class ConvReLU1d(_FusedModule): 43 r"""This is a sequential container which calls the Conv1d and ReLU modules. 44 During quantization this will be replaced with the corresponding fused module.""" 45 46 def __init__(self, conv, relu): 47 assert ( 48 type_before_parametrizations(conv) == Conv1d 49 and type_before_parametrizations(relu) == ReLU 50 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" 51 super().__init__(conv, relu) 52 53 54class ConvReLU2d(_FusedModule): 55 r"""This is a sequential container which calls the Conv2d and ReLU modules. 56 During quantization this will be replaced with the corresponding fused module.""" 57 58 def __init__(self, conv, relu): 59 assert ( 60 type_before_parametrizations(conv) == Conv2d 61 and type_before_parametrizations(relu) == ReLU 62 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" 63 super().__init__(conv, relu) 64 65 66class ConvReLU3d(_FusedModule): 67 r"""This is a sequential container which calls the Conv3d and ReLU modules. 68 During quantization this will be replaced with the corresponding fused module.""" 69 70 def __init__(self, conv, relu): 71 assert ( 72 type_before_parametrizations(conv) == Conv3d 73 and type_before_parametrizations(relu) == ReLU 74 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(relu)}" 75 super().__init__(conv, relu) 76 77 78class LinearReLU(_FusedModule): 79 r"""This is a sequential container which calls the Linear and ReLU modules. 80 During quantization this will be replaced with the corresponding fused module.""" 81 82 def __init__(self, linear, relu): 83 assert ( 84 type_before_parametrizations(linear) == Linear 85 and type_before_parametrizations(relu) == ReLU 86 ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(relu)}" 87 super().__init__(linear, relu) 88 89 90class ConvBn1d(_FusedModule): 91 r"""This is a sequential container which calls the Conv 1d and Batch Norm 1d modules. 92 During quantization this will be replaced with the corresponding fused module.""" 93 94 def __init__(self, conv, bn): 95 assert ( 96 type_before_parametrizations(conv) == Conv1d 97 and type_before_parametrizations(bn) == BatchNorm1d 98 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" 99 super().__init__(conv, bn) 100 101 102class ConvBn2d(_FusedModule): 103 r"""This is a sequential container which calls the Conv 2d and Batch Norm 2d modules. 104 During quantization this will be replaced with the corresponding fused module.""" 105 106 def __init__(self, conv, bn): 107 assert ( 108 type_before_parametrizations(conv) == Conv2d 109 and type_before_parametrizations(bn) == BatchNorm2d 110 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" 111 super().__init__(conv, bn) 112 113 114class ConvBnReLU1d(_FusedModule): 115 r"""This is a sequential container which calls the Conv 1d, Batch Norm 1d, and ReLU modules. 116 During quantization this will be replaced with the corresponding fused module.""" 117 118 def __init__(self, conv, bn, relu): 119 assert ( 120 type_before_parametrizations(conv) == Conv1d 121 and type_before_parametrizations(bn) == BatchNorm1d 122 and type_before_parametrizations(relu) == ReLU 123 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 124 super().__init__(conv, bn, relu) 125 126 127class ConvBnReLU2d(_FusedModule): 128 r"""This is a sequential container which calls the Conv 2d, Batch Norm 2d, and ReLU modules. 129 During quantization this will be replaced with the corresponding fused module.""" 130 131 def __init__(self, conv, bn, relu): 132 assert ( 133 type_before_parametrizations(conv) == Conv2d 134 and type_before_parametrizations(bn) == BatchNorm2d 135 and type_before_parametrizations(relu) == ReLU 136 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 137 super().__init__(conv, bn, relu) 138 139 140class ConvBn3d(_FusedModule): 141 r"""This is a sequential container which calls the Conv 3d and Batch Norm 3d modules. 142 During quantization this will be replaced with the corresponding fused module.""" 143 144 def __init__(self, conv, bn): 145 assert ( 146 type_before_parametrizations(conv) == Conv3d 147 and type_before_parametrizations(bn) == BatchNorm3d 148 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}" 149 super().__init__(conv, bn) 150 151 152class ConvBnReLU3d(_FusedModule): 153 r"""This is a sequential container which calls the Conv 3d, Batch Norm 3d, and ReLU modules. 154 During quantization this will be replaced with the corresponding fused module.""" 155 156 def __init__(self, conv, bn, relu): 157 assert ( 158 type_before_parametrizations(conv) == Conv3d 159 and type_before_parametrizations(bn) == BatchNorm3d 160 and type_before_parametrizations(relu) == ReLU 161 ), f"Incorrect types for input modules{type_before_parametrizations(conv)}{type_before_parametrizations(bn)}{type_before_parametrizations(relu)}" # noqa: B950 162 super().__init__(conv, bn, relu) 163 164 165class BNReLU2d(_FusedModule): 166 r"""This is a sequential container which calls the BatchNorm 2d and ReLU modules. 167 During quantization this will be replaced with the corresponding fused module.""" 168 169 def __init__(self, batch_norm, relu): 170 assert ( 171 type_before_parametrizations(batch_norm) == BatchNorm2d 172 and type_before_parametrizations(relu) == ReLU 173 ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}" 174 super().__init__(batch_norm, relu) 175 176 177class BNReLU3d(_FusedModule): 178 r"""This is a sequential container which calls the BatchNorm 3d and ReLU modules. 179 During quantization this will be replaced with the corresponding fused module.""" 180 181 def __init__(self, batch_norm, relu): 182 assert ( 183 type_before_parametrizations(batch_norm) == BatchNorm3d 184 and type_before_parametrizations(relu) == ReLU 185 ), f"Incorrect types for input modules{type_before_parametrizations(batch_norm)}{type_before_parametrizations(relu)}" 186 super().__init__(batch_norm, relu) 187 188 189class LinearBn1d(_FusedModule): 190 r"""This is a sequential container which calls the Linear and BatchNorm1d modules. 191 During quantization this will be replaced with the corresponding fused module.""" 192 193 def __init__(self, linear, bn): 194 assert ( 195 type_before_parametrizations(linear) == Linear 196 and type_before_parametrizations(bn) == BatchNorm1d 197 ), f"Incorrect types for input modules{type_before_parametrizations(linear)}{type_before_parametrizations(bn)}" 198 super().__init__(linear, bn) 199 200 201class LinearLeakyReLU(_FusedModule): 202 r"""This is a sequential container which calls the Linear and LeakyReLU modules. 203 During quantization this will be replaced with the corresponding fused module.""" 204 205 def __init__(self, linear, leaky_relu): 206 assert ( 207 type(linear) == Linear and type(leaky_relu) == torch.nn.LeakyReLU 208 ), f"Incorrect types for input modules{type(linear)}{type(leaky_relu)}" 209 super().__init__(linear, leaky_relu) 210 211 212class LinearTanh(_FusedModule): 213 r"""This is a sequential container which calls the Linear and Tanh modules. 214 During quantization this will be replaced with the corresponding fused module.""" 215 216 def __init__(self, linear, tanh): 217 assert ( 218 type(linear) == Linear and type(tanh) == torch.nn.Tanh 219 ), f"Incorrect types for input modules{type(linear)}{type(tanh)}" 220 super().__init__(linear, tanh) 221 222 223class ConvAdd2d(_FusedModule): 224 r"""This is a sequential container which calls the Conv2d modules with extra Add. 225 During quantization this will be replaced with the corresponding fused module.""" 226 227 def __init__(self, conv, add): 228 super().__init__(conv) 229 self.add = add 230 231 def forward(self, x1, x2): 232 return self.add(self[0](x1), x2) 233 234 235class ConvAddReLU2d(_FusedModule): 236 r"""This is a sequential container which calls the Conv2d, add, Relu. 237 During quantization this will be replaced with the corresponding fused module.""" 238 239 def __init__(self, conv, add, relu): 240 super().__init__(conv) 241 self.add = add 242 self.relu = relu 243 244 def forward(self, x1, x2): 245 return self.relu(self.add(self[0](x1), x2)) 246