1# Owner(s): ["oncall: jit"] 2 3import copy 4import io 5import os 6import sys 7import unittest 8from typing import Optional 9 10import torch 11from torch.testing._internal.common_utils import skipIfTorchDynamo 12 13 14# Make the helper files in test/ importable 15pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 16sys.path.append(pytorch_test_dir) 17from torch.testing import FileCheck 18from torch.testing._internal.common_utils import ( 19 find_library_location, 20 IS_FBCODE, 21 IS_MACOS, 22 IS_SANDCASTLE, 23 IS_WINDOWS, 24) 25from torch.testing._internal.jit_utils import JitTestCase 26 27 28if __name__ == "__main__": 29 raise RuntimeError( 30 "This test file is not meant to be run directly, use:\n\n" 31 "\tpython test/test_jit.py TESTNAME\n\n" 32 "instead." 33 ) 34 35 36@skipIfTorchDynamo("skipping as a precaution") 37class TestTorchbind(JitTestCase): 38 def setUp(self): 39 if IS_SANDCASTLE or IS_MACOS or IS_FBCODE: 40 raise unittest.SkipTest("non-portable load_library call used in test") 41 lib_file_path = find_library_location("libtorchbind_test.so") 42 if IS_WINDOWS: 43 lib_file_path = find_library_location("torchbind_test.dll") 44 torch.ops.load_library(str(lib_file_path)) 45 46 def test_torchbind(self): 47 def test_equality(f, cmp_key): 48 obj1 = f() 49 obj2 = torch.jit.script(f)() 50 return (cmp_key(obj1), cmp_key(obj2)) 51 52 def f(): 53 val = torch.classes._TorchScriptTesting._Foo(5, 3) 54 val.increment(1) 55 return val 56 57 test_equality(f, lambda x: x) 58 59 with self.assertRaisesRegex(RuntimeError, "Expected a value of type 'int'"): 60 val = torch.classes._TorchScriptTesting._Foo(5, 3) 61 val.increment("foo") 62 63 def f(): 64 ss = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"]) 65 return ss.pop() 66 67 test_equality(f, lambda x: x) 68 69 def f(): 70 ss1 = torch.classes._TorchScriptTesting._StackString(["asdf", "bruh"]) 71 ss2 = torch.classes._TorchScriptTesting._StackString(["111", "222"]) 72 ss1.push(ss2.pop()) 73 return ss1.pop() + ss2.pop() 74 75 test_equality(f, lambda x: x) 76 77 # test nn module with prepare_scriptable function 78 class NonJitableClass: 79 def __init__(self, int1, int2): 80 self.int1 = int1 81 self.int2 = int2 82 83 def return_vals(self): 84 return self.int1, self.int2 85 86 class CustomWrapper(torch.nn.Module): 87 def __init__(self, foo): 88 super().__init__() 89 self.foo = foo 90 91 def forward(self) -> None: 92 self.foo.increment(1) 93 return 94 95 def __prepare_scriptable__(self): 96 int1, int2 = self.foo.return_vals() 97 foo = torch.classes._TorchScriptTesting._Foo(int1, int2) 98 return CustomWrapper(foo) 99 100 foo = CustomWrapper(NonJitableClass(1, 2)) 101 jit_foo = torch.jit.script(foo) 102 103 def test_torchbind_take_as_arg(self): 104 global StackString # see [local resolution in python] 105 StackString = torch.classes._TorchScriptTesting._StackString 106 107 def foo(stackstring): 108 # type: (StackString) 109 stackstring.push("lel") 110 return stackstring 111 112 script_input = torch.classes._TorchScriptTesting._StackString([]) 113 scripted = torch.jit.script(foo) 114 script_output = scripted(script_input) 115 self.assertEqual(script_output.pop(), "lel") 116 117 def test_torchbind_return_instance(self): 118 def foo(): 119 ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"]) 120 return ss 121 122 scripted = torch.jit.script(foo) 123 # Ensure we are creating the object and calling __init__ 124 # rather than calling the __init__wrapper nonsense 125 fc = ( 126 FileCheck() 127 .check("prim::CreateObject()") 128 .check('prim::CallMethod[name="__init__"]') 129 ) 130 fc.run(str(scripted.graph)) 131 out = scripted() 132 self.assertEqual(out.pop(), "mom") 133 self.assertEqual(out.pop(), "hi") 134 135 def test_torchbind_return_instance_from_method(self): 136 def foo(): 137 ss = torch.classes._TorchScriptTesting._StackString(["hi", "mom"]) 138 clone = ss.clone() 139 ss.pop() 140 return ss, clone 141 142 scripted = torch.jit.script(foo) 143 out = scripted() 144 self.assertEqual(out[0].pop(), "hi") 145 self.assertEqual(out[1].pop(), "mom") 146 self.assertEqual(out[1].pop(), "hi") 147 148 def test_torchbind_def_property_getter_setter(self): 149 def foo_getter_setter_full(): 150 fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6) 151 # getX method intentionally adds 2 to x 152 old = fooGetterSetter.x 153 # setX method intentionally adds 2 to x 154 fooGetterSetter.x = old + 4 155 new = fooGetterSetter.x 156 return old, new 157 158 self.checkScript(foo_getter_setter_full, ()) 159 160 def foo_getter_setter_lambda(): 161 foo = torch.classes._TorchScriptTesting._FooGetterSetterLambda(5) 162 old = foo.x 163 foo.x = old + 4 164 new = foo.x 165 return old, new 166 167 self.checkScript(foo_getter_setter_lambda, ()) 168 169 def test_torchbind_def_property_just_getter(self): 170 def foo_just_getter(): 171 fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6) 172 # getY method intentionally adds 4 to x 173 return fooGetterSetter, fooGetterSetter.y 174 175 scripted = torch.jit.script(foo_just_getter) 176 out, result = scripted() 177 self.assertEqual(result, 10) 178 179 with self.assertRaisesRegex(RuntimeError, "can't set attribute"): 180 out.y = 5 181 182 def foo_not_setter(): 183 fooGetterSetter = torch.classes._TorchScriptTesting._FooGetterSetter(5, 6) 184 old = fooGetterSetter.y 185 fooGetterSetter.y = old + 4 186 # getY method intentionally adds 4 to x 187 return fooGetterSetter.y 188 189 with self.assertRaisesRegexWithHighlight( 190 RuntimeError, 191 "Tried to set read-only attribute: y", 192 "fooGetterSetter.y = old + 4", 193 ): 194 scripted = torch.jit.script(foo_not_setter) 195 196 def test_torchbind_def_property_readwrite(self): 197 def foo_readwrite(): 198 fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6) 199 old = fooReadWrite.x 200 fooReadWrite.x = old + 4 201 return fooReadWrite.x, fooReadWrite.y 202 203 self.checkScript(foo_readwrite, ()) 204 205 def foo_readwrite_error(): 206 fooReadWrite = torch.classes._TorchScriptTesting._FooReadWrite(5, 6) 207 fooReadWrite.y = 5 208 return fooReadWrite 209 210 with self.assertRaisesRegexWithHighlight( 211 RuntimeError, "Tried to set read-only attribute: y", "fooReadWrite.y = 5" 212 ): 213 scripted = torch.jit.script(foo_readwrite_error) 214 215 def test_torchbind_take_instance_as_method_arg(self): 216 def foo(): 217 ss = torch.classes._TorchScriptTesting._StackString(["mom"]) 218 ss2 = torch.classes._TorchScriptTesting._StackString(["hi"]) 219 ss.merge(ss2) 220 return ss 221 222 scripted = torch.jit.script(foo) 223 out = scripted() 224 self.assertEqual(out.pop(), "hi") 225 self.assertEqual(out.pop(), "mom") 226 227 def test_torchbind_return_tuple(self): 228 def f(): 229 val = torch.classes._TorchScriptTesting._StackString(["3", "5"]) 230 return val.return_a_tuple() 231 232 scripted = torch.jit.script(f) 233 tup = scripted() 234 self.assertEqual(tup, (1337.0, 123)) 235 236 def test_torchbind_save_load(self): 237 def foo(): 238 ss = torch.classes._TorchScriptTesting._StackString(["mom"]) 239 ss2 = torch.classes._TorchScriptTesting._StackString(["hi"]) 240 ss.merge(ss2) 241 return ss 242 243 scripted = torch.jit.script(foo) 244 self.getExportImportCopy(scripted) 245 246 def test_torchbind_lambda_method(self): 247 def foo(): 248 ss = torch.classes._TorchScriptTesting._StackString(["mom"]) 249 return ss.top() 250 251 scripted = torch.jit.script(foo) 252 self.assertEqual(scripted(), "mom") 253 254 def test_torchbind_class_attr_recursive(self): 255 class FooBar(torch.nn.Module): 256 def __init__(self, foo_model): 257 super().__init__() 258 self.foo_mod = foo_model 259 260 def forward(self) -> int: 261 return self.foo_mod.info() 262 263 def to_ivalue(self): 264 torchbind_model = torch.classes._TorchScriptTesting._Foo( 265 self.foo_mod.info(), 1 266 ) 267 return FooBar(torchbind_model) 268 269 inst = FooBar(torch.classes._TorchScriptTesting._Foo(2, 3)) 270 scripted = torch.jit.script(inst.to_ivalue()) 271 self.assertEqual(scripted(), 6) 272 273 def test_torchbind_class_attribute(self): 274 class FooBar1234(torch.nn.Module): 275 def __init__(self) -> None: 276 super().__init__() 277 self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"]) 278 279 def forward(self): 280 return self.f.top() 281 282 inst = FooBar1234() 283 scripted = torch.jit.script(inst) 284 eic = self.getExportImportCopy(scripted) 285 assert eic() == "deserialized" 286 for expected in ["deserialized", "was", "i"]: 287 assert eic.f.pop() == expected 288 289 def test_torchbind_getstate(self): 290 class FooBar4321(torch.nn.Module): 291 def __init__(self) -> None: 292 super().__init__() 293 self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 294 295 def forward(self): 296 return self.f.top() 297 298 inst = FooBar4321() 299 scripted = torch.jit.script(inst) 300 eic = self.getExportImportCopy(scripted) 301 # NB: we expect the values {7, 3, 3, 1} as __getstate__ is defined to 302 # return {1, 3, 3, 7}. I tried to make this actually depend on the 303 # values at instantiation in the test with some transformation, but 304 # because it seems we serialize/deserialize multiple times, that 305 # transformation isn't as you would it expect it to be. 306 assert eic() == 7 307 for expected in [7, 3, 3, 1]: 308 assert eic.f.pop() == expected 309 310 def test_torchbind_deepcopy(self): 311 class FooBar4321(torch.nn.Module): 312 def __init__(self) -> None: 313 super().__init__() 314 self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 315 316 def forward(self): 317 return self.f.top() 318 319 inst = FooBar4321() 320 scripted = torch.jit.script(inst) 321 copied = copy.deepcopy(scripted) 322 assert copied.forward() == 7 323 for expected in [7, 3, 3, 1]: 324 assert copied.f.pop() == expected 325 326 def test_torchbind_python_deepcopy(self): 327 class FooBar4321(torch.nn.Module): 328 def __init__(self) -> None: 329 super().__init__() 330 self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 331 332 def forward(self): 333 return self.f.top() 334 335 inst = FooBar4321() 336 copied = copy.deepcopy(inst) 337 assert copied() == 7 338 for expected in [7, 3, 3, 1]: 339 assert copied.f.pop() == expected 340 341 def test_torchbind_tracing(self): 342 class TryTracing(torch.nn.Module): 343 def __init__(self) -> None: 344 super().__init__() 345 self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 346 347 def forward(self): 348 return torch.ops._TorchScriptTesting.take_an_instance(self.f) 349 350 traced = torch.jit.trace(TryTracing(), ()) 351 self.assertEqual(torch.zeros(4, 4), traced()) 352 353 def test_torchbind_pass_wrong_type(self): 354 with self.assertRaisesRegex(RuntimeError, "but instead found type 'Tensor'"): 355 torch.ops._TorchScriptTesting.take_an_instance(torch.rand(3, 4)) 356 357 def test_torchbind_tracing_nested(self): 358 class TryTracingNest(torch.nn.Module): 359 def __init__(self) -> None: 360 super().__init__() 361 self.f = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 362 363 class TryTracing123(torch.nn.Module): 364 def __init__(self) -> None: 365 super().__init__() 366 self.nest = TryTracingNest() 367 368 def forward(self): 369 return torch.ops._TorchScriptTesting.take_an_instance(self.nest.f) 370 371 traced = torch.jit.trace(TryTracing123(), ()) 372 self.assertEqual(torch.zeros(4, 4), traced()) 373 374 def test_torchbind_pickle_serialization(self): 375 nt = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 376 b = io.BytesIO() 377 torch.save(nt, b) 378 b.seek(0) 379 # weights_only=False as trying to load ScriptObject 380 nt_loaded = torch.load(b, weights_only=False) 381 for exp in [7, 3, 3, 1]: 382 self.assertEqual(nt_loaded.pop(), exp) 383 384 def test_torchbind_instantiate_missing_class(self): 385 with self.assertRaisesRegex( 386 RuntimeError, 387 "Tried to instantiate class 'foo.IDontExist', but it does not exist!", 388 ): 389 torch.classes.foo.IDontExist(3, 4, 5) 390 391 def test_torchbind_optional_explicit_attr(self): 392 class TorchBindOptionalExplicitAttr(torch.nn.Module): 393 foo: Optional[torch.classes._TorchScriptTesting._StackString] 394 395 def __init__(self) -> None: 396 super().__init__() 397 self.foo = torch.classes._TorchScriptTesting._StackString(["test"]) 398 399 def forward(self) -> str: 400 foo_obj = self.foo 401 if foo_obj is not None: 402 return foo_obj.pop() 403 else: 404 return "<None>" 405 406 mod = TorchBindOptionalExplicitAttr() 407 scripted = torch.jit.script(mod) 408 409 def test_torchbind_no_init(self): 410 with self.assertRaisesRegex(RuntimeError, "torch::init"): 411 x = torch.classes._TorchScriptTesting._NoInit() 412 413 def test_profiler_custom_op(self): 414 inst = torch.classes._TorchScriptTesting._PickleTester([3, 4]) 415 416 with torch.autograd.profiler.profile() as prof: 417 torch.ops._TorchScriptTesting.take_an_instance(inst) 418 419 found_event = False 420 for e in prof.function_events: 421 if e.name == "_TorchScriptTesting::take_an_instance": 422 found_event = True 423 self.assertTrue(found_event) 424 425 def test_torchbind_getattr(self): 426 foo = torch.classes._TorchScriptTesting._StackString(["test"]) 427 self.assertEqual(None, getattr(foo, "bar", None)) 428 429 def test_torchbind_attr_exception(self): 430 foo = torch.classes._TorchScriptTesting._StackString(["test"]) 431 with self.assertRaisesRegex(AttributeError, "does not have a field"): 432 foo.bar 433 434 def test_lambda_as_constructor(self): 435 obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False) 436 self.assertEqual(obj_no_swap.diff(), 1) 437 438 obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True) 439 self.assertEqual(obj_swap.diff(), -1) 440 441 def test_staticmethod(self): 442 def fn(inp: int) -> int: 443 return torch.classes._TorchScriptTesting._StaticMethod.staticMethod(inp) 444 445 self.checkScript(fn, (1,)) 446 447 def test_default_args(self): 448 def fn() -> int: 449 obj = torch.classes._TorchScriptTesting._DefaultArgs() 450 obj.increment(5) 451 obj.decrement() 452 obj.decrement(2) 453 obj.divide() 454 obj.scale_add(5) 455 obj.scale_add(3, 2) 456 obj.divide(3) 457 return obj.increment() 458 459 self.checkScript(fn, ()) 460 461 def gn() -> int: 462 obj = torch.classes._TorchScriptTesting._DefaultArgs(5) 463 obj.increment(3) 464 obj.increment() 465 obj.decrement(2) 466 obj.divide() 467 obj.scale_add(3) 468 obj.scale_add(3, 2) 469 obj.divide(2) 470 return obj.decrement() 471 472 self.checkScript(gn, ()) 473