1# Owner(s): ["oncall: mobile"] 2 3import io 4from itertools import product 5from pathlib import Path 6 7import torch 8import torch.utils.bundled_inputs 9from torch.jit.mobile import _load_for_lite_interpreter 10from torch.testing._internal.common_utils import run_tests, TestCase 11 12 13pytorch_test_dir = Path(__file__).resolve().parents[1] 14 15 16class TestLiteScriptModule(TestCase): 17 def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule): 18 buffer = io.BytesIO( 19 script_module._save_to_buffer_for_lite_interpreter( 20 _save_mobile_debug_info=True 21 ) 22 ) 23 buffer.seek(0) 24 mobile_module = _load_for_lite_interpreter(buffer) 25 return mobile_module 26 27 def _try_fn(self, fn, *args, **kwargs): 28 try: 29 return fn(*args, **kwargs) 30 except Exception as e: 31 return e 32 33 def test_versioned_div_tensor(self): 34 def div_tensor_0_3(self, other): 35 if self.is_floating_point() or other.is_floating_point(): 36 return self.true_divide(other) 37 return self.divide(other, rounding_mode="trunc") 38 39 model_path = ( 40 pytorch_test_dir 41 / "cpp" 42 / "jit" 43 / "upgrader_models" 44 / "test_versioned_div_tensor_v2.ptl" 45 ) 46 mobile_module_v2 = _load_for_lite_interpreter(str(model_path)) 47 jit_module_v2 = torch.jit.load(str(model_path)) 48 current_mobile_module = self._save_load_mobile_module(jit_module_v2) 49 vals = (2.0, 3.0, 2, 3) 50 for val_a, val_b in product(vals, vals): 51 a = torch.tensor((val_a,)) 52 b = torch.tensor((val_b,)) 53 54 def _helper(m, fn): 55 m_results = self._try_fn(m, a, b) 56 fn_result = self._try_fn(fn, a, b) 57 58 if isinstance(m_results, Exception): 59 self.assertTrue(isinstance(fn_result, Exception)) 60 else: 61 for result in m_results: 62 print("result: ", result) 63 print("fn_result: ", fn_result) 64 print(result == fn_result) 65 self.assertTrue(result.eq(fn_result)) 66 # self.assertEqual(result, fn_result) 67 68 # old operator should produce the same result as applying upgrader of torch.div op 69 # _helper(mobile_module_v2, div_tensor_0_3) 70 # latest operator should produce the same result as applying torch.div op 71 # _helper(current_mobile_module, torch.div) 72 73 74if __name__ == "__main__": 75 run_tests() 76