xref: /aosp_15_r20/external/pytorch/test/jit/test_decorator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2# flake8: noqa
3
4import sys
5import unittest
6from enum import Enum
7from typing import List, Optional
8
9import torch
10from jit.myfunction_a import my_function_a
11from torch.testing._internal.jit_utils import JitTestCase
12
13
14class TestDecorator(JitTestCase):
15    def test_decorator(self):
16        # Note: JitTestCase.checkScript() does not work with decorators
17        # self.checkScript(my_function_a, (1.0,))
18        # Error:
19        #   RuntimeError: expected def but found '@' here:
20        #   @my_decorator
21        #   ~ <--- HERE
22        #   def my_function_a(x: float) -> float:
23        # Do a simple torch.jit.script() test instead
24        fn = my_function_a
25        fx = torch.jit.script(fn)
26        self.assertEqual(fn(1.0), fx(1.0))
27