1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport io 4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 5*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.bundled_inputs 9*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.mobile import _load_for_lite_interpreter 10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = Path(__file__).resolve().parents[1] 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerclass TestLiteScriptModule(TestCase): 17*da0073e9SAndroid Build Coastguard Worker def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule): 18*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO( 19*da0073e9SAndroid Build Coastguard Worker script_module._save_to_buffer_for_lite_interpreter( 20*da0073e9SAndroid Build Coastguard Worker _save_mobile_debug_info=True 21*da0073e9SAndroid Build Coastguard Worker ) 22*da0073e9SAndroid Build Coastguard Worker ) 23*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 24*da0073e9SAndroid Build Coastguard Worker mobile_module = _load_for_lite_interpreter(buffer) 25*da0073e9SAndroid Build Coastguard Worker return mobile_module 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def _try_fn(self, fn, *args, **kwargs): 28*da0073e9SAndroid Build Coastguard Worker try: 29*da0073e9SAndroid Build Coastguard Worker return fn(*args, **kwargs) 30*da0073e9SAndroid Build Coastguard Worker except Exception as e: 31*da0073e9SAndroid Build Coastguard Worker return e 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_versioned_div_tensor(self): 34*da0073e9SAndroid Build Coastguard Worker def div_tensor_0_3(self, other): 35*da0073e9SAndroid Build Coastguard Worker if self.is_floating_point() or other.is_floating_point(): 36*da0073e9SAndroid Build Coastguard Worker return self.true_divide(other) 37*da0073e9SAndroid Build Coastguard Worker return self.divide(other, rounding_mode="trunc") 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker model_path = ( 40*da0073e9SAndroid Build Coastguard Worker pytorch_test_dir 41*da0073e9SAndroid Build Coastguard Worker / "cpp" 42*da0073e9SAndroid Build Coastguard Worker / "jit" 43*da0073e9SAndroid Build Coastguard Worker / "upgrader_models" 44*da0073e9SAndroid Build Coastguard Worker / "test_versioned_div_tensor_v2.ptl" 45*da0073e9SAndroid Build Coastguard Worker ) 46*da0073e9SAndroid Build Coastguard Worker mobile_module_v2 = _load_for_lite_interpreter(str(model_path)) 47*da0073e9SAndroid Build Coastguard Worker jit_module_v2 = torch.jit.load(str(model_path)) 48*da0073e9SAndroid Build Coastguard Worker current_mobile_module = self._save_load_mobile_module(jit_module_v2) 49*da0073e9SAndroid Build Coastguard Worker vals = (2.0, 3.0, 2, 3) 50*da0073e9SAndroid Build Coastguard Worker for val_a, val_b in product(vals, vals): 51*da0073e9SAndroid Build Coastguard Worker a = torch.tensor((val_a,)) 52*da0073e9SAndroid Build Coastguard Worker b = torch.tensor((val_b,)) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker def _helper(m, fn): 55*da0073e9SAndroid Build Coastguard Worker m_results = self._try_fn(m, a, b) 56*da0073e9SAndroid Build Coastguard Worker fn_result = self._try_fn(fn, a, b) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker if isinstance(m_results, Exception): 59*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(fn_result, Exception)) 60*da0073e9SAndroid Build Coastguard Worker else: 61*da0073e9SAndroid Build Coastguard Worker for result in m_results: 62*da0073e9SAndroid Build Coastguard Worker print("result: ", result) 63*da0073e9SAndroid Build Coastguard Worker print("fn_result: ", fn_result) 64*da0073e9SAndroid Build Coastguard Worker print(result == fn_result) 65*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.eq(fn_result)) 66*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(result, fn_result) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker # old operator should produce the same result as applying upgrader of torch.div op 69*da0073e9SAndroid Build Coastguard Worker # _helper(mobile_module_v2, div_tensor_0_3) 70*da0073e9SAndroid Build Coastguard Worker # latest operator should produce the same result as applying torch.div op 71*da0073e9SAndroid Build Coastguard Worker # _helper(current_mobile_module, torch.div) 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 75*da0073e9SAndroid Build Coastguard Worker run_tests() 76