xref: /aosp_15_r20/external/pytorch/test/mobile/test_upgraders.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport io
4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
5*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.utils.bundled_inputs
9*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.mobile import _load_for_lite_interpreter
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, TestCase
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = Path(__file__).resolve().parents[1]
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerclass TestLiteScriptModule(TestCase):
17*da0073e9SAndroid Build Coastguard Worker    def _save_load_mobile_module(self, script_module: torch.jit.ScriptModule):
18*da0073e9SAndroid Build Coastguard Worker        buffer = io.BytesIO(
19*da0073e9SAndroid Build Coastguard Worker            script_module._save_to_buffer_for_lite_interpreter(
20*da0073e9SAndroid Build Coastguard Worker                _save_mobile_debug_info=True
21*da0073e9SAndroid Build Coastguard Worker            )
22*da0073e9SAndroid Build Coastguard Worker        )
23*da0073e9SAndroid Build Coastguard Worker        buffer.seek(0)
24*da0073e9SAndroid Build Coastguard Worker        mobile_module = _load_for_lite_interpreter(buffer)
25*da0073e9SAndroid Build Coastguard Worker        return mobile_module
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def _try_fn(self, fn, *args, **kwargs):
28*da0073e9SAndroid Build Coastguard Worker        try:
29*da0073e9SAndroid Build Coastguard Worker            return fn(*args, **kwargs)
30*da0073e9SAndroid Build Coastguard Worker        except Exception as e:
31*da0073e9SAndroid Build Coastguard Worker            return e
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_versioned_div_tensor(self):
34*da0073e9SAndroid Build Coastguard Worker        def div_tensor_0_3(self, other):
35*da0073e9SAndroid Build Coastguard Worker            if self.is_floating_point() or other.is_floating_point():
36*da0073e9SAndroid Build Coastguard Worker                return self.true_divide(other)
37*da0073e9SAndroid Build Coastguard Worker            return self.divide(other, rounding_mode="trunc")
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker        model_path = (
40*da0073e9SAndroid Build Coastguard Worker            pytorch_test_dir
41*da0073e9SAndroid Build Coastguard Worker            / "cpp"
42*da0073e9SAndroid Build Coastguard Worker            / "jit"
43*da0073e9SAndroid Build Coastguard Worker            / "upgrader_models"
44*da0073e9SAndroid Build Coastguard Worker            / "test_versioned_div_tensor_v2.ptl"
45*da0073e9SAndroid Build Coastguard Worker        )
46*da0073e9SAndroid Build Coastguard Worker        mobile_module_v2 = _load_for_lite_interpreter(str(model_path))
47*da0073e9SAndroid Build Coastguard Worker        jit_module_v2 = torch.jit.load(str(model_path))
48*da0073e9SAndroid Build Coastguard Worker        current_mobile_module = self._save_load_mobile_module(jit_module_v2)
49*da0073e9SAndroid Build Coastguard Worker        vals = (2.0, 3.0, 2, 3)
50*da0073e9SAndroid Build Coastguard Worker        for val_a, val_b in product(vals, vals):
51*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor((val_a,))
52*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor((val_b,))
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker            def _helper(m, fn):
55*da0073e9SAndroid Build Coastguard Worker                m_results = self._try_fn(m, a, b)
56*da0073e9SAndroid Build Coastguard Worker                fn_result = self._try_fn(fn, a, b)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker                if isinstance(m_results, Exception):
59*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(fn_result, Exception))
60*da0073e9SAndroid Build Coastguard Worker                else:
61*da0073e9SAndroid Build Coastguard Worker                    for result in m_results:
62*da0073e9SAndroid Build Coastguard Worker                        print("result: ", result)
63*da0073e9SAndroid Build Coastguard Worker                        print("fn_result: ", fn_result)
64*da0073e9SAndroid Build Coastguard Worker                        print(result == fn_result)
65*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(result.eq(fn_result))
66*da0073e9SAndroid Build Coastguard Worker                        # self.assertEqual(result, fn_result)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            # old operator should produce the same result as applying upgrader of torch.div op
69*da0073e9SAndroid Build Coastguard Worker            # _helper(mobile_module_v2, div_tensor_0_3)
70*da0073e9SAndroid Build Coastguard Worker            # latest operator should produce the same result as applying torch.div op
71*da0073e9SAndroid Build Coastguard Worker            # _helper(current_mobile_module, torch.div)
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
75*da0073e9SAndroid Build Coastguard Worker    run_tests()
76