xref: /aosp_15_r20/external/pytorch/test/jit/test_torchbind.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import copy
4import io
5import os
6import sys
7import unittest
8from typing import Optional
9
10import torch
11from torch.testing._internal.common_utils import skipIfTorchDynamo
12
13
14# Make the helper files in test/ importable
15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16sys.path.append(pytorch_test_dir)
17from torch.testing import FileCheck
18from torch.testing._internal.common_utils import (
19    find_library_location,
20    IS_FBCODE,
21    IS_MACOS,
22    IS_SANDCASTLE,
23    IS_WINDOWS,
24)
25from torch.testing._internal.jit_utils import JitTestCase
26
27
28if __name__ == "__main__":
29    raise RuntimeError(
30        "This test file is not meant to be run directly, use:\n\n"
31        "\tpython test/test_jit.py TESTNAME\n\n"
32        "instead."
33    )
34
35
36@skipIfTorchDynamo("skipping as a precaution")
37class TestTorchbind(JitTestCase):
38    def setUp(self):
39        if IS_SANDCASTLE or IS_MACOS or IS_FBCODE:
40            raise unittest.SkipTest("non-portable load_library call used in test")
41        lib_file_path = find_library_location("libtorchbind_test.so")
42        if IS_WINDOWS:
43            lib_file_path = find_library_location("torchbind_test.dll")
44        torch.ops.load_library(str(lib_file_path))
45
46    def test_torchbind(self):
47        def test_equality(f, cmp_key):
48            obj1 = f()
49            obj2 = torch.jit.script(f)()
50            return (cmp_key(obj1), cmp_key(obj2))
51
52        def f():
53            val = torch.classes._TorchScriptTesting._Foo(5, 3)
54            val.increment(1)
55            return val
56
57        test_equality(f, lambda x: x)
58
59        with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"):
60            val = torch.classes._TorchScriptTesting._Foo(5, 3)
61            val.increment("foo")
62
63        def f():
64            ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
65            return ss.pop()
66
67        test_equality(f, lambda x: x)
68
69        def f():
70            ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"])
71            ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"])
72            ss1.push(ss2.pop())
73            return ss1.pop() + ss2.pop()
74
75        test_equality(f, lambda x: x)
76
77        # test nn module with prepare_scriptable function
78        class NonJitableClass:
79            def __init__(self, int1, int2):
80                self.int1 = int1
81                self.int2 = int2
82
83            def return_vals(self):
84                return self.int1, self.int2
85
86        class CustomWrapper(torch.nn.Module):
87            def __init__(self, foo):
88                super().__init__()
89                self.foo = foo
90
91            def forward(self) -> None:
92                self.foo.increment(1)
93                return
94
95            def __prepare_scriptable__(self):
96                int1, int2 = self.foo.return_vals()
97                foo = torch.classes._TorchScriptTesting._Foo(int1, int2)
98                return CustomWrapper(foo)
99
100        foo = CustomWrapper(NonJitableClass(1, 2))
101        jit_foo = torch.jit.script(foo)
102
103    def test_torchbind_take_as_arg(self):
104        global StackString  # see [local resolution in python]
105        StackString = torch.classes._TorchScriptTesting._StackString
106
107        def foo(stackstring):
108            # type: (StackString)
109            stackstring.push("lel")
110            return stackstring
111
112        script_input = torch.classes._TorchScriptTesting._StackString([])
113        scripted = torch.jit.script(foo)
114        script_output = scripted(script_input)
115        self.assertEqual(script_output.pop(), "lel")
116
117    def test_torchbind_return_instance(self):
118        def foo():
119            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
120            return ss
121
122        scripted = torch.jit.script(foo)
123        # Ensure we are creating the object and calling __init__
124        # rather than calling the __init__wrapper nonsense
125        fc = (
126            FileCheck()
127            .check("prim::CreateObject()")
128            .check('prim::CallMethod[name="__init__"]')
129        )
130        fc.run(str(scripted.graph))
131        out = scripted()
132        self.assertEqual(out.pop(), "mom")
133        self.assertEqual(out.pop(), "hi")
134
135    def test_torchbind_return_instance_from_method(self):
136        def foo():
137            ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"])
138            clone = ss.clone()
139            ss.pop()
140            return ss, clone
141
142        scripted = torch.jit.script(foo)
143        out = scripted()
144        self.assertEqual(out[0].pop(), "hi")
145        self.assertEqual(out[1].pop(), "mom")
146        self.assertEqual(out[1].pop(), "hi")
147
148    def test_torchbind_def_property_getter_setter(self):
149        def foo_getter_setter_full():
150            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
151            # getX method intentionally adds 2 to x
152            old = fooGetterSetter.x
153            # setX method intentionally adds 2 to x
154            fooGetterSetter.x = old + 4
155            new = fooGetterSetter.x
156            return old, new
157
158        self.checkScript(foo_getter_setter_full, ())
159
160        def foo_getter_setter_lambda():
161            foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5)
162            old = foo.x
163            foo.x = old + 4
164            new = foo.x
165            return old, new
166
167        self.checkScript(foo_getter_setter_lambda, ())
168
169    def test_torchbind_def_property_just_getter(self):
170        def foo_just_getter():
171            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
172            # getY method intentionally adds 4 to x
173            return fooGetterSetter, fooGetterSetter.y
174
175        scripted = torch.jit.script(foo_just_getter)
176        out, result = scripted()
177        self.assertEqual(result, 10)
178
179        with self.assertRaisesRegex(RuntimeError, "can't set attribute"):
180            out.y = 5
181
182        def foo_not_setter():
183            fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6)
184            old = fooGetterSetter.y
185            fooGetterSetter.y = old + 4
186            # getY method intentionally adds 4 to x
187            return fooGetterSetter.y
188
189        with self.assertRaisesRegexWithHighlight(
190            RuntimeError,
191            "Tried to set read-only attribute: y",
192            "fooGetterSetter.y = old + 4",
193        ):
194            scripted = torch.jit.script(foo_not_setter)
195
196    def test_torchbind_def_property_readwrite(self):
197        def foo_readwrite():
198            fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
199            old = fooReadWrite.x
200            fooReadWrite.x = old + 4
201            return fooReadWrite.x, fooReadWrite.y
202
203        self.checkScript(foo_readwrite, ())
204
205        def foo_readwrite_error():
206            fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6)
207            fooReadWrite.y = 5
208            return fooReadWrite
209
210        with self.assertRaisesRegexWithHighlight(
211            RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5"
212        ):
213            scripted = torch.jit.script(foo_readwrite_error)
214
215    def test_torchbind_take_instance_as_method_arg(self):
216        def foo():
217            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
218            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
219            ss.merge(ss2)
220            return ss
221
222        scripted = torch.jit.script(foo)
223        out = scripted()
224        self.assertEqual(out.pop(), "hi")
225        self.assertEqual(out.pop(), "mom")
226
227    def test_torchbind_return_tuple(self):
228        def f():
229            val = torch.classes._TorchScriptTesting._StackString(["3", "5"])
230            return val.return_a_tuple()
231
232        scripted = torch.jit.script(f)
233        tup = scripted()
234        self.assertEqual(tup, (1337.0, 123))
235
236    def test_torchbind_save_load(self):
237        def foo():
238            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
239            ss2 = torch.classes._TorchScriptTesting._StackString(["hi"])
240            ss.merge(ss2)
241            return ss
242
243        scripted = torch.jit.script(foo)
244        self.getExportImportCopy(scripted)
245
246    def test_torchbind_lambda_method(self):
247        def foo():
248            ss = torch.classes._TorchScriptTesting._StackString(["mom"])
249            return ss.top()
250
251        scripted = torch.jit.script(foo)
252        self.assertEqual(scripted(), "mom")
253
254    def test_torchbind_class_attr_recursive(self):
255        class FooBar(torch.nn.Module):
256            def __init__(self, foo_model):
257                super().__init__()
258                self.foo_mod = foo_model
259
260            def forward(self) -> int:
261                return self.foo_mod.info()
262
263            def to_ivalue(self):
264                torchbind_model = torch.classes._TorchScriptTesting._Foo(
265                    self.foo_mod.info(), 1
266                )
267                return FooBar(torchbind_model)
268
269        inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3))
270        scripted = torch.jit.script(inst.to_ivalue())
271        self.assertEqual(scripted(), 6)
272
273    def test_torchbind_class_attribute(self):
274        class FooBar1234(torch.nn.Module):
275            def __init__(self) -> None:
276                super().__init__()
277                self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
278
279            def forward(self):
280                return self.f.top()
281
282        inst = FooBar1234()
283        scripted = torch.jit.script(inst)
284        eic = self.getExportImportCopy(scripted)
285        assert eic() == "deserialized"
286        for expected in ["deserialized", "was", "i"]:
287            assert eic.f.pop() == expected
288
289    def test_torchbind_getstate(self):
290        class FooBar4321(torch.nn.Module):
291            def __init__(self) -> None:
292                super().__init__()
293                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
294
295            def forward(self):
296                return self.f.top()
297
298        inst = FooBar4321()
299        scripted = torch.jit.script(inst)
300        eic = self.getExportImportCopy(scripted)
301        # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to
302        # return {1, 3, 3, 7}. I tried to make this actually depend on the
303        # values at instantiation in the test with some transformation, but
304        # because it seems we serialize/deserialize multiple times, that
305        # transformation isn't as you would it expect it to be.
306        assert eic() == 7
307        for expected in [7, 3, 3, 1]:
308            assert eic.f.pop() == expected
309
310    def test_torchbind_deepcopy(self):
311        class FooBar4321(torch.nn.Module):
312            def __init__(self) -> None:
313                super().__init__()
314                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
315
316            def forward(self):
317                return self.f.top()
318
319        inst = FooBar4321()
320        scripted = torch.jit.script(inst)
321        copied = copy.deepcopy(scripted)
322        assert copied.forward() == 7
323        for expected in [7, 3, 3, 1]:
324            assert copied.f.pop() == expected
325
326    def test_torchbind_python_deepcopy(self):
327        class FooBar4321(torch.nn.Module):
328            def __init__(self) -> None:
329                super().__init__()
330                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
331
332            def forward(self):
333                return self.f.top()
334
335        inst = FooBar4321()
336        copied = copy.deepcopy(inst)
337        assert copied() == 7
338        for expected in [7, 3, 3, 1]:
339            assert copied.f.pop() == expected
340
341    def test_torchbind_tracing(self):
342        class TryTracing(torch.nn.Module):
343            def __init__(self) -> None:
344                super().__init__()
345                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
346
347            def forward(self):
348                return torch.ops._TorchScriptTesting.take_an_instance(self.f)
349
350        traced = torch.jit.trace(TryTracing(), ())
351        self.assertEqual(torch.zeros(4, 4), traced())
352
353    def test_torchbind_pass_wrong_type(self):
354        with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"):
355            torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4))
356
357    def test_torchbind_tracing_nested(self):
358        class TryTracingNest(torch.nn.Module):
359            def __init__(self) -> None:
360                super().__init__()
361                self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4])
362
363        class TryTracing123(torch.nn.Module):
364            def __init__(self) -> None:
365                super().__init__()
366                self.nest = TryTracingNest()
367
368            def forward(self):
369                return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f)
370
371        traced = torch.jit.trace(TryTracing123(), ())
372        self.assertEqual(torch.zeros(4, 4), traced())
373
374    def test_torchbind_pickle_serialization(self):
375        nt = torch.classes._TorchScriptTesting._PickleTester([3, 4])
376        b = io.BytesIO()
377        torch.save(nt, b)
378        b.seek(0)
379        # weights_only=False as trying to load ScriptObject
380        nt_loaded = torch.load(b, weights_only=False)
381        for exp in [7, 3, 3, 1]:
382            self.assertEqual(nt_loaded.pop(), exp)
383
384    def test_torchbind_instantiate_missing_class(self):
385        with self.assertRaisesRegex(
386            RuntimeError,
387            "Tried to instantiate class 'foo.IDontExist', but it does not exist!",
388        ):
389            torch.classes.foo.IDontExist(3, 4, 5)
390
391    def test_torchbind_optional_explicit_attr(self):
392        class TorchBindOptionalExplicitAttr(torch.nn.Module):
393            foo: Optional[torch.classes._TorchScriptTesting._StackString]
394
395            def __init__(self) -> None:
396                super().__init__()
397                self.foo = torch.classes._TorchScriptTesting._StackString(["test"])
398
399            def forward(self) -> str:
400                foo_obj = self.foo
401                if foo_obj is not None:
402                    return foo_obj.pop()
403                else:
404                    return "<None>"
405
406        mod = TorchBindOptionalExplicitAttr()
407        scripted = torch.jit.script(mod)
408
409    def test_torchbind_no_init(self):
410        with self.assertRaisesRegex(RuntimeError, "torch::init"):
411            x = torch.classes._TorchScriptTesting._NoInit()
412
413    def test_profiler_custom_op(self):
414        inst = torch.classes._TorchScriptTesting._PickleTester([3, 4])
415
416        with torch.autograd.profiler.profile() as prof:
417            torch.ops._TorchScriptTesting.take_an_instance(inst)
418
419        found_event = False
420        for e in prof.function_events:
421            if e.name == "_TorchScriptTesting::take_an_instance":
422                found_event = True
423        self.assertTrue(found_event)
424
425    def test_torchbind_getattr(self):
426        foo = torch.classes._TorchScriptTesting._StackString(["test"])
427        self.assertEqual(None, getattr(foo, "bar", None))
428
429    def test_torchbind_attr_exception(self):
430        foo = torch.classes._TorchScriptTesting._StackString(["test"])
431        with self.assertRaisesRegex(AttributeError, "does not have a field"):
432            foo.bar
433
434    def test_lambda_as_constructor(self):
435        obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
436        self.assertEqual(obj_no_swap.diff(), 1)
437
438        obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
439        self.assertEqual(obj_swap.diff(), -1)
440
441    def test_staticmethod(self):
442        def fn(inp: int) -> int:
443            return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp)
444
445        self.checkScript(fn, (1,))
446
447    def test_default_args(self):
448        def fn() -> int:
449            obj = torch.classes._TorchScriptTesting._DefaultArgs()
450            obj.increment(5)
451            obj.decrement()
452            obj.decrement(2)
453            obj.divide()
454            obj.scale_add(5)
455            obj.scale_add(3, 2)
456            obj.divide(3)
457            return obj.increment()
458
459        self.checkScript(fn, ())
460
461        def gn() -> int:
462            obj = torch.classes._TorchScriptTesting._DefaultArgs(5)
463            obj.increment(3)
464            obj.increment()
465            obj.decrement(2)
466            obj.divide()
467            obj.scale_add(3)
468            obj.scale_add(3, 2)
469            obj.divide(2)
470            return obj.decrement()
471
472        self.checkScript(gn, ())
473