# Owner(s): ["module: custom-operators"] import collections import itertools import os import re import subprocess import sys import typing import unittest from typing import * # noqa: F403 import numpy as np import torch._custom_ops as custom_ops import torch.testing._internal.optests as optests import torch.utils._pytree as pytree import torch.utils.cpp_extension from functorch import make_fx from torch import Tensor from torch._custom_op.impl import CustomOp, infer_schema from torch._library.infer_schema import tuple_to_list from torch._utils_internal import get_file_path_2 from torch.testing._internal import custom_op_db from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, OpDTypes, ops, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, run_tests, skipIfTorchDynamo, subtest, TestCase, ) from torch.testing._internal.custom_op_db import numpy_nonzero # Shadowed by `torch.testing._internal.common_utils.custom_op` from torch._custom_op.impl import custom_op # usort: skip def requires_compile(fun): fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun) return fun class CustomOpTestCaseBase(TestCase): test_ns = "_test_custom_op" def setUp(self): super().setUp() self.libraries = [] def tearDown(self): super().tearDown() import torch._custom_op keys = list(torch._custom_op.impl.global_registry.keys()) for key in keys: if not key.startswith(f"{self.test_ns}::"): continue torch._custom_op.impl.global_registry[key]._destroy() if hasattr(torch.ops, self.test_ns): delattr(torch.ops, self.test_ns) for lib in self.libraries: lib._destroy() del self.libraries def ns(self): return getattr(torch.ops, self.test_ns) def lib(self): result = torch.library.Library(self.test_ns, "FRAGMENT") # noqa: TOR901 self.libraries.append(result) return result def get_op(self, qualname): return torch._custom_op.impl.get_op(qualname) @requires_compile class TestCustomOpTesting(CustomOpTestCaseBase): @parametrize("check_gradients", (False, "auto")) @parametrize("dynamic", (True, False)) def test_aot_autograd_check_degenerate_cases( self, device, dynamic, check_gradients ): def simple(x): return x.clone() # Should not raise x = torch.randn(3, device=device) optests.aot_autograd_check( simple, (x,), {}, dynamic=dynamic, check_gradients=check_gradients ) def outputs_dont_require_grad(x): return x.detach() # Should not raise y = torch.randn(3, device=device, requires_grad=True) optests.aot_autograd_check( simple, (y,), {}, dynamic=dynamic, check_gradients=check_gradients ) def no_outputs(x): return x.detach() # Should not raise x = torch.randn(3, device=device, requires_grad=True) y = torch.randn(3, device=device, requires_grad=False) optests.aot_autograd_check( no_outputs, (x,), {}, dynamic=dynamic, check_gradients=check_gradients ) optests.aot_autograd_check( no_outputs, (y,), {}, dynamic=dynamic, check_gradients=check_gradients ) def test_incorrect_schema_mutation(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): guard = torch._C._AutoDispatchBelowAutograd() try: return op(x) finally: del guard @staticmethod def backward(ctx, gx): return gx def foo_impl(x): x.sin_() return x.clone() lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") x = torch.tensor(3.14159 / 3, requires_grad=True, device=device) with self.assertRaisesRegex( optests.OpCheckError, "Argument x is not defined as mutable but was mutated" ): torch.library.opcheck(op, (x,), {}) def test_incorrect_schema_view(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python with torch._C._AutoDispatchBelowAutograd(): with torch._C._ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) ): return op(x) @staticmethod def backward(ctx, gx): return gx def foo_impl(x): return x.view_as(x) def foo_meta(x): return x.view_as(x) lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_meta, "Meta") x = torch.tensor(3.14159 / 3, requires_grad=True) with self.assertRaisesRegex( optests.OpCheckError, "Argument x is not defined to alias output but was aliasing", ): torch.library.opcheck(op, (x,), {}) def test_missing_abstract_impl(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): with torch._C._AutoDispatchBelowAutograd(): return op(x) @staticmethod def backward(ctx, gx): return 2 * gx def foo_impl(x): return torch.tensor(x.cpu().numpy() ** 2, device=x.device) lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") x = torch.tensor([0, 1.0], requires_grad=True) with self.assertRaisesRegex( optests.OpCheckError, "_test_custom_op.foo.default", ): torch.library.opcheck(op, (x,), {}) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_incorrect_abstract_impl(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python guard = torch._C._AutoDispatchBelowAutograd() guard2 = torch._C.ExcludeDispatchKeyGuard( torch._C.DispatchKeySet(torch._C.DispatchKey.ADInplaceOrView) ) try: return op(x) finally: del guard del guard2 @staticmethod def backward(ctx, gx): return gx def foo_impl(x): return x**2 def foo_meta(x): return x.unsqueeze(1) ** 2 lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") lib.impl("foo", foo_meta, "Meta") x = torch.tensor([0, 1.0], requires_grad=True) with self.assertRaisesRegex(optests.OpCheckError, "Shapes .* are not equal"): torch.library.opcheck(op, (x,), {}) def test_missing_functionalization(self, device): lib = self.lib() lib.define("foo(Tensor(a!) x) -> Tensor(a!)") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.mark_dirty(x) with torch._C._AutoDispatchBelowAutograd(): return op(x) @staticmethod def backward(ctx, gx): return gx def foo_impl(x): return x.sin_() def foo_meta(x): return x lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") lib.impl("foo", foo_meta, "Meta") x = torch.tensor([0, 1.0]) y = x.clone() with self.assertRaisesRegex( optests.OpCheckError, "We only support functionalizing operators whose outputs do not have alias annotations", ): torch.library.opcheck(op, (y,), {}) def test_autograd_registered_at_backend(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): return x.clone() @staticmethod def backward(ctx, gx): return gx * 0.5 lib.impl("foo", Foo.apply, "CPU") lib.impl("foo", Foo.apply, "CUDA") lib.impl("foo", lambda x: x.clone(), "Meta") x = torch.randn([], requires_grad=True) with self.assertRaisesRegex( torch.testing._internal.optests.OpCheckError, "does not have an autograd kernel", ): torch.library.opcheck(op, (x,), {}) # I'm not sure why this is necessary del lib def test_global_state_mutation(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): invoked = 0 @staticmethod def forward(ctx, x): Foo.invoked += 1 return x.clone() * Foo.invoked @staticmethod def backward(ctx, gx): return gx lib.impl("foo", Foo.apply, "CompositeImplicitAutograd") x = torch.tensor(3.14159 / 3, requires_grad=True) with self.assertRaisesRegex( optests.OpCheckError, "eager-mode PyTorch vs AOTAutograd" ): torch.library.opcheck(op, (x,), {}) @ops(custom_op_db.custom_op_db, dtypes=OpDTypes.any_one) def test_opcheck_opinfo(self, device, dtype, op): for sample_input in op.sample_inputs( device, dtype, requires_grad=op.supports_autograd ): args = [sample_input.input] + list(sample_input.args) kwargs = sample_input.kwargs torch.library.opcheck(op.op, args, kwargs) def test_opcheck_fails_basic(self, device): @custom_op(f"{self.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: ... @foo.impl(["cpu", "cuda"]) def foo_impl(x): return x.sum() x = torch.randn(3, device=device, requires_grad=True) # Triggers the CustomOp autograd NYI error with self.assertRaisesRegex( optests.OpCheckError, "Autograd has not been implemented for operator" ): torch.library.opcheck(self.get_op(f"{self.test_ns}::foo"), (x,), {}) def test_autograd_registration_check_autograd_kernel(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): with torch._C._AutoDispatchBelowAutograd(): return op(x) @staticmethod def backward(ctx, gx): return gx def foo_impl(x): return x.sin() lib.impl("foo", Foo.apply, "Autograd") lib.impl("foo", foo_impl, "CPU") lib.impl("foo", foo_impl, "CUDA") x = torch.randn(3, requires_grad=True, device=device) # Should not raise optests.autograd_registration_check(op, (x,), {}) def test_autograd_registration_check_compositeimplicitautograd(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default def foo_impl(x): return x.sin().cos() lib.impl("foo", foo_impl, "CompositeImplicitAutograd") x = torch.randn(3, requires_grad=True, device=device) # Should not raise optests.autograd_registration_check(op, (x,), {}) def test_autograd_registration_check_incorrect_composite(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default def foo_impl(x): return x.sin().cos() lib.impl("foo", foo_impl, "CompositeExplicitAutograd") x = torch.randn(3, requires_grad=True, device=device) with self.assertRaisesRegex(AssertionError, "incorrectly registered"): optests.autograd_registration_check(op, (x,), {}) def test_autograd_registration_check_incorrect(self, device): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") op = self.ns().foo.default class Foo(torch.autograd.Function): @staticmethod def forward(ctx, x): return torch.sin(x) @staticmethod def backward(ctx, gx): return gx lib.impl("foo", Foo.apply, "CPU") lib.impl("foo", Foo.apply, "CUDA") x = torch.randn(3, requires_grad=True, device=device) with self.assertRaisesRegex(AssertionError, "incorrectly registered"): optests.autograd_registration_check(op, (x,), {}) def test_assert_raises_regex(self, device): from torch.testing._internal.optests.aot_autograd import assert_raises_regex with assert_raises_regex(RuntimeError, "c"): raise RuntimeError("abcd") with assert_raises_regex(RuntimeError, "c.*"): raise RuntimeError("abcd") with self.assertRaisesRegex(AssertionError, "instead got"): with assert_raises_regex(RuntimeError, "c.*"): raise ValueError("abcd") with self.assertRaisesRegex(AssertionError, "Expected exception"): with assert_raises_regex(RuntimeError, "c.*"): pass with self.assertRaisesRegex(AssertionError, "to match regex"): with assert_raises_regex(RuntimeError, "f"): raise RuntimeError("abcd") class TestCustomOp(CustomOpTestCaseBase): test_ns = "_test_custom_op" @requires_compile def test_functionalize_error(self): with torch.library._scoped_library(self.test_ns, "FRAGMENT") as lib: lib.define("foo(Tensor(a!) x) -> Tensor(a!)") def foo(x): return x.sin_() lib.impl("foo", foo, "CompositeExplicitAutograd") foo_op = self.get_op(f"{self.test_ns}::foo") lib.define("bar(Tensor(a) x) -> Tensor(a)") def bar(x): return x.view(-1) lib.impl("bar", bar, "CompositeExplicitAutograd") bar_op = self.get_op(f"{self.test_ns}::bar") msg = r".*We only support functionalizing operators whose outputs do not have alias annotations" x = torch.randn(3) @torch.compile(backend="aot_eager", fullgraph=True) def f(x): return foo_op(x) @torch.compile(backend="aot_eager", fullgraph=True) def g(x): return bar_op(x) with self.assertRaisesRegex(RuntimeError, msg): f(x) with self.assertRaisesRegex(RuntimeError, msg): g(x) def test_invalid_schemas(self): # function schmea validation goes through torchgen, so this is just a # basic test. with self.assertRaisesRegex(AssertionError, "Invalid function schema: foo"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(") def test_invalid_qualname(self): with self.assertRaisesRegex(ValueError, "overload"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo.Tensor", "() -> ()") def test_name_must_match(self): with self.assertRaisesRegex(ValueError, "to have name"): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def baz(x: Tensor) -> Tensor: raise NotImplementedError def test_unsupported_schemas(self): with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op( f"{TestCustomOp.test_ns}::foo", "(Tensor(a!) x) -> Tensor(a)" )(foo) with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op( f"{TestCustomOp.test_ns}::foo", "(Tensor(a) x) -> Tensor(a)" )(foo) with self.assertRaisesRegex(ValueError, "only supports functional"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor x) -> ()")( foo ) with self.assertRaisesRegex(ValueError, "self"): custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", "(Tensor self) -> ()")( foo ) # Tests for the older custom_op API def test_schema_matches_signature(self): with self.assertRaisesRegex(ValueError, "signature to match"): @custom_op(f"{TestCustomOp.test_ns}::blah", "(Tensor y) -> Tensor") def blah(x): pass with self.assertRaisesRegex(ValueError, "signature to match"): @custom_op( f"{TestCustomOp.test_ns}::blah2", "(Tensor x, *, Tensor y) -> Tensor" ) def blah2(x, y): pass with self.assertRaisesRegex(ValueError, "signature to match"): @custom_op( f"{TestCustomOp.test_ns}::blah3", "(Tensor x, *, Tensor w, Tensor z) -> Tensor", ) def blah3(x, *, y, z): pass with self.assertRaisesRegex(ValueError, "signature to match"): @custom_op( f"{TestCustomOp.test_ns}::blah4", "(Tensor x, *, Tensor z, Tensor y) -> Tensor", ) def blah4(x, *, y, z): pass with self.assertRaisesRegex(ValueError, "not supported"): @custom_op(f"{TestCustomOp.test_ns}::blah5", "(Tensor x) -> Tensor") def blah5(*args): pass with self.assertRaisesRegex(ValueError, "not supported"): @custom_op( f"{TestCustomOp.test_ns}::blah6", "(*, Tensor z, Tensor y) -> Tensor" ) def blah6(**kwargs): pass with self.assertRaisesRegex(ValueError, "default arguments"): @custom_op( f"{TestCustomOp.test_ns}::blah7", "(Tensor x, *, Tensor y) -> Tensor" ) def blah7(x=1, *, y): pass with self.assertRaisesRegex(ValueError, "default arguments"): @custom_op( f"{TestCustomOp.test_ns}::blah8", "(Tensor x, *, Tensor y) -> Tensor" ) def blah8(x, *, y=1): pass # kwonly-arg works @custom_op( f"{TestCustomOp.test_ns}::blah9", "(Tensor x, *, Tensor y) -> Tensor" ) def blah9(x, *, y): pass def test_infer_schema_no_return(self): with self.assertRaisesRegex( ValueError, "No return type annotation was provided. Please add one." ): @torch.library.custom_op("mylib::foo", mutates_args={}) def foo(x: torch.Tensor, y: int): return x * y def test_infer_schema_supported(self): def a(x: Tensor) -> Tensor: return torch.empty([]) self.assertExpectedInline( infer_schema(a, mutates_args=()), """(Tensor x) -> Tensor""" ) def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor: return torch.empty([]) self.assertExpectedInline( infer_schema(kwonly1, mutates_args=()), """(Tensor x, *, SymInt y, float z) -> Tensor""", ) def kwonly2(*, y: Tensor) -> Tensor: return torch.empty([]) self.assertExpectedInline( infer_schema(kwonly2, mutates_args=()), """(*, Tensor y) -> Tensor""" ) def b( x: Tensor, y: int, z: bool, a: float, b: torch.dtype, c: torch.device, d: torch.types.Number, ) -> Tuple[Tensor, int, float, bool]: return torch.empty([]), 1, 0.1, True self.assertExpectedInline( infer_schema(b, mutates_args=()), """(Tensor x, SymInt y, bool z, float a, ScalarType b, Device c, Scalar d) -> (Tensor, SymInt, float, bool)""", ) def c( x: Tensor, y: Sequence[Tensor], z: Optional[Tensor], w: Sequence[Optional[Tensor]], ) -> List[Tensor]: return [torch.empty([])] self.assertExpectedInline( infer_schema(c, mutates_args=()), """(Tensor x, Tensor[] y, Tensor? z, Tensor?[] w) -> Tensor[]""", ) def d(x: Tensor) -> Tuple[List[Tensor], Tensor]: return [torch.empty([])], torch.empty([]) self.assertExpectedInline( infer_schema(d, mutates_args=()), """(Tensor x) -> (Tensor[], Tensor)""" ) def e() -> Tensor: return torch.empty([]) self.assertExpectedInline(infer_schema(e, mutates_args=()), """() -> Tensor""") def f(x: Tensor) -> None: pass self.assertExpectedInline( infer_schema(f, mutates_args=()), """(Tensor x) -> ()""" ) def g( x: Tensor, y: List[Tensor], z: List[Tensor], w: List[Optional[Tensor]] ) -> None: pass self.assertExpectedInline( infer_schema(g, mutates_args=()), """(Tensor x, Tensor[] y, Tensor[] z, Tensor?[] w) -> ()""", ) self.assertExpectedInline( infer_schema(g, mutates_args={"x", "w", "z"}), """(Tensor(a0!) x, Tensor[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", ) self.assertExpectedInline( infer_schema(g, mutates_args="unknown"), """(Tensor(a0!) x, Tensor(a1!)[] y, Tensor(a2!)[] z, Tensor(a3!)?[] w) -> ()""", ) def h( x: Tensor, a: Optional[int] = None, b: float = 3.14, c: bool = True, d: int = 3, e: str = "foo", f: torch.dtype = torch.float, g: torch.dtype = torch.float32, h: torch.dtype = torch.int, i: torch.device = torch.device("cpu:0"), j: torch.device = "cpu", ) -> None: pass self.assertExpectedInline( infer_schema(h, mutates_args=()), ( """(Tensor x, SymInt? a=None, float b=3.14, bool c=True, SymInt d=3, str e="foo", """ """ScalarType f=float32, ScalarType g=float32, ScalarType h=int32, Device i="cpu:0", Device j="cpu") -> ()""" ), ) def foo_impl(x: torch.Tensor) -> torch.Tensor: return x.sin() schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") def test_infer_schema_unsupported(self): with self.assertRaisesRegex(ValueError, "varargs"): def foo(*args): raise NotImplementedError infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "varkwargs"): def foo(**kwargs): raise NotImplementedError infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "must have a type annotation"): def foo(x): raise NotImplementedError infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "unsupported"): def foo(x: Tensor) -> Tuple[Tensor, ...]: raise NotImplementedError infer_schema(foo, mutates_args=()) with self.assertRaisesRegex(ValueError, "can be mutated"): def foo(x: Tensor, y: int) -> Tensor: raise NotImplementedError infer_schema(foo, mutates_args={"y"}) def _generate_examples(self, typ): if typ is int: return [17] if typ is float: return [3.14] if typ is bool: return [True] if typ is str: return ["foo"] if typ is torch.dtype: return [torch.float32] if typ is torch.device: return [torch.device("cpu")] if typ == torch.types.Number: return [2.718] if typ is torch.Tensor: return [torch.tensor(3)] if typ == Optional[torch.types.Number]: return [None, 2.718] origin = typing.get_origin(typ) if origin is Union: args = typing.get_args(typ) assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None)) elt = args[0] if args[1] is type(None) else args[1] return self._generate_examples(elt) + [None] if origin is list: args = typing.get_args(typ) assert len(args) == 1 elt = args[0] return [ self._generate_examples(elt), self._generate_examples(elt), self._generate_examples(elt), ] if origin is collections.abc.Sequence: args = typing.get_args(typ) assert len(args) == 1 examples = self._generate_examples(args[0]) return list(itertools.product(examples, examples)) + [] raise NotImplementedError( f"testrunner cannot generate instanstance of type {typ}" ) def test_supported_return_types_single_return(self): for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: for example in self._generate_examples(typ): try: @custom_ops.custom_op(f"{self.test_ns}::foo") def foo(x: Tensor) -> typ: raise NotImplementedError @custom_ops.impl(f"{self.test_ns}::foo") def foo_impl(x: Tensor) -> typ: return example op = self.get_op(f"{self.test_ns}::foo") result = op(torch.randn([])) self.assertEqual(result, example, msg=f"{typ} {example}") finally: custom_ops._destroy(f"{self.test_ns}::foo") def test_supported_return_types_multi_return(self): for typ in torch._library.infer_schema.SUPPORTED_RETURN_TYPES: for example in self._generate_examples(typ): try: @custom_ops.custom_op(f"{self.test_ns}::foo") def foo(x: Tensor) -> Tuple[typ, typ]: raise NotImplementedError @custom_ops.impl(f"{self.test_ns}::foo") def foo_impl(x: Tensor) -> Tuple[typ, typ]: return (example, example) op = self.get_op(f"{self.test_ns}::foo") result = op(torch.randn([])) expected = (example, example) self.assertEqual(result, expected, msg=f"{typ} {example}") finally: custom_ops._destroy(f"{self.test_ns}::foo") def test_supported_param_types(self): for typ in torch._library.infer_schema.SUPPORTED_PARAM_TYPES: @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: Tensor, y: typ) -> Tensor: raise NotImplementedError yeet = None @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=["cpu"]) def foo_cpu(x, y): nonlocal yeet yeet = y return x.clone() try: for example in self._generate_examples(typ): op = self.get_op(f"{self.test_ns}::foo") op(torch.randn([]), example) self.assertEqual(yeet, example, msg=f"{typ} {example}") yeet = None finally: custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") def test_sequences(self): # Sequence[int] gets automagically turned into int[] in the schema. # This test checks that we actually do support arbitrary sequence types. class MySequence(collections.abc.Sequence): def __init__(self) -> None: self._container = [1, 2, 3] def __getitem__(self, idx): return self._container[idx] def __len__(self): return len(self._container) @custom_ops.custom_op(f"{self.test_ns}::foo") def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor: raise NotImplementedError called = 0 @custom_ops.impl(f"{self.test_ns}::foo", device_types="cpu") def foo_cpu(x, sizes): nonlocal called called += 1 # Dispatcher will normalize the sequence type into a List self.assertEqual(sizes, [1, 2, 3]) return x.clone() x = torch.randn([]) seq = MySequence() op = self.get_op(f"{self.test_ns}::foo") op(x, seq) self.assertEqual(called, 1) def test_unsupported_param_types(self): # Not comprehensive (it doesn't need to be), just a check that our mechanism works with self.assertRaisesRegex(ValueError, "unsupported type"): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: Tensor, y: List[Optional[int]]) -> Tensor: raise NotImplementedError del foo with self.assertRaisesRegex(ValueError, "unsupported type"): # int[N] in Dispatcher is a bit wild, so we don't try to support it. @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: raise NotImplementedError del foo with self.assertRaisesRegex(ValueError, r"For example, typing.List\[int\]"): # test that we propose a correct and supported type. @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) def foo(x: Tensor, y: Tuple[int, int]) -> Tensor: raise NotImplementedError del foo with self.assertRaises(ValueError) as cm: @torch.library.custom_op(f"{TestCustomOp.test_ns}::foo", mutates_args={}) def foo(x: Tensor, y: Tuple[int, float]) -> Tensor: raise NotImplementedError del foo self.assertNotIn("example", str(cm.exception), "") with self.assertRaisesRegex(ValueError, "unsupported type"): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: Tensor, y: Callable) -> Tensor: raise NotImplementedError del foo def test_supported_schemas(self): # All of these should already be tested by PyTorch codegen # (we share the same mechanism), but here's a sanity check. schemas = [ "(Tensor x) -> Tensor", "(Tensor x) -> Tensor y", "(Tensor[] x) -> Tensor y", "(Tensor x) -> (Tensor, Tensor)", "(Tensor x) -> (Tensor y, Tensor z)", "(Tensor x) -> (Tensor y, Tensor z)", ] other_schemas = [ "(Tensor x, Tensor w) -> (Tensor y, Tensor z)", "(Tensor x, Tensor w) -> (Tensor, Tensor)", "(Tensor x, Tensor w) -> Tensor", "(Tensor? x, Tensor w) -> Tensor", "(Tensor? x, Tensor[] w) -> Tensor", "(Tensor x, int[] w) -> Tensor", "(Tensor x, SymInt[] w) -> Tensor", "(Tensor x, Scalar w) -> Tensor", "(Tensor x, float w) -> Tensor", "(Tensor x, float? w) -> Tensor", "(Tensor x, bool[] w) -> Tensor", ] for schema in schemas: custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo", schema) custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") for schema in other_schemas: custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar", schema) custom_ops._destroy(f"{TestCustomOp.test_ns}::bar") def test_reserved_ns(self): from torch._custom_op.impl import RESERVED_NS for ns in RESERVED_NS: with self.assertRaisesRegex(ValueError, "is a reserved namespace"): custom_ops.custom_op(f"{ns}::foo", "(Tensor x) -> Tensor") with self.assertRaisesRegex(ValueError, "is a reserved namespace"): @custom_ops.custom_op(f"{ns}::foo2") def foo2(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError def test_private_ctor(self): with self.assertRaisesRegex(RuntimeError, "CustomOp constructor is private"): CustomOp(None, None, None, None, None) def test_lifetime(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError custom_op = torch._custom_op.impl.get_op(f"{TestCustomOp.test_ns}::foo") # We can't define an op multiple times, with self.assertRaisesRegex(RuntimeError, "multiple times"): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 raise NotImplementedError # Unless we delete the original op. custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") # Smoke test @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 raise NotImplementedError custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") def test_autograd_notimplemented(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 raise NotImplementedError x = torch.randn(3, requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): op(x) custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") del foo @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: Sequence[torch.Tensor]) -> torch.Tensor: raise NotImplementedError x = torch.randn(3, requires_grad=True) y = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): op([y, x]) custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") del foo @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise NotImplementedError x = torch.randn(3, requires_grad=True) y = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") with self.assertRaisesRegex(RuntimeError, "Autograd has not been implemented"): op(y, x) custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") def test_autograd_notimplemented_gradmode(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x, y): return x * y x = torch.randn(3, requires_grad=True) y = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") with torch.no_grad(): # Shouldn't raise, because we are in no_grad op(y, x) def test_impl_cpu(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") def foo_cpu(x): return x.sin() x = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") result = op(x) self.assertEqual(result, foo_cpu(x)) def test_impl_invalid_devices(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError def foo_impl(x): return x.sin() from torch._custom_op.impl import SUPPORTED_DEVICE_TYPE_TO_KEY for device_type in SUPPORTED_DEVICE_TYPE_TO_KEY.keys(): # Smoke test: should not raise error custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type)( foo_impl ) # Not supported by this API: we can either support them in the future # or provide some other CustomOp.def_* function. This depends on how # common the use cases are. for invalid_type in ["hip", "xla", "mkldnn", ["cpu", "hip"]]: with self.assertRaisesRegex(ValueError, "we only support device_type"): custom_ops.impl( f"{TestCustomOp.test_ns}::foo", device_types=invalid_type )(foo_impl) def test_backward_partially_registered(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.sin() @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return grad * saved.cos() x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") with self.assertRaisesRegex( RuntimeError, "unable to find a 'save_for_backward'" ): y = op(x) y.backward() def test_save_for_backward_inputs_are_namedtuple(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.sin() hit = 0 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): nonlocal hit hit += 1 self.assertTrue(isinstance(inputs, tuple)) self.assertEqual(list(inputs._asdict().keys()), ["x"]) return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": grad * saved.cos()} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x) self.assertEqual(hit, 1) y.backward() self.assertEqual(hit, 1) def test_backward_returns_dict(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return grad * saved.cos() x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x) with self.assertRaisesRegex(RuntimeError, "to be a dict"): y.backward() def test_backward_dict_invalid_keys(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": grad * saved.cos(), "y": None} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x) with self.assertRaisesRegex(RuntimeError, "to have keys {'x'}"): y.backward() def test_backward_dict_grad_for_nontensor(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, dim: int) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x, dim): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": grad * saved.cos(), "dim": None} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x, 32) with self.assertRaisesRegex(RuntimeError, "non-Tensor-like types"): y.backward() def test_backward_dict_requires_keys_for_input_tensors(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x, y): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": grad * saved.cos()} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x, x) with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): y.backward() def test_backward_dict_requires_keys_for_input_optional_tensors(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, y: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x, y): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": grad * saved.cos()} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x, None) with self.assertRaisesRegex(RuntimeError, r"to have keys {.*'y'.*}"): y.backward() def test_backward_grads_are_tensor_or_none(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"x": (grad * saved.cos(),)} x = torch.randn([], requires_grad=True) op = self.get_op(f"{self.test_ns}::foo") y = op(x) with self.assertRaisesRegex(RuntimeError, "either None or a Tensor"): y.backward() def test_backward_tensorlist_input_requires_list_grads_with_same_numel(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(xs): return xs[0].sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.xs[0] @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"xs": [grad * saved.cos(), None]} xs = [torch.randn([], requires_grad=True) for _ in range(3)] op = self.get_op(f"{self.test_ns}::foo") y = op(xs) with self.assertRaisesRegex(RuntimeError, "3 gradients but got 2"): y.backward() def test_backward_tensorlist_input_requires_list_grads_none_or_Tensor(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(xs): return xs[0].sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.xs[0] @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"xs": [grad * saved.cos(), None, (None,)]} xs = [torch.randn([], requires_grad=True) for _ in range(3)] op = self.get_op(f"{self.test_ns}::foo") y = op(xs) with self.assertRaisesRegex(RuntimeError, "None or Tensor"): y.backward() def test_backward_tensorlist_input_requires_list_grads(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(xs): return xs[0].sin() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return inputs.xs[0] @custom_ops.impl_backward(f"{TestCustomOp.test_ns}::foo") def foo_backward(ctx, saved, grad): return {"xs": None} xs = [torch.randn([], requires_grad=True) for _ in range(3)] op = self.get_op(f"{self.test_ns}::foo") y = op(xs) with self.assertRaisesRegex(RuntimeError, "list of gradients"): y.backward() def test_backward_output_differentiability_type(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(xs: Sequence[torch.Tensor]) -> torch.Tensor: raise NotImplementedError with self.assertRaisesRegex(RuntimeError, "output_differentiability"): @custom_ops.impl_backward( f"{TestCustomOp.test_ns}::foo", output_differentiability=True ) def foo_backward(ctx, saved, grad): return {"xs": None} def test_backward_output_differentiability_numel(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(xs: Sequence[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError with self.assertRaisesRegex(RuntimeError, "output_differentiability"): @custom_ops.impl_backward( f"{TestCustomOp.test_ns}::foo", output_differentiability=[True] ) def foo_backward(ctx, saved, grad): return {"xs": None} def test_backward_output_differentiability_tensorlist(self): @custom_ops.custom_op(f"{self.test_ns}::foo") def foo(x: Tensor) -> Tuple[List[Tensor], Tensor]: raise NotImplementedError @custom_ops.impl(f"{self.test_ns}::foo") def foo_impl(x): return [x.clone(), x.clone()], x.clone() @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return [] @custom_ops.impl_backward( f"{TestCustomOp.test_ns}::foo", output_differentiability=[False, True] ) def foo_backward(ctx, saved, grad_lst, grad): return {"x": grad} op = self.get_op(f"{self.test_ns}::foo") x = torch.randn(3, requires_grad=True) [a, b], c = op(x) self.assertFalse(a.requires_grad) self.assertFalse(b.requires_grad) self.assertTrue(c.requires_grad) def test_backward_output_differentiability_non_tensor(self): @custom_ops.custom_op(f"{self.test_ns}::foo") def foo(x: Tensor) -> Tuple[Tensor, int]: raise NotImplementedError @custom_ops.impl(f"{self.test_ns}::foo") def foo_impl(x): return x.clone(), 3 @custom_ops.impl_save_for_backward(f"{TestCustomOp.test_ns}::foo") def foo_save_for_backward(inputs, output): return [] @custom_ops.impl_backward( f"{TestCustomOp.test_ns}::foo", output_differentiability=[True, True] ) def foo_backward(ctx, saved, grad0, grad1): return {"x": grad0} op = self.get_op(f"{self.test_ns}::foo") x = torch.randn(3, requires_grad=True) with self.assertRaisesRegex(RuntimeError, "is not a Tensor"): op(x) @unittest.skipIf(not TEST_CUDA, "requires CUDA") def test_impl_separate(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cpu") def foo_cpu(x): return x.sin() @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") def foo_cuda(x): return x.cos() x = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") result = op(x) self.assertEqual(result, foo_cpu(x)) x_cuda = x.cuda() op = self.get_op(f"{self.test_ns}::foo") result = op(x_cuda) self.assertEqual(result, foo_cuda(x_cuda)) @unittest.skipIf(not TEST_CUDA, "requires CUDA") def test_impl_multiple(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @custom_ops.impl(f"{TestCustomOp.test_ns}::foo") def foo_impl(x): return x.cos() op = self.get_op(f"{self.test_ns}::foo") x = torch.randn(3) result = op(x) self.assertEqual(result, foo_impl(x)) x_cuda = x.cuda() result = op(x_cuda) self.assertEqual(result, foo_impl(x_cuda)) def test_impl_abstract_overload(self): lib = self.lib() lib.define("sin.blah(Tensor x) -> Tensor") torch.library.impl_abstract( f"{self.test_ns}::sin.blah", torch.empty_like, lib=lib ) op = self.ns().sin.blah x = torch.randn(3, device="meta") op(x) def test_impl_meta(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, dim: int) -> torch.Tensor: raise NotImplementedError @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) def foo_meta(x, dim): output_shape = list(x.shape) del output_shape[dim] return x.new_empty(output_shape) x = torch.randn(2, 3, device="meta") op = self.get_op(f"{self.test_ns}::foo") result = op(x, 1) self.assertEqual(result.shape, foo_meta(x, 1).shape) def test_duplicate_impl(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor, dim: int) -> torch.Tensor: raise NotImplementedError @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) def foo_meta(x, dim): output_shape = list(x.shape) del output_shape[dim] return x.new_empty(output_shape) with self.assertRaisesRegex(RuntimeError, r"test_custom_ops.py:\d+"): @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) def foo_meta2(x, dim): output_shape = list(x.shape) del output_shape[dim] return x.new_empty(output_shape) def test_new_data_dependent_symint(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) def foo_meta(x): ctx = torch.library.get_ctx() r = ctx.new_dynamic_size(min=1) with self.assertRaisesRegex(ValueError, "greater than or equal to 0"): ctx.new_dynamic_size(min=-1) with self.assertRaisesRegex(ValueError, "SymInt"): ctx.new_dynamic_size(max=x.numel()) # NB: You must return dynamic sizes! return x.new_empty(r) x = torch.randn(2, 3, device="cpu") op = self.get_op(f"{self.test_ns}::foo") make_fx(op, tracing_mode="symbolic")(x) def test_meta_for_data_dependent_shape_operation(self): x = torch.randn(10, device="meta") with self.assertRaisesRegex(RuntimeError, "data-dependent output shape"): numpy_nonzero(x) def test_basic_make_fx(self): # More serious tests are in our CustomOp opinfo db, # this one is just a sanity check. @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @torch.library.impl_abstract(f"{TestCustomOp.test_ns}::foo", lib=self.lib()) def foo_meta(x): return x.sum() x = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") gm = make_fx(op, tracing_mode="symbolic")(x) self.assertTrue(f"{TestCustomOp.test_ns}.foo" in gm.code) def test_not_implemented_error(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError x = torch.randn(3) op = self.get_op(f"{self.test_ns}::foo") with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"): op(x) x = torch.randn(3, device="meta") with self.assertRaisesRegex(NotImplementedError, "no fake impl or Meta kernel"): op(x) @custom_ops.custom_op(f"{TestCustomOp.test_ns}::bar") def bar(sizes: Sequence[int]) -> torch.Tensor: raise NotImplementedError op = self.get_op(f"{self.test_ns}::bar") with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"): op((1, 2, 3)) def test_data_dependent_basic(self): x = torch.randn(5, 5) gm = make_fx(numpy_nonzero, tracing_mode="symbolic")(x) self.assertTrue("nonzero" in gm.code) def test_data_dependent_fake_tracing(self): x = torch.randn(5, 5) # We've updated to attempt to use unbacked symints even for fake # tracing make_fx(numpy_nonzero, tracing_mode="fake")(x) def test_symints(self): def f(x): return torch.ops._torch_testing.numpy_view_copy(x, x.shape) x = torch.randn(2, 3, 4) gm = make_fx(f, tracing_mode="symbolic")(x) result = gm(x) self.assertEqual(result, f(x)) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2) numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None return numpy_view_copy""", # noqa: B950 ) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") def test_data_dependent_compile(self): import torch._dynamo.testing from torch._dynamo.utils import counters counters.clear() cnt = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnt) def f(x): return numpy_nonzero(x.clone()).clone() f(torch.randn(10)) self.assertEqual(len(counters["graph_break"]), 1) self.assertEqual(next(iter(counters["graph_break"].values())), 1) self.assertExpectedInline( next(iter(counters["graph_break"].keys())).replace(";", "\n"), """\ dynamic shape operator: _torch_testing.numpy_nonzero.default to enable, set torch._dynamo.config.capture_dynamic_output_shape_ops = True""", ) # pre-existing problem: torch.compile(dynamic=True) will, by default, # graph break on data-dependent operations. Eventually we'll make it so # that it never graph breaks on data-dependent operations. @unittest.expectedFailure def test_data_dependent_nms_dynamic_compile(self): import torch._dynamo.testing from torch._dynamo.utils import counters counters.clear() cnt = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnt, dynamic=True) def f(x, s, i): return torch.ops._torch_testing.numpy_nms(x.clone(), s, i).clone() f(torch.randn(20, 4), torch.randn(20), 0.1) self.assertEqual(len(counters["graph_break"]), 0) def test_impl_on_existing_op(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" @torch._custom_ops.impl(qualname) def foo_impl(x): return x.sin() op = self.get_op(qualname) x = torch.randn(3) result = op(x) self.assertEqual(result, x.sin()) @parametrize( "key", ["CPU", "CUDA", "CompositeImplicitAutograd", "CompositeExplicitAutograd"] ) def test_impl_on_existing_op_with_cpu_registration(self, key): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" def foo_impl(x): return x.sin() lib.impl("foo", foo_impl, key) op = self.get_op(qualname) with self.assertRaisesRegex(RuntimeError, "already has an implementation"): custom_ops.impl(qualname, func=foo_impl) def test_abstract_impl_on_existing_op(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" @torch.library.impl_abstract(qualname, lib=self.lib()) def foo_impl(x): return x.sin() op = self.get_op(qualname) with torch._subclasses.FakeTensorMode(): x = torch.randn(3) result = op(x) self.assertEqual(result.shape, x.shape) self.assertEqual(result.stride(), x.stride()) def test_abstract_impl_on_existing_op_with_meta(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" def foo_impl(x): return x.sin() lib.impl("foo", foo_impl, "Meta") op = self.get_op(qualname) with self.assertRaisesRegex(RuntimeError, r"already has .*Meta implementation"): torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) def test_abstract_impl_on_existing_op_with_CompositeImplicitAutograd(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" def foo_impl(x): return x.sin() lib.impl("foo", foo_impl, "CompositeImplicitAutograd") op = self.get_op(qualname) with self.assertRaisesRegex(RuntimeError, "CompositeImplicitAutograd"): torch.library.impl_abstract(qualname, func=foo_impl, lib=self.lib()) def test_abstract_impl_on_existing_op_with_CompositeExplicitAutograd(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" def foo_impl(x): return x.sin() lib.impl("foo", foo_impl, "CompositeExplicitAutograd") op = self.get_op(qualname) torch.library.impl_abstract(qualname, func=lambda x: x.sum(), lib=self.lib()) with torch._subclasses.FakeTensorMode(): x = torch.randn(10) result = op(x) self.assertEqual(result.shape, ()) def _test_backward_impl_raises(self, qualname, err_regex): with self.assertRaisesRegex(RuntimeError, err_regex): @custom_ops.impl_save_for_backward(qualname) def foo2(x): return with self.assertRaisesRegex(RuntimeError, err_regex): @custom_ops.impl_backward(qualname) def foo3(x): return def test_backward_impl_on_existing_op_incorrect_schema_views(self): lib = self.lib() lib.define("foo(Tensor(a) x) -> Tensor(a)") qualname = f"{self.test_ns}::foo" self._test_backward_impl_raises(qualname, "operator that returns views") def test_backward_impl_on_existing_op_incorrect_schema_mutable(self): lib = self.lib() lib.define("foo(Tensor(a!) x) -> Tensor") qualname = f"{self.test_ns}::foo" self._test_backward_impl_raises(qualname, "non-functional") def test_backward_impl_on_existing_op_incorrect_schema_no_output(self): lib = self.lib() lib.define("foo(Tensor x) -> ()") qualname = f"{self.test_ns}::foo" self._test_backward_impl_raises(qualname, "no returns") def test_backward_impl_on_existing_op_CompositeImplicitAutograd(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" lib.impl("foo", lambda x: x.sin().cos(), "CompositeImplicitAutograd") self._test_backward_impl_raises(qualname, "CompositeImplicitAutograd") @parametrize("key", ["Autograd", "AutogradCPU", "AutogradCUDA"]) def test_backward_impl_on_existing_op_with_key(self, key): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" lib.impl("foo", lambda x: x.sin().cos(), key) self._test_backward_impl_raises(qualname, key) def test_is_functional_schema(self): tests = { "foo(Tensor x) -> Tensor": True, "foo(Tensor(a) x) -> Tensor": True, "foo(Tensor(a!) x) -> Tensor": False, "foo(Tensor(a) x) -> Tensor(a)": False, "foo(Tensor x) -> ()": False, } for schema_str, expected in tests.items(): res = torch._library.utils.is_functional_schema(schema_str) self.assertEqual(res, expected) from torchgen.model import FunctionSchema schema = FunctionSchema.parse(schema_str) res = torch._library.utils.is_functional_schema(schema) self.assertEqual(res, expected) schema = torch._C.parse_schema(schema_str) res = torch._library.utils.is_functional_schema(schema) self.assertEqual(res, expected) def test_incorrect_schema_types(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): lib.define("foo12(Tensor a) -> asdfasdf") with self.assertRaisesRegex(RuntimeError, "unknown type specifier"): lib.define("foo12(asdf a) -> Tensor") with self.assertRaisesRegex(RuntimeError, "Use `SymInt` or `int`"): lib.define("foo12(int64_t a) -> Tensor") with self.assertRaisesRegex(RuntimeError, "Use `float`"): lib.define("foo12(double a) -> Tensor") def test_is_tensorlist_like_type(self): tensorlists = [ # Tensor[] torch.ops.aten.where.default._schema.returns[0].type, # Tensor?[] torch.ops.aten.index.Tensor._schema.arguments[1].type, # Tensor[]? torch._C.parse_schema("foo(Tensor[]? x) -> ()").arguments[0].type, # Tensor?[]? torch._C.parse_schema("foo(Tensor?[]? x) -> ()").arguments[0].type, ] non_tensorlists = [ # Tensor torch.ops.aten.sin.default._schema.arguments[0].type, # IntList torch.ops.aten.sum.dim_IntList._schema.arguments[1].type, ] for a in tensorlists: self.assertTrue(torch._library.utils.is_tensorlist_like_type(a)) for a in non_tensorlists: self.assertFalse(torch._library.utils.is_tensorlist_like_type(a)) def test_backward_impl_on_existing_op(self): lib = self.lib() lib.define("foo(Tensor x) -> Tensor") qualname = f"{self.test_ns}::foo" @custom_ops.impl(qualname) def foo_impl(x): with torch.no_grad(): return x.sin() @custom_ops.impl_save_for_backward(qualname) def foo_save_for_backward(inputs, output): return inputs.x @custom_ops.impl_backward(qualname) def foo_backward(ctx, saved, grad_out): return {"x": grad_out * saved.cos()} op = self.get_op(qualname) x = torch.randn([], requires_grad=True) y = op(x) (gx,) = torch.autograd.grad(y, x) self.assertEqual(gx, x.cos()) @parametrize( "tags", [ subtest(torch.Tag.pointwise, "single"), subtest((torch.Tag.pointwise,), "tuple"), subtest([torch.Tag.pointwise], "list"), ], ) def test_define_with_tags(self, tags): lib = self.lib() tags = (torch.Tag.pointwise,) torch.library.define( f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib, tags=tags ) actual = self.ns().foo.default.tags self.assertTrue(isinstance(actual, list)) self.assertEqual(actual, list(tags)) def test_builtin_aten_ops_are_pt2_compliant(self): for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]: self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) def test_builtin_torchscript_ops(self): for op in [torch.ops.aten.sub.complex, torch.ops.aten.mul.complex]: self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) def test_autogen_aten_ops_are_pt2_compliant(self): for op in [torch.ops.aten.fill.Tensor_out]: self.assertIn(torch.Tag.generated, op.tags) self.assertIn(torch.Tag.pt2_compliant_tag, op.tags) def test_resolve_packet(self): x = torch.randn(3) result = torch._C._jit_resolve_packet("aten::sum", x) self.assertEqual(result, "default") result = torch._C._jit_resolve_packet("aten::sum", x, dim=1) self.assertEqual(result, "dim_IntList") with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): result = torch._C._jit_resolve_packet("aten::sum", x, x, x) def test_define_bad_schema(self): lib = self.lib() with self.assertRaisesRegex(ValueError, "expected schema to look like"): torch.library.define(f"{self.test_ns}::foo", "foo(Tensor x) -> Tensor") def test_define_and_impl(self): lib = self.lib() torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) @torch.library.impl(f"{self.test_ns}::foo", "CPU", lib=lib) def f(x): return torch.from_numpy(np.sin(x.numpy())) x = torch.randn(3) y = self.ns().foo(x) assert torch.allclose(y, x.sin()) def test_define_validation(self): with self.assertRaisesRegex(ValueError, "namespace"): torch.library.define("foo", "(Tensor x) -> Tensor") def test_legacy_define(self): lib = self.lib() @torch.library.define(lib, "foo(Tensor x) -> Tensor") def f(x): return torch.from_numpy(np.sin(x.numpy())) x = torch.randn(3) y = self.ns().foo(x) assert torch.allclose(y, x.sin()) def test_impl_function(self): lib = self.lib() torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) def f(x): return torch.from_numpy(np.sin(x.numpy())) torch.library.impl(f"{self.test_ns}::foo", "CPU", f, lib=lib) x = torch.randn(3) y = self.ns().foo(x) assert torch.allclose(y, x.sin()) def test_legacy_impl(self): lib = self.lib() torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) @torch.library.impl(lib, "foo", "CPU") def f(x): return torch.from_numpy(np.sin(x.numpy())) x = torch.randn(3) y = self.ns().foo(x) assert torch.allclose(y, x.sin()) def test_defined_in_python(self): self.assertFalse(torch.ops.aten.sin.default._defined_in_python) self.assertFalse(torch.ops.aten.sum.dim_IntList._defined_in_python) lib = self.lib() torch.library.define("{self._test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) ns = self.ns() self.assertTrue(ns.foo.default._defined_in_python) torch.library.define( "{self._test_ns}::bar.overload", "(Tensor x) -> Tensor", lib=lib ) self.assertTrue(ns.bar.overload._defined_in_python) def _test_impl_device(self, name, types, device): lib = self.lib() torch.library.define(f"{self.test_ns}::{name}", "(Tensor x) -> Tensor", lib=lib) @torch.library.impl(f"{self.test_ns}::{name}", types) def f(x): x_np = x.cpu().numpy() y = torch.from_numpy(np.sin(x_np)) return y.to(device=x.device) x = torch.randn(3, device=device) y = getattr(self.ns(), name)(x) assert torch.allclose(y, x.sin()) def test_impl_device_cpu(self): self._test_impl_device("foo1", "default", "cpu") self._test_impl_device("foo2", ["cpu"], "cpu") self._test_impl_device("foo3", ["cpu", "cuda"], "cpu") @unittest.skipIf(not TEST_CUDA, "requires cuda") def test_impl_device_cuda(self): self._test_impl_device("foo4", "default", "cuda") self._test_impl_device("foo5", ["cuda"], "cuda") self._test_impl_device("foo6", ["cpu", "cuda"], "cuda") def test_impl_device_function(self): lib = self.lib() torch.library.define(f"{self.test_ns}::foo", "(Tensor x) -> Tensor", lib=lib) def f(x): x_np = x.cpu().numpy() y = torch.from_numpy(np.sin(x_np)) return y.to(device=x.device) torch.library.impl(f"{self.test_ns}::foo", "default", f, lib=lib) x = torch.randn(3) y = self.ns().foo(x) assert torch.allclose(y, x.sin()) def test_impl_device_invalid(self): with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"): torch.library.impl("blah::blah", "somethingsomething") def test_autograd_function_backed_op(self): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; static torch::Tensor forward( torch::autograd::AutogradContext* ctx, const torch::Tensor& x) { return x; } static torch::autograd::variable_list backward( torch::autograd::AutogradContext *ctx, torch::autograd::variable_list grad_output) { return grad_output; } }; torch::Tensor custom_op_backed_by_autograd_fn(const torch::Tensor& x) { return CustomOpAutogradFunction::apply(x); } TORCH_LIBRARY(mylib, m) { m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); } """ module = torch.utils.cpp_extension.load_inline( name="mylib", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", verbose=True, ) x = torch.ones(2, 2, requires_grad=True) temp = x.clone().detach() out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x) loss = out.sum() loss.backward() self.assertEqual(x.grad, temp) def op_with_incorrect_schema(testcase, name): lib = testcase.lib() lib.define(f"{name}(Tensor x) -> Tensor") qualname = f"{testcase.test_ns}::{name}" lib.impl(name, lambda x: x[:], "CompositeExplicitAutograd") return testcase.get_op(qualname) class MiniOpTest(CustomOpTestCaseBase): test_ns = "mini_op_test" def _init_op_delayed_backward_error(self): name = "delayed_error" qualname = f"{self.test_ns}::{name}" lib = self.lib() lib.define(f"{name}(Tensor x) -> Tensor") lib.impl(name, lambda x: x.clone(), "CompositeExplicitAutograd") op = self.get_op(qualname) class Op(torch.autograd.Function): @staticmethod def forward(ctx, x): with torch._C._AutoDispatchBelowAutograd(): return op(x) @staticmethod def backward(ctx, grad): raise NotImplementedError def autograd_impl(x): return Op.apply(x) lib.impl(name, autograd_impl, "Autograd") return op def _init_op_with_no_abstract_impl(self): name = "no_abstract" qualname = f"{self.test_ns}::{name}" lib = self.lib() lib.define(f"{name}(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)) lib.impl(name, lambda x: x.clone(), "CPU") return torch._library.utils.lookup_op(qualname) def setUp(self): super().setUp() self._op_with_no_abstract_impl = self._init_op_with_no_abstract_impl() self._op_delayed_backward_error = self._init_op_delayed_backward_error() @optests.dontGenerateOpCheckTests("Testing this API") def test_dont_generate(self): op = op_with_incorrect_schema(self, "incorrect_schema") x = torch.randn(3) op(x) def test_mm(self): x = torch.randn(2, 3, requires_grad=True) y = torch.randn(3, 5) result = torch.ops.aten.mm.default(x, y) self.assertEqual(result, x @ y) def test_mm_meta(self): x = torch.randn(2, 3, requires_grad=True, device="meta") y = torch.randn(3, 5, device="meta") result = torch.ops.aten.mm.default(x, y) self.assertEqual(result.shape, (x @ y).shape) def test_mm_fake(self): with torch._subclasses.fake_tensor.FakeTensorMode(): x = torch.randn(2, 3, requires_grad=True, device="cpu") y = torch.randn(3, 5, device="cpu") result = torch.ops.aten.mm.default(x, y) self.assertEqual(result.shape, (x @ y).shape) def test_mm_errors(self): x = torch.randn(2, 3, requires_grad=True) y = torch.randn(4, 5) with self.assertRaisesRegex(RuntimeError, "cannot be multiplied"): result = torch.ops.aten.mm.default(x, y) def test_nonzero(self): x = torch.tensor([0, 1, 2, 0, 0]) y = torch.ops.aten.nonzero.default(x) self.assertEqual(y, torch.tensor([[1], [2]])) def test_inplace(self): x = torch.randn(3) x_clone = x.clone() y = torch.ops.aten.sin_(x) self.assertEqual(x, x_clone.sin()) def test_incorrect_schema(self): op = op_with_incorrect_schema(self, "incorrect_schema") x = torch.randn(3) op(x) def test_no_abstract(self): op = self._op_with_no_abstract_impl x = torch.randn(3) op(x) def test_delayed_error(self): op = self._op_delayed_backward_error x = torch.randn([], requires_grad=True) y = op(x) with self.assertRaises(NotImplementedError): y.sum().backward() def test_delayed_error_no_requires_grad(self): op = self._op_delayed_backward_error x = torch.randn([]) y = op(x) class TestCustomOpAPI(TestCase): @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_basic(self): @torch.library.custom_op("_torch_testing::add", mutates_args=()) def add(x: Tensor, y: float) -> Tensor: x_np = x.numpy(force=True) out_np = x_np + y return torch.from_numpy(out_np).to(x.device) x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z, x + y) cpu_called = False @add.register_kernel("cpu") def _(x, y): nonlocal cpu_called cpu_called = True x_np = x.numpy() out_np = x_np + y return torch.from_numpy(out_np) z = add(x, y) self.assertEqual(z, x + y) self.assertTrue(cpu_called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_no_grad_skips_autograd(self): @torch.library.custom_op("_torch_testing::add", mutates_args=()) def add(x: Tensor, y: float) -> Tensor: x_np = x.numpy(force=True) out_np = x_np + y return torch.from_numpy(out_np).to(x.device) called = 0 def setup_context(ctx, inputs, output): nonlocal called called += 1 def backward(ctx, grad): raise AssertionError("should not be reached") add.register_autograd(backward, setup_context=setup_context) x = torch.randn(3, requires_grad=True) with torch.no_grad(): y = add(x, 2.0) self.assertEqual(called, 0) self.assertEqual(y, x + 2.0) x.requires_grad_(False) y = add(x, 2.0) self.assertEqual(called, 0) self.assertEqual(y, x + 2.0) x = torch.randn(3, requires_grad=True) y = add(x, 2.0) self.assertEqual(called, 1) self.assertEqual(y, x + 2.0) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_manual_schema(self): @torch.library.custom_op( "_torch_testing::add", mutates_args=(), schema="(Tensor x, float y) -> Tensor", ) def add(x, y): x_np = x.numpy(force=True) out_np = x_np + y return torch.from_numpy(out_np).to(x.device) x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z, x + y) @torch.library.custom_op( "_torch_testing::sin_", mutates_args=["x"], schema="(Tensor(a!) x) -> ()", ) def sin_(x): x_np = x.numpy() np.sin(x_np, out=x_np) x = torch.randn(3) expected = x.sin() sin_(x) self.assertEqual(x, expected) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_kwarg_only_tensors(self): with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): @torch.library.custom_op("_torch_testing::foo", mutates_args=()) def foo(x: Tensor, *, y: int, z: Tensor) -> Tensor: pass with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): @torch.library.custom_op("_torch_testing::foo", mutates_args=()) def foo2(x: Tensor, *, y: int, z: Optional[Tensor]) -> Tensor: pass with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): @torch.library.custom_op("_torch_testing::foo", mutates_args=()) def foo3(x: Tensor, *, y: int, z: List[Tensor]) -> Tensor: pass with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("foo(Tensor x, *, Tensor y) -> Tensor") with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): torch.library.register_autograd( "_torch_testing::foo", lambda grad: grad, setup_context=lambda ctx, inputs, keyword_only_inputs, output: None, ) with self.assertRaisesRegex(NotImplementedError, "kwarg-only Tensor args"): torch.library.register_vmap( "_torch_testing::foo", lambda info, in_dims, x, *, y: (x, 0), ) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_autograd_kwargonly_low_level(self): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("foo(Tensor x, *, float y) -> Tensor") called = False def foo_impl(x, *, y): return x * y lib.impl("foo", foo_impl, "CPU") def backward(ctx, grad): nonlocal called called = True return grad * ctx.y def setup_context(ctx, inputs, keyword_only_inputs, output): assert tuple(keyword_only_inputs.keys()) == ("y",) ctx.y = keyword_only_inputs["y"] torch.library.register_autograd( "_torch_testing::foo", backward, setup_context=setup_context, lib=lib ) x = torch.randn(3, requires_grad=True) torch.ops._torch_testing.foo(x, y=3.14).sum().backward() self.assertTrue(called) self.assertEqual(x.grad, torch.tensor([3.14, 3.14, 3.14])) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_autograd_defaults(self): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") def foo_impl(w, x=2, *, y=3, z): return w * x * y * z lib.impl("foo", foo_impl, "CPU") called = False def backward(ctx, grad): nonlocal called called = True return grad * ctx.c def setup_context(ctx, inputs, keyword_only_inputs, output): assert len(inputs) == 2 assert inputs[1] == 2 assert keyword_only_inputs == {"y": 3, "z": 42} ctx.c = keyword_only_inputs["y"] * keyword_only_inputs["z"] * inputs[1] torch.library.register_autograd( "_torch_testing::foo", backward, setup_context=setup_context, lib=lib ) w = torch.randn(3, requires_grad=True) torch.ops._torch_testing.foo(w, z=42).sum().backward() self.assertTrue(called) self.assertEqual(w.grad, torch.full_like(w, 2 * 3 * 42)) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_manual_schema_error(self): with self.assertRaisesRegex(ValueError, "the op mutates {'x'}"): @torch.library.custom_op( "_torch_testing::sin_", mutates_args=(), schema="(Tensor(a!) x) -> ()", ) def sin_(x): x_np = x.numpy() np.sin(x_np, out=x_np) def test_supports_tensorlist(self): @torch._library.autograd.supports_tensorlist class Stack(torch.autograd.Function): @staticmethod def forward(ctx, xs): ctx.num_xs = len(xs) return torch.stack(xs) @staticmethod def backward(ctx, grad): expected = ([True] * ctx.num_xs,) self.assertEqual(ctx.needs_input_grad, expected) return list(grad.unbind(0)) # call two applys, do a backward on the first def t(): return torch.randn([], requires_grad=True) xs0 = [t(), t(), t()] xs1 = [t(), t(), t(), t()] y0 = Stack.apply(xs0) y1 = Stack.apply(xs1) grads = torch.autograd.grad(y0.sum(), xs0) self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) # call one apply, do multiple backwards xs = [t(), t(), t()] y = Stack.apply(xs) _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) _ = torch.autograd.grad(y.sum(), xs, retain_graph=True) grads = torch.autograd.grad(y.sum(), xs, retain_graph=True) self.assertEqual(grads, [torch.tensor(1.0) for _ in range(3)]) # error: on access forward, backward directly with self.assertRaisesRegex(NotImplementedError, "Function.forward directly"): Stack.forward(None, xs) with self.assertRaisesRegex(NotImplementedError, "Function.backward directly"): Stack.backward(None, xs) # the recursive case @torch._library.autograd.supports_tensorlist class Foo(torch.autograd.Function): @staticmethod def forward(ctx, xs): if len(xs) > 1: return Foo.apply(xs[1:]) ctx.len_xs = len(xs) return xs[0].sin() @staticmethod def backward(ctx, grad): result = [None] * ctx.len_xs result[-1] = grad.cos() return result # should work result = Foo.apply(xs) expected = xs[-1].sin() self.assertEqual(result, expected) # recursive on backward @torch._library.autograd.supports_tensorlist class Bar(torch.autograd.Function): @staticmethod def forward(ctx, xs): return [xs[i] + i for i in range(len(xs))] @staticmethod def backward(ctx, grads): f1 = Bar.apply(grads[:2]) f2 = Bar.apply(grads[2:]) return f1 + f2 xs = [torch.tensor(0.0, requires_grad=True) for _ in range(5)] ys = Bar.apply(xs) sum(ys).backward() result = [xi.grad for xi in xs] self.assertEqual(result, torch.tensor([1.0, 2, 1, 2, 3]).unbind(0)) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_default_values(self): defaults = [] @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f( x: Tensor, a: Optional[int] = None, b: float = 3.14, c: bool = True, d: int = 3, e: str = "foo", f: torch.dtype = torch.float, g: torch.dtype = torch.float32, h: torch.dtype = torch.int, i: torch.device = torch.device("cpu:0"), j: torch.device = "cpu", ) -> Tensor: defaults.extend([a, b, c, d, e, f, g, h, i, j]) return x.clone() x = torch.randn(3) f(x) self.assertEqual( defaults, [ None, 3.14, True, 3, "foo", torch.float, torch.float32, torch.int, torch.device("cpu:0"), "cpu", ], ) default_values = [ arg.default_value for arg in torch.ops._torch_testing.f.default._schema.arguments ] # enum values taken from c10/core/ScalarType.h type_enum = { "float": 6, "int": 3, } self.assertEqual( default_values, [ None, None, 3.14, True, 3, "foo", type_enum["float"], type_enum["float"], type_enum["int"], torch.device("cpu:0"), torch.device("cpu"), ], ) def test_mutated_error(self): with self.assertRaisesRegex( ValueError, r".*{'y'} in mutates_args were not found" ): @torch.library.custom_op( "_torch_testing::numpy_sin_inplace", mutates_args={"y"}, device_types="cpu", ) def numpy_sin_inplace(x: Tensor) -> None: x_np = x.numpy() np.sin(x_np, out=x_np) def test_mutated(self): @torch.library.custom_op( "_torch_testing::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu" ) def numpy_sin_inplace(x: Tensor) -> None: x_np = x.numpy() np.sin(x_np, out=x_np) x = torch.randn(3) version = x._version expected = x.sin() numpy_sin_inplace(x) self.assertEqual(x, expected) self.assertGreater(x._version, version) @torch.library.custom_op("_torch_testing::f", mutates_args={"y", "z", "w"}) def f( x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] ) -> None: return x = torch.randn(3) y = torch.randn(3) z = [torch.randn(3), torch.randn(3)] w = [torch.randn(3), None, torch.randn(3)] initial_versions = pytree.tree_map_only( torch.Tensor, lambda x: x._version, (x, y, z, w) ) f(x, y, z, w) new_versions = pytree.tree_map_only( torch.Tensor, lambda x: x._version, (x, y, z, w) ) self.assertEqual(initial_versions[0], new_versions[0]) initial_versions, _ = pytree.tree_flatten(initial_versions[1:]) new_versions, _ = pytree.tree_flatten(new_versions[1:]) for prev, after in zip(initial_versions, new_versions): if prev is None and after is None: continue self.assertGreater(after, prev) def test_mutated_unknown(self): @torch.library.custom_op( "_torch_testing::f", mutates_args="unknown", device_types="cpu" ) def f(x: Tensor) -> None: x_np = x.numpy() np.sin(x_np, out=x_np) x = torch.randn(3) version = x._version expected = x.sin() f(x) self.assertEqual(x, expected) self.assertGreater(x._version, version) @torch.library.custom_op("_torch_testing::f2", mutates_args="unknown") def f2( x: Tensor, y: Optional[Tensor], z: List[Tensor], w: List[Optional[Tensor]] ) -> None: return x = torch.randn(3) y = torch.randn(3) z = [torch.randn(3), torch.randn(3)] w = [torch.randn(3), None, torch.randn(3)] initial_versions = pytree.tree_map_only( torch.Tensor, lambda x: x._version, (x, y, z, w) ) f2(x, y, z, w) new_versions = pytree.tree_map_only( torch.Tensor, lambda x: x._version, (x, y, z, w) ) initial_versions, _ = pytree.tree_flatten(initial_versions) new_versions, _ = pytree.tree_flatten(new_versions) for prev, after in zip(initial_versions, new_versions): if prev is None and after is None: continue self.assertGreater(after, prev) with self.assertRaisesRegex(ValueError, "string"): @torch.library.custom_op("_torch_testing::f3", mutates_args="x") def f3(x: Tensor) -> None: return @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_torch_dispatch_rule_subclass(self): from torch.testing._internal.two_tensor import TwoTensor @torch.library.custom_op("mylib::foo", mutates_args={}) def f(x: torch.Tensor) -> torch.Tensor: return x.sin() x = torch.randn(3) y = torch.randn(3) z = TwoTensor(x, y) with torch.library._scoped_library("mylib", "FRAGMENT") as m: called = 0 def TwoTensor_foo(cls, func, types, args, kwargs): nonlocal called assert cls is TwoTensor called += 1 return x.sin() m._register_torch_dispatch_rule("foo", TwoTensor, TwoTensor_foo) out = f(z) out2 = z.cos() self.assertEqual(called, 1) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_torch_dispatch_rule_mode(self): from torch.testing._internal.two_tensor import TwoTensorMode @torch.library.custom_op("mylib::foo", mutates_args={}) def f(x: torch.Tensor) -> torch.Tensor: return x.sin() x = torch.randn(3) with torch.library._scoped_library("mylib", "FRAGMENT") as m: called = 0 def TwoTensor_foo(mode, func, types, args, kwargs): nonlocal called called += 1 return x.sin() m._register_torch_dispatch_rule("foo", TwoTensorMode, TwoTensor_foo) with TwoTensorMode(): out = f(x) out2 = x.cos() self.assertEqual(called, 1) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") @parametrize("idx", [0, 1, 2, 3, 4, 5]) def test_library_register_fake_source(self, idx): opname = f"source{idx}" op = getattr(torch.ops._torch_testing, opname).default entry = torch._library.simple_registry.singleton.find(op._name) source = entry.fake_impl.kernel.source assert source is not None self.assertTrue("custom_op_db.py" in source) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_fake(self): for mode in ["function", "qualname", "opoverload"]: @torch.library.custom_op("_torch_testing::add", mutates_args=()) def add(x: Tensor, y: float) -> Tensor: x_np = x.cpu().numpy() out_np = x_np + y return torch.from_numpy(out_np).to(x.device) called = False if mode == "function": dec = torch.library.register_fake(add) self.assertIsNotNone(dec) elif mode == "qualname": dec = torch.library.register_fake("_torch_testing::add") self.assertIsNotNone(dec) elif mode == "opoverload": dec = torch.library.register_fake(torch.ops._torch_testing.add.default) self.assertIsNotNone(dec) else: raise AssertionError("should not get here") @dec def _(x, y): nonlocal called called = True return torch.empty_like(x) with torch._subclasses.fake_tensor.FakeTensorMode(): x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z.shape, x.shape) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_torch_dispatch(self): for mode in ["function", "qualname", "opoverload"]: class MyMode(torch.utils._python_dispatch.TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) @torch.library.custom_op("_torch_testing::add", mutates_args=()) def add(x: Tensor, y: float) -> Tensor: x_np = x.cpu().numpy() out_np = x_np + y return torch.from_numpy(out_np).to(x.device) called = False if mode == "function": dec = torch.library.register_torch_dispatch(add, MyMode) self.assertIsNotNone(dec) elif mode == "qualname": dec = torch.library.register_torch_dispatch( "_torch_testing::add", MyMode ) self.assertIsNotNone(dec) elif mode == "opoverload": dec = torch.library.register_torch_dispatch( torch.ops._torch_testing.add.default, MyMode ) self.assertIsNotNone(dec) else: raise AssertionError("should not get here") @dec def _(mode, func, types, args, kwargs): nonlocal called called = True return func(*args, **kwargs) with MyMode(): x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z.shape, x.shape) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_torch_dispatch_low_level(self): modes = ["qualname", "opoverload"] calls = ["decorator", "function"] device_types_options = [("cpu", "cuda"), "cpu", None] for mode, call, device_types in itertools.product( modes, calls, device_types_options ): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("add10(Tensor x, float y) -> Tensor") if mode == "qualname": op = "_torch_testing::add10" else: assert mode == "opoverload" op = torch.ops._torch_testing.add10.default called = False class MyMode(torch.utils._python_dispatch.TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) if call == "decorator": @torch.library.register_torch_dispatch(op, MyMode, lib=lib) def _(mode, func, types, args, kwargs): x, y = args nonlocal called called = True return x + y else: assert call == "function" def add_stuff(mode, func, types, args, kwargs): x, y = args nonlocal called called = True return x + y torch.library.register_torch_dispatch( op, MyMode, add_stuff, lib=lib ) x = torch.randn(3) y = 3.14 with MyMode(): z = torch.ops._torch_testing.add10.default(x, y) self.assertEqual(z, x + y) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_kernel(self): modes = ["function", "qualname", "opoverload"] calls = ["decorator", "function"] device_types_options = ["cpu", None] for mode, call, device_types in itertools.product( modes, calls, device_types_options ): @torch.library.custom_op( "_torch_testing::add", mutates_args=(), device_types="cuda" ) def add(x: Tensor, y: float) -> Tensor: x_np = x.cpu().numpy() out_np = x_np + y return torch.from_numpy(out_np).to(x.device) if mode == "function": op = add elif mode == "qualname": op = "_torch_testing::add" else: assert mode == "opoverload" op = torch.ops._torch_testing.add.default called = False if call == "decorator": @torch.library.register_kernel(op, device_types) def _(x, y): nonlocal called called = True x_np = x.numpy() out_np = x_np + y return torch.from_numpy(out_np) else: assert call == "function" def add_cpu(x, y): nonlocal called called = True x_np = x.numpy() out_np = x_np + y return torch.from_numpy(out_np) torch.library.register_kernel(op, device_types, add_cpu) x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z, x + y) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_kernel_low_level(self): modes = ["qualname", "opoverload"] calls = ["decorator", "function"] device_types_options = [("cpu", "cuda"), "cpu", None] for mode, call, device_types in itertools.product( modes, calls, device_types_options ): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("add9(Tensor x, float y) -> Tensor") if mode == "qualname": op = "_torch_testing::add9" else: assert mode == "opoverload" op = torch.ops._torch_testing.add9.default called = False if call == "decorator": @torch.library.register_kernel(op, device_types, lib=lib) def _(x, y): nonlocal called called = True x_np = x.numpy() out_np = x_np + y return torch.from_numpy(out_np) else: assert call == "function" def add_cpu(x, y): nonlocal called called = True x_np = x.numpy() out_np = x_np + y return torch.from_numpy(out_np) torch.library.register_kernel(op, device_types, add_cpu, lib=lib) x = torch.randn(3) y = 3.14 z = torch.ops._torch_testing.add9.default(x, y) self.assertEqual(z, x + y) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_autograd(self): for mode in ["function", "qualname", "opoverload"]: @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) def numpy_sin(x: Tensor) -> Tensor: x_np = x.cpu().numpy() y_np = np.sin(x_np) return torch.from_numpy(y_np).to(device=x.device) def setup_context(ctx, inputs, output) -> Tensor: (x,) = inputs ctx.save_for_backward(x) called = False def backward(ctx, grad): nonlocal called called = True (x,) = ctx.saved_tensors return grad * x.cos() if mode == "function": torch.library.register_autograd( numpy_sin, backward, setup_context=setup_context ) elif mode == "qualname": torch.library.register_autograd( "mylib::numpy_sin", backward, setup_context=setup_context ) elif mode == "opoverload": torch.library.register_autograd( torch.ops.mylib.numpy_sin.default, backward, setup_context=setup_context, ) x = torch.randn(3, requires_grad=True) y = numpy_sin(x) (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) self.assertTrue(called) self.assertEqual(grad_x, x.cos()) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_autograd_low_level(self): for mode in ["qualname", "opoverload"]: with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("sin5(Tensor x) -> Tensor") def numpy_sin(x: Tensor) -> Tensor: x_np = x.cpu().detach().numpy() y_np = np.sin(x_np) return torch.from_numpy(y_np).to(device=x.device) def setup_context(ctx, inputs, output) -> Tensor: (x,) = inputs ctx.save_for_backward(x) called = False def backward(ctx, grad): nonlocal called called = True (x,) = ctx.saved_tensors return grad * x.cos() lib.impl("sin5", numpy_sin, "CPU") called = False if mode == "qualname": torch.library.register_autograd( "_torch_testing::sin5", backward, setup_context=setup_context, lib=lib, ) elif mode == "opoverload": torch.library.register_autograd( torch.ops._torch_testing.sin5.default, backward, setup_context=setup_context, lib=lib, ) x = torch.randn(3, requires_grad=True) y = torch.ops._torch_testing.sin5(x) (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) self.assertTrue(called) self.assertEqual(grad_x, x.cos()) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_fake(self): @torch.library.custom_op("_torch_testing::add", mutates_args=()) def add(x: Tensor, y: float) -> Tensor: x_np = x.cpu().numpy() out_np = x_np + y return torch.from_numpy(out_np).to(x.device) x = torch.randn(3) y = 3.14 z = add(x, y) self.assertEqual(z, x + y) try: with torch._subclasses.fake_tensor.FakeTensorMode(): x = torch.randn(3) add(x, y) raise AssertionError("should not be hit") except RuntimeError as e: abstract_impl_error_msg = str(e) abstract_impl_error_msg = re.sub( r"0x.*>\)>", "0xDEADBEEF>)>", abstract_impl_error_msg ).replace(". ", ".\n") self.assertExpectedInline( abstract_impl_error_msg, """\ There was no fake impl registered for . This is necessary for torch.compile/export/fx tracing to work. Please use `add.register_fake` to add an fake impl.""", ) if not IS_WINDOWS: @torch.compile(backend="eager") def f(x, y): return add(x, y) x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, "no fake impl"): f(x, y) abstract_called = False @add.register_fake def _(x, y): nonlocal abstract_called abstract_called = True return torch.empty_like(x) with torch._subclasses.fake_tensor.FakeTensorMode(): x = torch.randn(3) z = add(x, y) self.assertEqual(z.shape, x.shape) self.assertTrue(abstract_called) @skipIfTorchDynamo("recursive dynamo") @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") def test_compile(self): called_impl = False called_abstract = False @torch.library.custom_op("_torch_testing::linear", mutates_args=()) def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: nonlocal called_impl called_impl = True x_np = x.numpy() w_np = weight.numpy() b_np = bias.numpy() out_np = np.add(x_np @ w_np.T, bias) return out_np @custom_linear.register_fake def _(x, weight, bias): nonlocal called_abstract called_abstract = True assert x.dim() == 2 assert weight.dim() == 2 assert bias.dim() == 1 assert x.shape[1] == weight.shape[1] assert weight.shape[0] == bias.shape[0] assert x.device == weight.device return x.new_empty(x.size(0), weight.size(0)) x = torch.randn(2, 2) weight = torch.randn(2, 2) bias = torch.randn(2) out = torch.compile(custom_linear, backend="eager", fullgraph=True)( x, weight, bias ) self.assertEqual(out, torch.nn.functional.linear(x, weight, bias)) self.assertTrue(called_impl) self.assertTrue(called_abstract) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_autograd_error_cases(self): @torch.library.custom_op("_torch_testing::g", mutates_args=()) def g(x: Tensor) -> Tensor: return x.sin() x = torch.randn(3, requires_grad=True) y = g(x) with self.assertRaisesRegex(RuntimeError, "no autograd formula"): y.sum().backward() @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_replacement(self): @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: return x.sin() x = torch.randn(3) y = f(x) self.assertEqual(y, x.sin()) @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: return x.cos() y = f(x) self.assertEqual(y, x.cos()) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") @unittest.skipIf(not TEST_CUDA, "requires CUDA") def test_split_device(self): cpu_call_count = 0 cuda_call_count = 0 @torch.library.custom_op( "_torch_testing::f", mutates_args=(), device_types="cpu" ) def f(x: Tensor) -> Tensor: nonlocal cpu_call_count cpu_call_count += 1 x_np = x.numpy() out_np = np.sin(x_np) return torch.from_numpy(out_np) @f.register_kernel("cuda") def _(x: Tensor) -> Tensor: nonlocal cuda_call_count cuda_call_count += 1 x_np = x.cpu().numpy() out_np = np.sin(x_np) return torch.from_numpy(out_np).to(x.device) x = torch.randn(3) y = f(x) self.assertEqual(y, x.sin()) self.assertEqual(cpu_call_count, 1) self.assertEqual(cuda_call_count, 0) x = x.cuda() y = f(x) self.assertEqual(y, x.sin()) self.assertEqual(cpu_call_count, 1) self.assertEqual(cuda_call_count, 1) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") @unittest.skipIf(not TEST_CUDA, "requires CUDA") def test_multi_types(self): @torch.library.custom_op( "_torch_testing::f", mutates_args=(), device_types=("cpu", "cuda") ) def f(x: Tensor) -> Tensor: x_np = x.cpu().numpy() out_np = np.sin(x_np) return torch.from_numpy(out_np).to(x.device) x = torch.randn(3) y = f(x) self.assertEqual(y, x.sin()) x = x.cuda() y = f(x) self.assertEqual(y, x.sin()) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_overloading(self): called_f = 0 called_f1 = 0 @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: nonlocal called_f called_f += 1 return x.clone() x = torch.randn(2, 3) torch.ops._torch_testing.f(x) self.assertEqual(called_f, 1) @torch.library.custom_op("_torch_testing::f.overload", mutates_args=()) def f1(x: Tensor, y: Tensor) -> Tensor: nonlocal called_f1 called_f1 += 1 return x.clone() torch.ops._torch_testing.f(x, x) self.assertEqual(called_f1, 1) def test_disallows_output_aliasing(self): @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: return x.view(-1) x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, "may not alias"): f(x) @torch.library.custom_op("_torch_testing::f", mutates_args=()) def f(x: Tensor) -> Tensor: return x x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, "may not alias"): f(x) @torch.library.custom_op( "_torch_testing::f", mutates_args={"x"}, device_types="cpu" ) def numpy_sin_inplace(x: Tensor) -> Tensor: x_np = x.numpy() np.sin(x_np, out=x_np) return x x = torch.randn(3) with self.assertRaisesRegex(RuntimeError, "may not alias"): numpy_sin_inplace(x) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_factory_function(self): @torch.library.custom_op( "_torch_testing::f", mutates_args={}, device_types="cpu" ) def f(device: torch.device) -> Tensor: return torch.ones(3) result = f(device="cpu") self.assertEqual(result.device, torch.device("cpu")) self.assertEqual(result, torch.ones(3)) with self.assertRaisesRegex( RuntimeError, "f does not have a kernel registered for cuda" ): f("cuda") with self.assertRaisesRegex( ValueError, "Functions without tensor inputs are required to have a `device: torch.device` argument", ): @torch.library.custom_op( "_torch_testing::f2", mutates_args={}, device_types="cpu" ) def f2() -> Tensor: return torch.ones(3) @torch.library.custom_op("_torch_testing::f3", mutates_args={}) def f3() -> Tensor: raise NotImplementedError("NYI") with self.assertRaisesRegex( ValueError, "Functions without tensor inputs are required to have a `device: torch.device` argument", ): @f3.register_kernel("cpu") def _(): return torch.zeros(3) result = f(x) @torch.library.custom_op("_torch_testing::f4", mutates_args={}) def f4(device: torch.device) -> Tensor: raise NotImplementedError("NYI") @f4.register_kernel("cpu") def _(device: torch.device): return torch.zeros(3) result = f(device="cpu") self.assertEqual(result.device, torch.device("cpu")) self.assertEqual(result, torch.ones(3)) def test_library_schema_infer(self): def foo_impl(x: torch.Tensor) -> torch.Tensor: return x.sin() schema = torch.library.infer_schema(foo_impl, op_name="myop", mutates_args={}) self.assertExpectedInline(schema, "myop(Tensor x) -> Tensor") schema = torch.library.infer_schema(foo_impl, mutates_args={}) self.assertExpectedInline(schema, "(Tensor x) -> Tensor") @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_set_kernel_enabled(self): x = torch.ones(1) @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor) -> Tensor: return x + 1 self.assertEqual(f(x), x + 1) with self.assertLogs("torch._library.custom_ops") as captured: with f.set_kernel_enabled("gpu", enabled=False): self.assertEqual(f(x), x + 1) self.assertIn( "no kernel was registered for this device type", captured.output[0] ) @f.register_kernel("cpu") def _(x): return x + 2 self.assertEqual(f(x), x + 2) with self.assertLogs("torch._library.custom_ops") as captured: with f.set_kernel_enabled("cpu", enabled=True): self.assertEqual(f(x), x + 2) self.assertIn("already enabled", captured.output[0]) with f.set_kernel_enabled("cpu", enabled=False): self.assertEqual(f(x), x + 1) with self.assertLogs("torch._library.custom_ops") as captured: with f.set_kernel_enabled("cpu", enabled=False): self.assertEqual(f(x), x + 1) self.assertIn("already disabled", captured.output[0]) self.assertEqual(f(x), x + 1) with f.set_kernel_enabled("cpu", enabled=True): self.assertEqual(f(x), x + 2) with f.set_kernel_enabled("cpu", enabled=False): self.assertEqual(f(x), x + 1) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_vmap_kwargonly_low_level(self): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("foo(Tensor x, *, float y) -> Tensor") called = False def foo_impl(x, *, y): return x * y lib.impl("foo", foo_impl, "CPU") def vmap(info, in_dims, x, *, y): nonlocal called called = True return x * y, 0 torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) x = torch.ones(3) result = torch.vmap(torch.ops._torch_testing.foo)(x, y=3.14) self.assertTrue(called) self.assertEqual(result, torch.tensor([3.14, 3.14, 3.14])) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_register_vmap_defaults(self): with torch.library._scoped_library("_torch_testing", "FRAGMENT") as lib: lib.define("foo(Tensor w, int x = 2, *, int y = 3, int z) -> Tensor") def foo_impl(w, x=2, *, y=3, z): return w * x * y * z lib.impl("foo", foo_impl, "CPU") called = False def vmap(info, in_dims, w, x=2, *, y=3, z): nonlocal called called = True return w * x * y * z, 0 torch.library.register_vmap("_torch_testing::foo", vmap, lib=lib) w = torch.ones(3) result = torch.vmap(torch.ops._torch_testing.foo)(w, z=42) self.assertTrue(called) self.assertEqual(result, w * 2 * 3 * 42) def test_layout_constraint_tags(self): needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order flexible_layout = torch._C.Tag.flexible_layout # (tags, the result of the tag inference) tests = [ ({needs_fixed_stride_order}, needs_fixed_stride_order), ({flexible_layout}, flexible_layout), # If no tags are provided, then the following is the default (set(), flexible_layout), # If multiple tags are provided, then we use the most constrained tag. ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), ] from torch._inductor.lowering import get_layout_constraint_tag for tags, expected in tests: with torch.library._scoped_library("mylib", "FRAGMENT") as m: m.define("foobar(Tensor x) -> Tensor", tags=tags) result = get_layout_constraint_tag(torch.ops.mylib.foobar.default) self.assertEqual(result, expected) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap(self): for mode in ["function", "qualname", "opoverload", "c_opdef"]: @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor, y: Tensor) -> Tensor: return x * y called = False def fvmap(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x * y result = result.movedim(-1, 0) return result, 0 if mode == "function": torch.library.register_vmap(f, fvmap) elif mode == "qualname": torch.library.register_vmap("mylib::f", fvmap) elif mode == "opoverload": torch.library.register_vmap(torch.ops.mylib.f.default, fvmap) elif mode == "c_opdef": f.register_vmap(fvmap) x = torch.randn(2, 2) y = torch.randn(2, 2) result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x * y) called = False result = torch.vmap(f, out_dims=1)(x, y) self.assertEqual(result, (x * y).T) self.assertTrue(called) called = False result = torch.vmap(f, in_dims=1)(x, y) self.assertEqual(result, (x * y).T) self.assertTrue(called) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_library_decorator(self): @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor, y: Tensor) -> Tensor: return x * y called = False @torch.library.register_vmap("mylib::f") def fvmap(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x * y result = result.movedim(-1, 0) return result, 0 x = torch.randn(2, 2) y = torch.randn(2, 2) result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x * y) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_op_decorator(self): @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor, y: Tensor) -> Tensor: return x * y called = False @f.register_vmap def fvmap(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x * y result = result.movedim(-1, 0) return result, 0 x = torch.randn(2, 2) y = torch.randn(2, 2) result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x * y) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_register_multiple_times(self): @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor, y: Tensor) -> Tensor: return x * y called = False @f.register_vmap def fvmap(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x * y result = result.movedim(-1, 0) return result, 0 x = torch.randn(2, 2) y = torch.randn(2, 2) result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x * y) called = False @f.register_vmap def fvmap2(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x + y result = result.movedim(-1, 0) return result, 0 result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x + y) @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") def test_library_register_vmap_register_multiple_times_2(self): @torch.library.custom_op("mylib::f", mutates_args=()) def f(x: Tensor, y: Tensor) -> Tensor: return x * y called = False @torch.library.register_vmap("mylib::f") def fvmap(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x * y result = result.movedim(-1, 0) return result, 0 x = torch.randn(2, 2) y = torch.randn(2, 2) result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x * y) called = False @torch.library.register_vmap("mylib::f") def fvmap2(info, in_dims, x, y): nonlocal called called = True x_bdim, y_bdim = in_dims x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) result = x + y result = result.movedim(-1, 0) return result, 0 result = torch.vmap(f)(x, y) self.assertTrue(called) self.assertEqual(result, x + y) class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" def test_nonzero_again(self): x = torch.tensor([0, 1, 2, 0, 0]) y = torch.ops.aten.nonzero.default(x) self.assertEqual(y, torch.tensor([[1], [2]])) optests.generate_opcheck_tests( MiniOpTest, ["aten", "mini_op_test"], get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), additional_decorators={ "test_pt2_compliant_tag_mini_op_test_no_abstract": [unittest.expectedFailure] }, test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, ) optests.generate_opcheck_tests( MiniOpTestOther, ["aten", "mini_op_test"], get_file_path_2(os.path.dirname(__file__), "minioptest_failures_dict.json"), test_utils=optests.generate_tests.DEPRECATED_DEFAULT_TEST_UTILS, ) class TestGenerateOpcheckTests(CustomOpTestCaseBase): def test_MiniOpTest(self): for orig_test in ["test_mm", "test_nonzero"]: for ( test ) in torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS: expected_test = f"{test}__{orig_test}" self.assertTrue(hasattr(MiniOpTest, expected_test), msg=expected_test) def test_generate_repro_save_data(self): from torch.testing._internal.optests.generate_tests import generate_repro args = (torch.ones(2, 2),) kwargs = {"mat2": torch.zeros(2, 2)} actual = generate_repro( "test_schema", torch.ops.aten.sin.default, args, kwargs, save_data=True, dry_run=True, ) actual = re.sub(r"torch.load\(\".*\.pt\"\)", 'torch.load("repro.pt")', actual) self.assertExpectedInline( actual, """\ # ========================================================= # BEGIN REPRO SCRIPT # ========================================================= import torch from torch.testing._internal.optests import opcheck # Make sure you have loaded the library that contains the op # via an import or torch.ops.load_library(...) op = torch.ops.aten.sin.default args, kwargs = torch.load("repro.pt") opcheck(op, args, kwargs, test_utils="test_schema") # ========================================================= # END REPRO SCRIPT # ========================================================= """, ) def test_generate_repro_no_save_data(self): from torch.testing._internal.optests.generate_tests import generate_repro args = (torch.ones(2, 2),) kwargs = {"mat2": torch.zeros(2, 2)} actual = generate_repro( "test_schema", torch.ops.aten.sin.default, args, kwargs, save_data=False, dry_run=True, ) self.assertExpectedInline( actual, """\ # ========================================================= # BEGIN REPRO SCRIPT # ========================================================= import torch from torch.testing._internal.optests import opcheck # Make sure you have loaded the library that contains the op # via an import or torch.ops.load_library(...) op = torch.ops.aten.sin.default # If you rerun your test with PYTORCH_OPCHECK_PRINT_BETTER_REPRO=1 # we will fill them in same (args, kwargs) as in your test args = () # args to the operator kwargs = {} # kwargs to the operator opcheck(op, args, kwargs, test_utils="test_schema") # ========================================================= # END REPRO SCRIPT # ========================================================= """, ) def test_failures_dict_validation(self): from torch.testing._internal.optests.generate_tests import ( FailuresDict, validate_failures_dict_structure, ) failures = { "mini_op_test::incorrect_schema": { "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error": { "comment": "", "status": "success", } } } with self.assertRaisesRegex(RuntimeError, "got status=success"): validate_failures_dict_structure( FailuresDict("", failures), torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, MiniOpTest, ) failures = { "mini_op_test::incorrect_schema": { "MiniOpTest.test_aot_dispatch__test_delayed_error": { "comment": "", "status": "xfail", }, } } with self.assertRaisesRegex(RuntimeError, "should begin with one of"): validate_failures_dict_structure( FailuresDict("", failures), torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, MiniOpTest, ) failures = { "mini_op_test::incorrect_schema": { "MiniOpTest.test_aot_dispatch_dynamic__test_delayed_error_nopenopenope": { "comment": "", "status": "xfail", }, } } with self.assertRaisesRegex(RuntimeError, "does not exist on the TestCase"): validate_failures_dict_structure( FailuresDict("", failures), torch.testing._internal.optests.generate_tests.DEFAULT_TEST_UTILS, MiniOpTest, ) def test_dont_generate_decorator(self): self.assertTrue(hasattr(MiniOpTest, "test_dont_generate")) self.assertFalse(hasattr(MiniOpTest, "test_schema__test_dont_generate")) def test_opcheck(self): x = torch.randn(3, requires_grad=True) with self.assertRaisesRegex(ValueError, "OpOverload"): torch.library.opcheck(torch.sin, (x,)) with self.assertRaisesRegex(ValueError, "test_utils to be subset of"): torch.library.opcheck(torch.ops.aten.sin.default, (x,), test_utils="blah") result = torch.library.opcheck(torch.ops.aten.sin.default, (x,)) self.assertEqual( result, { "test_schema": "SUCCESS", "test_autograd_registration": "SUCCESS", "test_faketensor": "SUCCESS", "test_aot_dispatch_dynamic": "SUCCESS", }, ) result = torch.library.opcheck( torch.ops.aten.sin.default, (x,), test_utils="test_schema" ) self.assertEqual(result, {"test_schema": "SUCCESS"}) result = torch.library.opcheck( torch.ops.aten.sin.default, (x,), test_utils=["test_schema", "test_faketensor"], ) self.assertEqual( result, { "test_schema": "SUCCESS", "test_faketensor": "SUCCESS", }, ) def test_opcheck_customopdef(self): sample_inputs = [ (torch.randn(3),), (torch.randn(3, requires_grad=True),), ] if torch.cuda.is_available(): sample_inputs.extend( [ (torch.randn(3, device="cuda"),), (torch.randn(3, device="cuda", requires_grad=True),), ] ) for args in sample_inputs: torch.library.opcheck(custom_op_db.numpy_cube, args) def test_is_inside_opcheck_mode(self): self.assertFalse(optests.is_inside_opcheck_mode()) with optests.generate_tests.OpCheckMode( ["foo"], "bar", lambda x: x, None, "baz", "brr" ): self.assertTrue(optests.is_inside_opcheck_mode()) def test_opcheck_bad_op(self): op = op_with_incorrect_schema(self, "foo") x = torch.randn(3) with self.assertRaisesRegex(Exception, "is not defined to alias output"): torch.library.opcheck(op, (x,)) result = torch.library.opcheck(op, (x,), raise_exception=False) self.assertTrue(isinstance(result["test_schema"], RuntimeError)) del result["test_schema"] self.assertEqual( result, { "test_autograd_registration": "SUCCESS", "test_faketensor": "SUCCESS", "test_aot_dispatch_dynamic": "SUCCESS", }, ) def test_opcheck_does_not_require_extra_deps(self): # torch.testing._internal.common_utils comes with a lot of additional # test-time dependencies. Since opcheck is public API, it should be # usable only with pytorch install-time dependencies. cmd = [ sys.executable, "-c", "import torch; import sys; \ x = torch.randn(3, requires_grad=True); \ torch.library.opcheck(torch.ops.aten.sin.default, (x,)); \ assert 'expecttest' not in sys.modules; \ assert 'torch.testing._internal.common_utils' not in sys.modules", ] subprocess.check_output(cmd, shell=False) class TestTypeConversion(TestCase): """In infer_schema(), we try to suggest a correct type when the type annotation is wrong.""" def setUp(self): self.supported_base_types = [ int, float, bool, str, torch.device, torch.Tensor, torch.dtype, torch.types.Number, ] def test_simple_tuple(self): self.assertEqual(List, tuple_to_list(Tuple)) def test_supported_types(self): for t in self.supported_base_types: result_type = tuple_to_list(Tuple[t, t, t]) self.assertEqual(result_type, List[t]) result_type = tuple_to_list(Tuple[t]) self.assertEqual(result_type, List[t]) def test_optional(self): for t in self.supported_base_types: result_type = tuple_to_list(Tuple[t, Optional[t]]) self.assertEqual(result_type, List[Optional[t]]) result_type = tuple_to_list(Tuple[t, t, Optional[t]]) self.assertEqual(result_type, List[Optional[t]]) result_type = tuple_to_list(Tuple[t, ...]) self.assertEqual(result_type, List[t]) def test_mixed_types(self): result_type = tuple_to_list(Tuple[int, float]) self.assertEqual(result_type, List[typing.Union[int, float]]) result_type = tuple_to_list(Tuple[int, float, str]) self.assertEqual(result_type, List[typing.Union[int, float, str]]) only_for = ("cpu", "cuda") instantiate_device_type_tests(TestCustomOpTesting, globals(), only_for=only_for) instantiate_parametrized_tests(TestCustomOp) instantiate_parametrized_tests(TestCustomOpAPI) if __name__ == "__main__": run_tests()