1# Owner(s): ["oncall: jit"] 2# flake8: noqa 3 4import sys 5import unittest 6from enum import Enum 7from typing import List, Optional 8 9import torch 10from jit.myfunction_a import my_function_a 11from torch.testing._internal.jit_utils import JitTestCase 12 13 14class TestDecorator(JitTestCase): 15 def test_decorator(self): 16 # Note: JitTestCase.checkScript() does not work with decorators 17 # self.checkScript(my_function_a, (1.0,)) 18 # Error: 19 # RuntimeError: expected def but found '@' here: 20 # @my_decorator 21 # ~ <--- HERE 22 # def my_function_a(x: float) -> float: 23 # Do a simple torch.jit.script() test instead 24 fn = my_function_a 25 fx = torch.jit.script(fn) 26 self.assertEqual(fn(1.0), fx(1.0)) 27