# Owner(s): ["oncall: jit"] import os import sys from typing import Any, List import torch from torch.testing._internal.common_utils import skipIfTorchDynamo from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestWith(JitTestCase): """ A suite of tests for with statements. """ def test_with_as(self): """ Check that with statements that use the 'as' keyword to bind expressions to targets work as expected. """ @torch.jit.script class Context: """ This class implements a basic context manager interface for use in the unit tests. Unlike Context, the stateful part of this class is a Tensor that is mutated in-place so that modifications made in the JIT interpreter are visible outside of it. """ def __init__(self, start: int): self.count = torch.tensor([start], dtype=torch.double) def __enter__(self): self.count.add_(0.3) return self.count def __exit__(self, type: Any, value: Any, tb: Any) -> bool: self.count.sub_(0.3) return True make_global(Context) def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) with c as mult: y = x + mult y *= c.count return y def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should still be called. """ c = Context(1) with c as mult: pass x *= c.count return x def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. """ with c as mult: y = x + mult return y x = y + y return x def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. """ with c as mult: y = x + mult if mult > 0: return y x = y + y return x def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. """ with c as mult: for a in l: if a == 0: break x += a * mult return x def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. """ with c as mult: for a in l: if a == 0: continue x += a * mult return x def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ c = Context(1) with c as mult: y = x + mult with c as mult: y *= mult return y def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ c = Context(1) with c as m: with c as n: y = x + n y *= m return y def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ c = Context(1) d = Context(2) with c as m, d as n: y = x + (m + n) return y test_input = torch.randn(2, 2) test_context = Context(2) test_list = [2, 0, 1, 3, 0, 2] self.checkScript(test_basic, (test_input,)) self.checkScript(test_pass, (test_input,)) self.checkScript(test_early_return, (test_input, test_context)) self.checkScript(test_break, (test_input, test_context, test_list)) self.checkScript(test_continue, (test_input, test_context, test_list)) self.assertEqual(test_context.count, 2) self.checkScript(test_serial, (test_input,)) self.checkScript(test_nested, (test_input,)) self.checkScript(test_combined, (test_input,)) def test_with_no_as(self): """ Check that with statements that do not use the 'as' keyword to bind expressions to targets work as expected. """ @torch.jit.script class Context: """ This class implements a basic context manager interface for use in the unit tests. Unlike Context, the stateful part of this class is a Tensor that is mutated in-place so that modifications made in the JIT interpreter are visible outside of it. """ def __init__(self, start: int): self.count = torch.tensor([start], dtype=torch.double) def __enter__(self): self.count.add_(0.3) return self.count def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) make_global(Context) def test_basic(x: torch.Tensor) -> torch.Tensor: """Basic test with one with-statement.""" c = Context(1) with c: y = x + c.count y *= c.count return y def test_pass(x: torch.Tensor) -> torch.Tensor: """ Test with a pass statement inside a with-statement. Although the body of the with is empty, __enter__ and __exit__ should still be called. """ c = Context(1) with c: pass x *= c.count return x def test_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that returning early from inside a with-statement works as expected. """ with c: y = x + c.count return y x = y + y return x def test_conditional_early_return(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test that conditionally returning early from inside a with-statement works as expected. """ with c: y = x + c.count if c.count > 0: return y x = y + y return x def test_break(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that breaking early from inside a with-statement works as expected. """ with c: for a in l: if a == 0: break x += a * c.count return x def test_continue(x: torch.Tensor, c: Context, l: List[int]) -> torch.Tensor: """ Test that using continue inside a with-statement works as expected. """ with c: for a in l: if a == 0: continue x += a * c.count return x def test_serial(x: torch.Tensor) -> torch.Tensor: """ Test two with-statements in a row. """ c = Context(1) with c: y = x + c.count with c: y *= c.count return y def test_nested(x: torch.Tensor) -> torch.Tensor: """ Test nested with-statements. """ c = Context(1) with c: with c: y = x + c.count y *= c.count return y def test_combined(x: torch.Tensor) -> torch.Tensor: """ Test a with-statement with multiple with items. """ c = Context(1) d = Context(2) with c, d: y = x + (c.count + d.count) return y test_input = torch.randn(2, 2) test_context = Context(2) test_list = [2, 0, 1, 3, 0, 2] self.checkScript(test_basic, (test_input,)) self.checkScript(test_pass, (test_input,)) self.checkScript(test_early_return, (test_input, test_context)) self.checkScript(test_break, (test_input, test_context, test_list)) self.checkScript(test_continue, (test_input, test_context, test_list)) self.assertEqual(test_context.count, 2) self.checkScript(test_serial, (test_input,)) self.checkScript(test_nested, (test_input,)) self.checkScript(test_combined, (test_input,)) def test_with_exceptions(self): """ Check that exceptions thrown in the bodies of with-statements are handled correctly. """ @torch.jit.script class Context: """ This class implements a basic context manager interface for use in the unit tests. Unlike Context, the stateful part of this class is a Tensor that is mutated in-place so that modifications made in the JIT interpreter are visible outside of it. """ def __init__(self, start: int): self.count = torch.tensor([start], dtype=torch.double) def __enter__(self): self.count.add_(0.3) return self.count def __exit__(self, type: Any, value: Any, tb: Any): self.count.sub_(0.3) make_global(Context) @torch.jit.script def method_that_raises() -> torch.Tensor: raise Exception("raised exception") # noqa: TRY002 @torch.jit.script def test_exception(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a with-statement. """ with c as _: x += method_that_raises() return x @torch.jit.script def test_exception_nested(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while executing the body of a nested with-statement. """ with c as _: with c as _: x += method_that_raises() return x @torch.jit.script def with_that_raises(c: Context) -> torch.Tensor: a = torch.tensor([1]) with c as _: a += method_that_raises() return a @torch.jit.script def test_exception_fn_call(x: torch.Tensor, c: Context) -> torch.Tensor: """ Test the case in which an exception is thrown while there are active with-statements in two different frames. """ with c as _: x += with_that_raises(c) return x c = Context(1) # checkScript and checkScriptRaisesRegex cannot be used because the string frontend will # not compile class types (of which Context, the context manager being used for this test # is one). with self.assertRaisesRegexWithHighlight( Exception, r"raised exception", 'raise Exception("raised exception' ): test_exception(torch.randn(2), c) self.assertEqual(c.count, 1) with self.assertRaisesRegexWithHighlight( Exception, r"raised exception", 'raise Exception("raised exception' ): test_exception_nested(torch.randn(2), c) self.assertEqual(c.count, 1) with self.assertRaisesRegexWithHighlight( Exception, r"raised exception", 'raise Exception("raised exception' ): test_exception_fn_call(torch.randn(2), c) self.assertEqual(c.count, 1) def test_with_errors(self): """ Check that errors related to with-statements are detected and reported correctly. """ @torch.jit.script class NoEnterNoExit: """ This class is missing __enter__ and __exit__ methods. """ def __init__(self) -> None: self.count = 1 @torch.jit.script class BadEnter: """ This class has an __enter__ method with an incorrect signature. """ def __init__(self) -> None: self.count = 1 def __enter__(self, incr: int): # noqa: PLE0302 self.count += incr def __exit__(self, type: Any, value: Any, tb: Any): pass @torch.jit.script class BadExit: """ This class has an __exit__ method with an incorrect signature. """ def __init__(self) -> None: self.count = 1 def __enter__(self): self.count += 1 def __exit__(self, type: Any, value: Any): # noqa: PLE0302 pass @torch.jit.script class ExitIncorrectTypes: """ This class has an __exit__ method with unsupported argument types. """ def __init__(self) -> None: self.count = 1 def __enter__(self): self.count += 1 def __exit__(self, type: Any, value: int, tb: int): pass def test_no_enter_no_exit(x: torch.Tensor, cm: NoEnterNoExit) -> torch.Tensor: with cm as _: pass return x def test_bad_enter(x: torch.Tensor, cm: BadEnter) -> torch.Tensor: with cm as _: pass return x def test_bad_exit(x: torch.Tensor, cm: BadExit) -> torch.Tensor: with cm as _: pass return x def test_exit_incorrect_types( x: torch.Tensor, cm: ExitIncorrectTypes ) -> torch.Tensor: with cm as _: pass return x def test_enter_without_object(): with "not_object" as obj: pass test_tensor = torch.randn(5, dtype=torch.double) with self.assertRaisesRegexWithHighlight( RuntimeError, r"does not define __enter__ and __exit__ methods", "cm" ): self.checkScript(test_no_enter_no_exit, (test_tensor, NoEnterNoExit())) with self.assertRaisesRegexWithHighlight( RuntimeError, r"__enter__ must have only one argument and one return value", "cm", ): self.checkScript(test_bad_enter, (test_tensor, BadEnter())) with self.assertRaisesRegexWithHighlight( RuntimeError, r"__exit__ must have four arguments", "cm" ): self.checkScript(test_bad_exit, (test_tensor, BadExit())) with self.assertRaisesRegexWithHighlight( RuntimeError, r"argument 2 of __exit__ must have Any type", "cm" ): self.checkScript( test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes()) ) with self.assertRaisesRegexWithHighlight( RuntimeError, r"must return an object", '"not_object"' ): self.checkScript(test_enter_without_object, ()) def test_with_no_grad(self): """ Check that torch.no_grad() works. Most of these are adapted from corresponding tests for eager-mode no_grad. """ # Basic no_grad test. def test_no_grad(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = x + y return w s = torch.jit.script(test_no_grad) x = torch.ones(5, 5, requires_grad=True) y = torch.ones(5, 5) * 4 w = s(x, y) self.assertFalse(w.requires_grad) self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) self.assertIsNone(w.grad_fn) # Test assignment of a grad-less Tensor to a Tensor with gradients # in a no_grad block. def test_no_grad_assignment(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): x[0] = y return x s = torch.jit.script(test_no_grad_assignment) z = torch.randn(5) w = s(x, z) self.assertTrue(w.requires_grad) self.assertIsNone(w.grad_fn) # Check that @torch.jit.ignored functions respect no_grad when it is # called in JIT mode. class NoGradModule(torch.nn.Module): @torch.jit.ignore def adder(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: w = x + y return w def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.no_grad(): w = self.adder(x, y) return w s = torch.jit.script(NoGradModule()) w = s(x, y) self.assertFalse(w.requires_grad) @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") def test_with_record_function(self): """ Check that torch.autograd.profiler.record_function context manager is torchscriptable. """ def with_rf(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: with torch.autograd.profiler.record_function("foo"): # Nested record_function. with torch.autograd.profiler.record_function("nested"): a = x + y return a scripted = torch.jit.script(with_rf) x, y = torch.ones(2), torch.ones(2) with torch.autograd.profiler.profile() as p: scripted(x, y) # Need to call below to populate CPU children. p.key_averages() function_events = p.function_events # Event with name "foo" should be recorded. rf_events = [evt for evt in function_events if evt.name == "foo"] self.assertEqual(len(rf_events), 1) rf_event = rf_events[0] child_events = rf_event.cpu_children # Ensure we find nested record_function event self.assertTrue("nested" in (child.name for child in child_events)) nested_function_event = [ evt for evt in function_events if evt.name == "nested" ][0] # Nested record function should have child "aten::add" nested_child_events = nested_function_event.cpu_children self.assertTrue("aten::add" in (child.name for child in nested_child_events))