1# Owner(s): ["oncall: quantization"] 2 3import torch 4from torch.testing._internal.common_quantization import skipIfNoFBGEMM 5from torch.testing._internal.jit_utils import JitTestCase 6 7 8class TestDeprecatedJitQuantized(JitTestCase): 9 @skipIfNoFBGEMM 10 def test_rnn_cell_quantized(self): 11 d_in, d_hid = 2, 2 12 13 for cell in [ 14 torch.nn.LSTMCell(d_in, d_hid).float(), 15 torch.nn.GRUCell(d_in, d_hid).float(), 16 torch.nn.RNNCell(d_in, d_hid).float(), 17 ]: 18 if isinstance(cell, torch.nn.LSTMCell): 19 num_chunks = 4 20 elif isinstance(cell, torch.nn.GRUCell): 21 num_chunks = 3 22 elif isinstance(cell, torch.nn.RNNCell): 23 num_chunks = 1 24 25 # Replace parameter values s.t. the range of values is exactly 26 # 255, thus we will have 0 quantization error in the quantized 27 # GEMM call. This i s for testing purposes. 28 # 29 # Note that the current implementation does not support 30 # accumulation values outside of the range representable by a 31 # 16 bit integer, instead resulting in a saturated value. We 32 # must take care that in our test we do not end up with a dot 33 # product that overflows the int16 range, e.g. 34 # (255*127+255*127) = 64770. So, we hardcode the test values 35 # here and ensure a mix of signedness. 36 vals = [ 37 [100, -155], 38 [100, -155], 39 [-155, 100], 40 [-155, 100], 41 [100, -155], 42 [-155, 100], 43 [-155, 100], 44 [100, -155], 45 ] 46 vals = vals[: d_hid * num_chunks] 47 cell.weight_ih = torch.nn.Parameter( 48 torch.tensor(vals, dtype=torch.float), requires_grad=False 49 ) 50 cell.weight_hh = torch.nn.Parameter( 51 torch.tensor(vals, dtype=torch.float), requires_grad=False 52 ) 53 54 with self.assertRaisesRegex( 55 RuntimeError, 56 "quantize_rnn_cell_modules function is no longer supported", 57 ): 58 cell = torch.jit.quantized.quantize_rnn_cell_modules(cell) 59 60 @skipIfNoFBGEMM 61 def test_rnn_quantized(self): 62 d_in, d_hid = 2, 2 63 64 for cell in [ 65 torch.nn.LSTM(d_in, d_hid).float(), 66 torch.nn.GRU(d_in, d_hid).float(), 67 ]: 68 # Replace parameter values s.t. the range of values is exactly 69 # 255, thus we will have 0 quantization error in the quantized 70 # GEMM call. This i s for testing purposes. 71 # 72 # Note that the current implementation does not support 73 # accumulation values outside of the range representable by a 74 # 16 bit integer, instead resulting in a saturated value. We 75 # must take care that in our test we do not end up with a dot 76 # product that overflows the int16 range, e.g. 77 # (255*127+255*127) = 64770. So, we hardcode the test values 78 # here and ensure a mix of signedness. 79 vals = [ 80 [100, -155], 81 [100, -155], 82 [-155, 100], 83 [-155, 100], 84 [100, -155], 85 [-155, 100], 86 [-155, 100], 87 [100, -155], 88 ] 89 if isinstance(cell, torch.nn.LSTM): 90 num_chunks = 4 91 elif isinstance(cell, torch.nn.GRU): 92 num_chunks = 3 93 vals = vals[: d_hid * num_chunks] 94 cell.weight_ih_l0 = torch.nn.Parameter( 95 torch.tensor(vals, dtype=torch.float), requires_grad=False 96 ) 97 cell.weight_hh_l0 = torch.nn.Parameter( 98 torch.tensor(vals, dtype=torch.float), requires_grad=False 99 ) 100 101 with self.assertRaisesRegex( 102 RuntimeError, "quantize_rnn_modules function is no longer supported" 103 ): 104 cell_int8 = torch.jit.quantized.quantize_rnn_modules( 105 cell, dtype=torch.int8 106 ) 107 108 with self.assertRaisesRegex( 109 RuntimeError, "quantize_rnn_modules function is no longer supported" 110 ): 111 cell_fp16 = torch.jit.quantized.quantize_rnn_modules( 112 cell, dtype=torch.float16 113 ) 114 115 if "fbgemm" in torch.backends.quantized.supported_engines: 116 117 def test_quantization_modules(self): 118 K1, N1 = 2, 2 119 120 class FooBar(torch.nn.Module): 121 def __init__(self) -> None: 122 super().__init__() 123 self.linear1 = torch.nn.Linear(K1, N1).float() 124 125 def forward(self, x): 126 x = self.linear1(x) 127 return x 128 129 fb = FooBar() 130 fb.linear1.weight = torch.nn.Parameter( 131 torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), 132 requires_grad=False, 133 ) 134 fb.linear1.bias = torch.nn.Parameter( 135 torch.zeros_like(fb.linear1.bias), requires_grad=False 136 ) 137 138 x = (torch.rand(1, K1).float() - 0.5) / 10.0 139 value = torch.tensor([[100, -150]], dtype=torch.float) 140 141 y_ref = fb(value) 142 143 with self.assertRaisesRegex( 144 RuntimeError, "quantize_linear_modules function is no longer supported" 145 ): 146 fb_int8 = torch.jit.quantized.quantize_linear_modules(fb) 147 148 with self.assertRaisesRegex( 149 RuntimeError, "quantize_linear_modules function is no longer supported" 150 ): 151 fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16) 152 153 @skipIfNoFBGEMM 154 def test_erase_class_tensor_shapes(self): 155 class Linear(torch.nn.Module): 156 def __init__(self, in_features, out_features): 157 super().__init__() 158 qweight = torch._empty_affine_quantized( 159 [out_features, in_features], 160 scale=1, 161 zero_point=0, 162 dtype=torch.qint8, 163 ) 164 self._packed_weight = torch.ops.quantized.linear_prepack(qweight) 165 166 @torch.jit.export 167 def __getstate__(self): 168 return ( 169 torch.ops.quantized.linear_unpack(self._packed_weight)[0], 170 self.training, 171 ) 172 173 def forward(self): 174 return self._packed_weight 175 176 @torch.jit.export 177 def __setstate__(self, state): 178 self._packed_weight = torch.ops.quantized.linear_prepack(state[0]) 179 self.training = state[1] 180 181 @property 182 def weight(self): 183 return torch.ops.quantized.linear_unpack(self._packed_weight)[0] 184 185 @weight.setter 186 def weight(self, w): 187 self._packed_weight = torch.ops.quantized.linear_prepack(w) 188 189 with torch._jit_internal._disable_emit_hooks(): 190 x = torch.jit.script(Linear(10, 10)) 191 torch._C._jit_pass_erase_shape_information(x.graph) 192 193 194if __name__ == "__main__": 195 raise RuntimeError( 196 "This test file is not meant to be run directly, use:\n\n" 197 "\tpython test/test_quantization.py TESTNAME\n\n" 198 "instead." 199 ) 200