1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport functools 3*da0073e9SAndroid Build Coastguard Workerimport itertools 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 10*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config 11*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 12*da0073e9SAndroid Build Coastguard Workerimport torch.utils.checkpoint 13*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import normalize_gm 14*da0073e9SAndroid Build Coastguard Workerfrom torch._higher_order_ops.wrap import wrap 15*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ( 16*da0073e9SAndroid Build Coastguard Worker DimDynamic, 17*da0073e9SAndroid Build Coastguard Worker ShapeEnv, 18*da0073e9SAndroid Build Coastguard Worker StatelessSymbolicContext, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch.nested._internal.nested_tensor import ( 21*da0073e9SAndroid Build Coastguard Worker jagged_from_list, 22*da0073e9SAndroid Build Coastguard Worker jagged_from_tensor_and_lengths, 23*da0073e9SAndroid Build Coastguard Worker nested_view_from_values_offsets, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 26*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 27*da0073e9SAndroid Build Coastguard Worker NestedTensorTestCase, 28*da0073e9SAndroid Build Coastguard Worker parametrize, 29*da0073e9SAndroid Build Coastguard Worker subtest, 30*da0073e9SAndroid Build Coastguard Worker) 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.inductor_utils import HAS_CUDA 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor 33*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import return_and_correct_aliasing 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerdef traceable_subclass(c): 37*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Workerdef _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 41*da0073e9SAndroid Build Coastguard Worker actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) 42*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_recompiles, expected_recompiles) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerdef get_jagged_tensor(nested_size, offsets, requires_grad=True): 46*da0073e9SAndroid Build Coastguard Worker # Makes a jagged tensor with N constituent tensors with size 47*da0073e9SAndroid Build Coastguard Worker # as specified ((S0, S1, S2), D) 48*da0073e9SAndroid Build Coastguard Worker D = nested_size[1] 49*da0073e9SAndroid Build Coastguard Worker out = [] 50*da0073e9SAndroid Build Coastguard Worker for s in nested_size[0]: 51*da0073e9SAndroid Build Coastguard Worker out.append(torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64)) 52*da0073e9SAndroid Build Coastguard Worker return jagged_from_list(out, offsets) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef get_view_test_cases(): 56*da0073e9SAndroid Build Coastguard Worker # Test all cases with both an NT base and a dense base 57*da0073e9SAndroid Build Coastguard Worker # Subclass -> Subclass 58*da0073e9SAndroid Build Coastguard Worker # Dense -> Subclass 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker # NB: Don't close over loop variables, they will not get copied into the 61*da0073e9SAndroid Build Coastguard Worker # closure 62*da0073e9SAndroid Build Coastguard Worker # 63*da0073e9SAndroid Build Coastguard Worker # NB: These return functions so we don't generate tensors during test 64*da0073e9SAndroid Build Coastguard Worker # collection time 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def mk_basic(base_is_nt): 67*da0073e9SAndroid Build Coastguard Worker # There are three cases to consider here based on the logic in 68*da0073e9SAndroid Build Coastguard Worker # meta_utils.py 69*da0073e9SAndroid Build Coastguard Worker # 70*da0073e9SAndroid Build Coastguard Worker # (1) basic case: 71*da0073e9SAndroid Build Coastguard Worker # view is not a leaf and has the same requires grad as its basic case 72*da0073e9SAndroid Build Coastguard Worker x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) 73*da0073e9SAndroid Build Coastguard Worker x = x.clone() if base_is_nt else x 74*da0073e9SAndroid Build Coastguard Worker assert not x.is_leaf 75*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze(-1) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker def mk_leaf(base_is_nt, requires_grad_1, requires_grad_2): 78*da0073e9SAndroid Build Coastguard Worker x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=requires_grad_1) 79*da0073e9SAndroid Build Coastguard Worker x = x.clone() if base_is_nt else x 80*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 81*da0073e9SAndroid Build Coastguard Worker x_view = x.unsqueeze(-1) 82*da0073e9SAndroid Build Coastguard Worker # The issue is this doesn't quite work 83*da0073e9SAndroid Build Coastguard Worker x_view.requires_grad_(requires_grad_2) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker return x_view 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def mk_obscure(base_is_nt): 88*da0073e9SAndroid Build Coastguard Worker x, _ = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) 89*da0073e9SAndroid Build Coastguard Worker x = x.clone() if base_is_nt else x 90*da0073e9SAndroid Build Coastguard Worker # intermediate leaf view 91*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 92*da0073e9SAndroid Build Coastguard Worker x_view = x.unsqueeze(-1) 93*da0073e9SAndroid Build Coastguard Worker x_view.requires_grad_(True) 94*da0073e9SAndroid Build Coastguard Worker x_view_view = x_view.unsqueeze(-1) 95*da0073e9SAndroid Build Coastguard Worker return x_view_view 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker for base_is_nt in [False, True]: 98*da0073e9SAndroid Build Coastguard Worker prefix = f"base_is_nt_{base_is_nt}" 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker yield partial(mk_basic, base_is_nt), f"{prefix}_basic" 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker # (2) leaf view case: 103*da0073e9SAndroid Build Coastguard Worker # the view has to be a leaf (w/ requires_grad True or requires_grad False) 104*da0073e9SAndroid Build Coastguard Worker # base w/ requires_grad True or requires_grad False 105*da0073e9SAndroid Build Coastguard Worker for requires_grad_1, requires_grad_2 in itertools.product( 106*da0073e9SAndroid Build Coastguard Worker [True, False], repeat=2 107*da0073e9SAndroid Build Coastguard Worker ): 108*da0073e9SAndroid Build Coastguard Worker yield partial( 109*da0073e9SAndroid Build Coastguard Worker mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 110*da0073e9SAndroid Build Coastguard Worker ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker # (3) obscure case: 113*da0073e9SAndroid Build Coastguard Worker # view is not a leaf (implies requires_grad True) 114*da0073e9SAndroid Build Coastguard Worker # base w/ requires_grad False) 115*da0073e9SAndroid Build Coastguard Worker yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # Subclass -> Dense 118*da0073e9SAndroid Build Coastguard Worker yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ 119*da0073e9SAndroid Build Coastguard Worker 0 120*da0073e9SAndroid Build Coastguard Worker ].clone(), "subclass_dense" 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker # Dense -> Subclass -> Dense -> Subclass 123*da0073e9SAndroid Build Coastguard Worker def mk_dense_subclass_dense_subclass(): 124*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5) 125*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 3, 6, 10]) 126*da0073e9SAndroid Build Coastguard Worker offsets2 = offsets.clone().detach() 127*da0073e9SAndroid Build Coastguard Worker return nested_view_from_values_offsets( 128*da0073e9SAndroid Build Coastguard Worker nested_view_from_values_offsets(values, offsets).values(), offsets 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker yield mk_dense_subclass_dense_subclass, "dense_subclass_dense_subclass" 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def mk_subclass_dense_subclass_dense(): 134*da0073e9SAndroid Build Coastguard Worker x = get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() 135*da0073e9SAndroid Build Coastguard Worker offsets2 = x.offsets().clone().detach() 136*da0073e9SAndroid Build Coastguard Worker nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker yield mk_subclass_dense_subclass_dense, "subclass_dense_subclass_dense" 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard WorkerVIEW_TEST_CASES = {k: v for v, k in get_view_test_cases()} 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Workercompile_full_eager = torch.compile(backend="eager", fullgraph=True) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerclass BaseTorchFunction(torch.Tensor): 150*da0073e9SAndroid Build Coastguard Worker @classmethod 151*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 152*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 153*da0073e9SAndroid Build Coastguard Worker kwargs = {} 154*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Workerclass MockSubclass(torch.Tensor): 158*da0073e9SAndroid Build Coastguard Worker @classmethod 159*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 160*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 161*da0073e9SAndroid Build Coastguard Worker kwargs = {} 162*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Workerclass AttrSubclass(torch.Tensor): 166*da0073e9SAndroid Build Coastguard Worker x: int = 10 167*da0073e9SAndroid Build Coastguard Worker size: int = 10 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker @classmethod 170*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 171*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 172*da0073e9SAndroid Build Coastguard Worker kwargs = {} 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Workerclass DummyNDim(torch.Tensor): 178*da0073e9SAndroid Build Coastguard Worker @classmethod 179*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 180*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 181*da0073e9SAndroid Build Coastguard Worker kwargs = {} 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker if func == torch.Tensor.ndim.__get__: 184*da0073e9SAndroid Build Coastguard Worker return 10 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Workerclass WrapperSubclass: 190*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 191*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker @classmethod 194*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 195*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 196*da0073e9SAndroid Build Coastguard Worker kwargs = {} 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) 199*da0073e9SAndroid Build Coastguard Worker kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Workerclass SigmoidToExpSubclass(torch.Tensor): 205*da0073e9SAndroid Build Coastguard Worker @classmethod 206*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 207*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 208*da0073e9SAndroid Build Coastguard Worker kwargs = {} 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker if func == torch.Tensor.sigmoid: 211*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker# Wrapper subclass with two inner tensors: data and scale 217*da0073e9SAndroid Build Coastguard Worker# data has same shape as outer, and scale has single dim size 218*da0073e9SAndroid Build Coastguard Workerclass ScaledTensor(torch.Tensor): 219*da0073e9SAndroid Build Coastguard Worker def __new__( 220*da0073e9SAndroid Build Coastguard Worker cls, 221*da0073e9SAndroid Build Coastguard Worker data: torch.Tensor, 222*da0073e9SAndroid Build Coastguard Worker scale: torch.Tensor, 223*da0073e9SAndroid Build Coastguard Worker *, 224*da0073e9SAndroid Build Coastguard Worker constant: int = 0, 225*da0073e9SAndroid Build Coastguard Worker ): 226*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 227*da0073e9SAndroid Build Coastguard Worker cls, 228*da0073e9SAndroid Build Coastguard Worker data.size(), 229*da0073e9SAndroid Build Coastguard Worker strides=data.stride(), 230*da0073e9SAndroid Build Coastguard Worker storage_offset=data.storage_offset(), 231*da0073e9SAndroid Build Coastguard Worker dtype=data.dtype, 232*da0073e9SAndroid Build Coastguard Worker layout=data.layout, 233*da0073e9SAndroid Build Coastguard Worker requires_grad=data.requires_grad, 234*da0073e9SAndroid Build Coastguard Worker device=data.device, 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): 238*da0073e9SAndroid Build Coastguard Worker self._data = data 239*da0073e9SAndroid Build Coastguard Worker self._scale = scale 240*da0073e9SAndroid Build Coastguard Worker self._constant = constant 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 243*da0073e9SAndroid Build Coastguard Worker ctx = {"_constant": self._constant} 244*da0073e9SAndroid Build Coastguard Worker return ["_data", "_scale"], ctx 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker @staticmethod 247*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): 248*da0073e9SAndroid Build Coastguard Worker assert len(inner_tensors) == 2 249*da0073e9SAndroid Build Coastguard Worker return ScaledTensor( 250*da0073e9SAndroid Build Coastguard Worker inner_tensors["_data"], 251*da0073e9SAndroid Build Coastguard Worker inner_tensors["_scale"], 252*da0073e9SAndroid Build Coastguard Worker constant=metadata["_constant"], 253*da0073e9SAndroid Build Coastguard Worker ) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker @classmethod 256*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs=None): 257*da0073e9SAndroid Build Coastguard Worker scaled_tensor = args[0] 258*da0073e9SAndroid Build Coastguard Worker out = func(scaled_tensor._data, *args[1:], **kwargs) 259*da0073e9SAndroid Build Coastguard Worker return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 262*da0073e9SAndroid Build Coastguard Worker return f"{self._data.__repr__()}\n{self._scale.__repr__()}" 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Workerclass OptionalScaledTensor(torch.Tensor): 266*da0073e9SAndroid Build Coastguard Worker def __new__( 267*da0073e9SAndroid Build Coastguard Worker cls, 268*da0073e9SAndroid Build Coastguard Worker data, 269*da0073e9SAndroid Build Coastguard Worker scale, 270*da0073e9SAndroid Build Coastguard Worker *, 271*da0073e9SAndroid Build Coastguard Worker constant: int = 0, 272*da0073e9SAndroid Build Coastguard Worker ): 273*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 274*da0073e9SAndroid Build Coastguard Worker cls, 275*da0073e9SAndroid Build Coastguard Worker data.size(), 276*da0073e9SAndroid Build Coastguard Worker strides=data.stride(), 277*da0073e9SAndroid Build Coastguard Worker storage_offset=data.storage_offset(), 278*da0073e9SAndroid Build Coastguard Worker dtype=data.dtype, 279*da0073e9SAndroid Build Coastguard Worker layout=data.layout, 280*da0073e9SAndroid Build Coastguard Worker requires_grad=data.requires_grad, 281*da0073e9SAndroid Build Coastguard Worker device=data.device, 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def __init__(self, data: torch.Tensor, scale, constant: int = 0): 285*da0073e9SAndroid Build Coastguard Worker self._data = data 286*da0073e9SAndroid Build Coastguard Worker self._scale = scale 287*da0073e9SAndroid Build Coastguard Worker self._constant = constant 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 290*da0073e9SAndroid Build Coastguard Worker ctx = {"_constant": self._constant} 291*da0073e9SAndroid Build Coastguard Worker if self._scale is not None: 292*da0073e9SAndroid Build Coastguard Worker return ["_data", "_scale"], ctx 293*da0073e9SAndroid Build Coastguard Worker else: 294*da0073e9SAndroid Build Coastguard Worker return ["_data"], ctx 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker @staticmethod 297*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): 298*da0073e9SAndroid Build Coastguard Worker return OptionalScaledTensor( 299*da0073e9SAndroid Build Coastguard Worker inner_tensors["_data"], 300*da0073e9SAndroid Build Coastguard Worker inner_tensors["_scale"] if "_scale" in inner_tensors else None, 301*da0073e9SAndroid Build Coastguard Worker constant=metadata["_constant"], 302*da0073e9SAndroid Build Coastguard Worker ) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker @classmethod 305*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs=None): 306*da0073e9SAndroid Build Coastguard Worker scaled_tensor = args[0] 307*da0073e9SAndroid Build Coastguard Worker out = func(scaled_tensor._data, *args[1:], **kwargs) 308*da0073e9SAndroid Build Coastguard Worker if scaled_tensor._scale is not None: 309*da0073e9SAndroid Build Coastguard Worker out = out * scaled_tensor._scale 310*da0073e9SAndroid Build Coastguard Worker return OptionalScaledTensor( 311*da0073e9SAndroid Build Coastguard Worker out, scaled_tensor._scale, constant=scaled_tensor._constant 312*da0073e9SAndroid Build Coastguard Worker ) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 315*da0073e9SAndroid Build Coastguard Worker return ( 316*da0073e9SAndroid Build Coastguard Worker f"OptionalScaledTensor({self._data.__repr__()}\n{self._scale.__repr__()})" 317*da0073e9SAndroid Build Coastguard Worker ) 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerclass CtxSubclassTensor(torch.Tensor): 321*da0073e9SAndroid Build Coastguard Worker """ 322*da0073e9SAndroid Build Coastguard Worker Class used to verify guarding on the subclass metadata 323*da0073e9SAndroid Build Coastguard Worker """ 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker @staticmethod 326*da0073e9SAndroid Build Coastguard Worker def __new__(cls, a, constant): 327*da0073e9SAndroid Build Coastguard Worker shape = a.shape 328*da0073e9SAndroid Build Coastguard Worker kwargs = {} 329*da0073e9SAndroid Build Coastguard Worker kwargs["strides"] = a.stride() 330*da0073e9SAndroid Build Coastguard Worker kwargs["storage_offset"] = a.storage_offset() 331*da0073e9SAndroid Build Coastguard Worker kwargs["device"] = a.device 332*da0073e9SAndroid Build Coastguard Worker kwargs["layout"] = a.layout 333*da0073e9SAndroid Build Coastguard Worker kwargs["requires_grad"] = a.requires_grad 334*da0073e9SAndroid Build Coastguard Worker kwargs["dtype"] = a.dtype 335*da0073e9SAndroid Build Coastguard Worker out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 336*da0073e9SAndroid Build Coastguard Worker return out 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, constant): 339*da0073e9SAndroid Build Coastguard Worker self.a = a 340*da0073e9SAndroid Build Coastguard Worker self.constant = constant 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 343*da0073e9SAndroid Build Coastguard Worker a_repr = repr(self.a) 344*da0073e9SAndroid Build Coastguard Worker return f"CtxSubclassTensor({a_repr})" 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 347*da0073e9SAndroid Build Coastguard Worker return ["a"], (self.constant,) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker @staticmethod 350*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, meta, sizes, strides): 351*da0073e9SAndroid Build Coastguard Worker constant = meta[0] 352*da0073e9SAndroid Build Coastguard Worker a = inner_tensors["a"] 353*da0073e9SAndroid Build Coastguard Worker return CtxSubclassTensor(a, constant) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker @classmethod 356*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 357*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import return_and_correct_aliasing 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 360*da0073e9SAndroid Build Coastguard Worker kwargs = {} 361*da0073e9SAndroid Build Coastguard Worker biggest_constant = max( 362*da0073e9SAndroid Build Coastguard Worker [ 363*da0073e9SAndroid Build Coastguard Worker x.constant 364*da0073e9SAndroid Build Coastguard Worker for x in pytree.tree_flatten(args)[0] 365*da0073e9SAndroid Build Coastguard Worker if isinstance(x, CtxSubclassTensor) 366*da0073e9SAndroid Build Coastguard Worker ] 367*da0073e9SAndroid Build Coastguard Worker ) 368*da0073e9SAndroid Build Coastguard Worker args_a = pytree.tree_map( 369*da0073e9SAndroid Build Coastguard Worker lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, args 370*da0073e9SAndroid Build Coastguard Worker ) 371*da0073e9SAndroid Build Coastguard Worker kwargs_a = pytree.tree_map( 372*da0073e9SAndroid Build Coastguard Worker lambda x: x.a if isinstance(x, CtxSubclassTensor) else x, kwargs 373*da0073e9SAndroid Build Coastguard Worker ) 374*da0073e9SAndroid Build Coastguard Worker out_a = func(*args_a, **kwargs_a) 375*da0073e9SAndroid Build Coastguard Worker out = pytree.tree_map( 376*da0073e9SAndroid Build Coastguard Worker lambda x: CtxSubclassTensor(x, biggest_constant) 377*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) 378*da0073e9SAndroid Build Coastguard Worker else x, 379*da0073e9SAndroid Build Coastguard Worker out_a, 380*da0073e9SAndroid Build Coastguard Worker ) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker if func == torch.ops.aten.mul.Tensor: 383*da0073e9SAndroid Build Coastguard Worker out = out + out.constant 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker return return_and_correct_aliasing(func, args, kwargs, out) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Workerdef func(a): 389*da0073e9SAndroid Build Coastguard Worker return a.sin() 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Workerclass EagerRecordGraphAndInputs: 393*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 394*da0073e9SAndroid Build Coastguard Worker self.graphs = [] 395*da0073e9SAndroid Build Coastguard Worker self.example_inputs = [] 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def __call__(self, gm: torch.fx.GraphModule, example_inputs): 398*da0073e9SAndroid Build Coastguard Worker self.graphs.append(gm) 399*da0073e9SAndroid Build Coastguard Worker self.example_inputs.append(example_inputs) 400*da0073e9SAndroid Build Coastguard Worker return gm 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard WorkerGLOBAL_TEST_SUBCLASSES = { 404*da0073e9SAndroid Build Coastguard Worker MockSubclass, 405*da0073e9SAndroid Build Coastguard Worker DummyNDim, 406*da0073e9SAndroid Build Coastguard Worker SigmoidToExpSubclass, 407*da0073e9SAndroid Build Coastguard Worker BaseTorchFunction, 408*da0073e9SAndroid Build Coastguard Worker} 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker# Returns True if the function recompiles between inputs1 and inputs2 with the 412*da0073e9SAndroid Build Coastguard Worker# specified dynamic setting. 413*da0073e9SAndroid Build Coastguard Workerdef _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): 414*da0073e9SAndroid Build Coastguard Worker compile_count = [0] 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def counter(gm, example_inputs): 417*da0073e9SAndroid Build Coastguard Worker compile_count[0] += 1 418*da0073e9SAndroid Build Coastguard Worker return gm 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) 421*da0073e9SAndroid Build Coastguard Worker compiled_f(*inputs1) 422*da0073e9SAndroid Build Coastguard Worker compiled_f(*inputs2) 423*da0073e9SAndroid Build Coastguard Worker return compile_count[0] > 1 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Workerclass SubclassTests(torch._dynamo.test_case.TestCase): 427*da0073e9SAndroid Build Coastguard Worker @classmethod 428*da0073e9SAndroid Build Coastguard Worker def setUpClass(cls): 429*da0073e9SAndroid Build Coastguard Worker super().setUpClass() 430*da0073e9SAndroid Build Coastguard Worker cls._exit_stack.enter_context( 431*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.patch( 432*da0073e9SAndroid Build Coastguard Worker "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES 433*da0073e9SAndroid Build Coastguard Worker ) 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker @classmethod 437*da0073e9SAndroid Build Coastguard Worker def tearDownClass(cls): 438*da0073e9SAndroid Build Coastguard Worker cls._exit_stack.close() 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 441*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker def test_no_call_to_new(self): 444*da0073e9SAndroid Build Coastguard Worker class BadNewTorchFunction(torch.Tensor): 445*da0073e9SAndroid Build Coastguard Worker def __new__(cls, *args, **kwargs): 446*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Oops!") 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker @classmethod 449*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 450*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 451*da0073e9SAndroid Build Coastguard Worker kwargs = {} 452*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 455*da0073e9SAndroid Build Coastguard Worker "traceable_tensor_subclasses", {BadNewTorchFunction} 456*da0073e9SAndroid Build Coastguard Worker ): 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 459*da0073e9SAndroid Build Coastguard Worker def fn(x): 460*da0073e9SAndroid Build Coastguard Worker return torch.add(x, 1) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) 463*da0073e9SAndroid Build Coastguard Worker 464*da0073e9SAndroid Build Coastguard Worker res = fn(input) 465*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, BadNewTorchFunction) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def test_no_torch_function_recompiles(self): 468*da0073e9SAndroid Build Coastguard Worker class NJT: 469*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 470*da0073e9SAndroid Build Coastguard Worker return f"NJT(shape={self.shape})" 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker def __init__(self, values, offsets): 473*da0073e9SAndroid Build Coastguard Worker self._values = values 474*da0073e9SAndroid Build Coastguard Worker self._offsets = offsets 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker def sin(self): 477*da0073e9SAndroid Build Coastguard Worker return torch.sin(self) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker @classmethod 480*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 481*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 482*da0073e9SAndroid Build Coastguard Worker kwargs = {} 483*da0073e9SAndroid Build Coastguard Worker if func == torch.sin: 484*da0073e9SAndroid Build Coastguard Worker self = args[0] 485*da0073e9SAndroid Build Coastguard Worker return NJT(func(self._values), self._offsets) 486*da0073e9SAndroid Build Coastguard Worker raise AssertionError("should not get here") 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker values1 = torch.randn(10, 3, 4, requires_grad=True) 489*da0073e9SAndroid Build Coastguard Worker values2 = torch.randn(10, 3, 4, requires_grad=True) 490*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 3, 10]) 491*da0073e9SAndroid Build Coastguard Worker njt1 = NJT(values1, offsets) 492*da0073e9SAndroid Build Coastguard Worker njt2 = NJT(values2, offsets) 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 495*da0073e9SAndroid Build Coastguard Worker def f(x): 496*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 499*da0073e9SAndroid Build Coastguard Worker f(njt1) 500*da0073e9SAndroid Build Coastguard Worker f(njt2) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker def test_base_torch_function_tracing(self): 503*da0073e9SAndroid Build Coastguard Worker def fn(x): 504*da0073e9SAndroid Build Coastguard Worker return torch.add(x, 1) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2).as_subclass(BaseTorchFunction) 507*da0073e9SAndroid Build Coastguard Worker out = fn(input) 508*da0073e9SAndroid Build Coastguard Worker out_opt = compile_full_eager(fn)(input) 509*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out, BaseTorchFunction) 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_opt) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def test_torch_function_state_graph_break(self): 513*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 514*da0073e9SAndroid Build Coastguard Worker def fn(x): 515*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 516*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 517*da0073e9SAndroid Build Coastguard Worker return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 520*da0073e9SAndroid Build Coastguard Worker res, _ = fn(input) 521*da0073e9SAndroid Build Coastguard Worker self.assertFalse(res) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker def test_torch_function_state_nested(self): 524*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 525*da0073e9SAndroid Build Coastguard Worker def fn(x): 526*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 527*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 528*da0073e9SAndroid Build Coastguard Worker x = x + 1 529*da0073e9SAndroid Build Coastguard Worker # Should reset to the outer state (disabled) after exiting ctx manager 530*da0073e9SAndroid Build Coastguard Worker return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 533*da0073e9SAndroid Build Coastguard Worker res, _ = fn(input) 534*da0073e9SAndroid Build Coastguard Worker self.assertFalse(res) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker def test_torch_function_state_tracing(self): 537*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 538*da0073e9SAndroid Build Coastguard Worker def fn(x): 539*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 540*da0073e9SAndroid Build Coastguard Worker torch.add(x, 1.0) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 543*da0073e9SAndroid Build Coastguard Worker 544*da0073e9SAndroid Build Coastguard Worker res = fn(input) 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker def test_torch_function_state_guards(self): 547*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 548*da0073e9SAndroid Build Coastguard Worker 549*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 550*da0073e9SAndroid Build Coastguard Worker def fn(x): 551*da0073e9SAndroid Build Coastguard Worker torch.add(x, 1.0) 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 556*da0073e9SAndroid Build Coastguard Worker res = fn(input) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker res = fn(input) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def test_return_subclass(self): 563*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 564*da0073e9SAndroid Build Coastguard Worker def fn(x): 565*da0073e9SAndroid Build Coastguard Worker return MockSubclass(torch.add(x, 1.0)) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker res = fn(input) 570*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, MockSubclass) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker def test_return_as_subclass(self): 573*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 574*da0073e9SAndroid Build Coastguard Worker def fn(x): 575*da0073e9SAndroid Build Coastguard Worker return torch.add(x, 1.0).as_subclass(MockSubclass) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker res = fn(input) 580*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, MockSubclass) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker def test_return_local_subclass(self): 583*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 584*da0073e9SAndroid Build Coastguard Worker @classmethod 585*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 586*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 587*da0073e9SAndroid Build Coastguard Worker kwargs = {} 588*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 593*da0073e9SAndroid Build Coastguard Worker def fn(x): 594*da0073e9SAndroid Build Coastguard Worker return LocalSubclass(torch.add(x, 1.0)) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker res = fn(input) 599*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(res, LocalSubclass) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker def test_torch_function_list_args(self): 602*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS = {} 603*da0073e9SAndroid Build Coastguard Worker 604*da0073e9SAndroid Build Coastguard Worker class MyClass: 605*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo): 606*da0073e9SAndroid Build Coastguard Worker self.foo = foo 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker @classmethod 609*da0073e9SAndroid Build Coastguard Worker def __torch_function__( 610*da0073e9SAndroid Build Coastguard Worker cls, 611*da0073e9SAndroid Build Coastguard Worker func, 612*da0073e9SAndroid Build Coastguard Worker types, 613*da0073e9SAndroid Build Coastguard Worker args=(), 614*da0073e9SAndroid Build Coastguard Worker kwargs=None, 615*da0073e9SAndroid Build Coastguard Worker ): 616*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 617*da0073e9SAndroid Build Coastguard Worker kwargs = {} 618*da0073e9SAndroid Build Coastguard Worker if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 619*da0073e9SAndroid Build Coastguard Worker [ # noqa: C419 620*da0073e9SAndroid Build Coastguard Worker issubclass(t, (torch.Tensor, MyClass)) for t in types 621*da0073e9SAndroid Build Coastguard Worker ] 622*da0073e9SAndroid Build Coastguard Worker ): 623*da0073e9SAndroid Build Coastguard Worker return NotImplemented 624*da0073e9SAndroid Build Coastguard Worker return HANDLED_FUNCTIONS[func](*args, **kwargs) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker def _stack(input, dim=0, *, out=None): 627*da0073e9SAndroid Build Coastguard Worker return MyClass(sum([x.foo for x in input])) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker HANDLED_FUNCTIONS[torch.stack] = _stack 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 632*da0073e9SAndroid Build Coastguard Worker def fn(v0, v1): 633*da0073e9SAndroid Build Coastguard Worker return torch.stack([v0, v1]) 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker ret = fn(MyClass(1), MyClass(1)) 636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret.foo, 2) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker @parametrize( 639*da0073e9SAndroid Build Coastguard Worker "comparison", 640*da0073e9SAndroid Build Coastguard Worker [ 641*da0073e9SAndroid Build Coastguard Worker subtest(isinstance, "isinstance"), 642*da0073e9SAndroid Build Coastguard Worker subtest(lambda instance, type_: type(instance) == type_, "equality"), 643*da0073e9SAndroid Build Coastguard Worker subtest(lambda instance, type_: type(instance) is type_, "identity"), 644*da0073e9SAndroid Build Coastguard Worker ], 645*da0073e9SAndroid Build Coastguard Worker ) 646*da0073e9SAndroid Build Coastguard Worker @parametrize( 647*da0073e9SAndroid Build Coastguard Worker "input_type", 648*da0073e9SAndroid Build Coastguard Worker [ 649*da0073e9SAndroid Build Coastguard Worker subtest(torch.Tensor, "tensor"), 650*da0073e9SAndroid Build Coastguard Worker subtest(DummyNDim, "subclass"), 651*da0073e9SAndroid Build Coastguard Worker ], 652*da0073e9SAndroid Build Coastguard Worker ) 653*da0073e9SAndroid Build Coastguard Worker def test_type_check(self, comparison, input_type): 654*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker def fn(x): 657*da0073e9SAndroid Build Coastguard Worker if comparison(x, DummyNDim): 658*da0073e9SAndroid Build Coastguard Worker return torch.ones(1, 1) 659*da0073e9SAndroid Build Coastguard Worker else: 660*da0073e9SAndroid Build Coastguard Worker return torch.zeros(2, 2) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2).as_subclass(input_type) 663*da0073e9SAndroid Build Coastguard Worker exp_res = fn(input) 664*da0073e9SAndroid Build Coastguard Worker act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) 665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(exp_res, act_res) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker def test_torch_function_call_on_method(self): 668*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 669*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2, 2) 670*da0073e9SAndroid Build Coastguard Worker z = torch.ones(2, 2) 671*da0073e9SAndroid Build Coastguard Worker wrapped = x.as_subclass(SigmoidToExpSubclass) 672*da0073e9SAndroid Build Coastguard Worker wrapped2 = y.as_subclass(SigmoidToExpSubclass) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker def fn(w): 675*da0073e9SAndroid Build Coastguard Worker return w.sigmoid() 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker fn_opt = compile_full_eager(fn) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker res_exp = fn(wrapped) 680*da0073e9SAndroid Build Coastguard Worker res_act = fn_opt(wrapped2) 681*da0073e9SAndroid Build Coastguard Worker res_exp2 = z.exp() 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_act) 684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_exp2) 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker def test_user_overidden_method_unsupported(self): 687*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 688*da0073e9SAndroid Build Coastguard Worker @classmethod 689*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 690*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 691*da0073e9SAndroid Build Coastguard Worker kwargs = {} 692*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def sigmoid(self): 695*da0073e9SAndroid Build Coastguard Worker return None 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 698*da0073e9SAndroid Build Coastguard Worker def fn(x): 699*da0073e9SAndroid Build Coastguard Worker x.sigmoid() 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker msg = ( 702*da0073e9SAndroid Build Coastguard Worker "Accessing overridden method/attribute sigmoid on a tensor" 703*da0073e9SAndroid Build Coastguard Worker " subclass with a __torch_function__ override is not supported" 704*da0073e9SAndroid Build Coastguard Worker ) 705*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 706*da0073e9SAndroid Build Coastguard Worker "traceable_tensor_subclasses", {LocalSubclass} 707*da0073e9SAndroid Build Coastguard Worker ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 708*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2).as_subclass(LocalSubclass) 709*da0073e9SAndroid Build Coastguard Worker fn(x) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def test_user_overidden_attr_unsupported(self): 712*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 713*da0073e9SAndroid Build Coastguard Worker @classmethod 714*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 715*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 716*da0073e9SAndroid Build Coastguard Worker kwargs = {} 717*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker ndim = 10 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 722*da0073e9SAndroid Build Coastguard Worker def fn(x): 723*da0073e9SAndroid Build Coastguard Worker return x.ndim 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker msg = ( 726*da0073e9SAndroid Build Coastguard Worker "Accessing overridden method/attribute ndim on a tensor" 727*da0073e9SAndroid Build Coastguard Worker " subclass with a __torch_function__ override is not supported" 728*da0073e9SAndroid Build Coastguard Worker ) 729*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 730*da0073e9SAndroid Build Coastguard Worker "traceable_tensor_subclasses", {LocalSubclass} 731*da0073e9SAndroid Build Coastguard Worker ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 732*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2).as_subclass(LocalSubclass) 733*da0073e9SAndroid Build Coastguard Worker fn(x) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def test_user_overidden_property_unsupported(self): 736*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 737*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 738*da0073e9SAndroid Build Coastguard Worker self._ndim = 10 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker @classmethod 741*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 742*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 743*da0073e9SAndroid Build Coastguard Worker kwargs = {} 744*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker @property 747*da0073e9SAndroid Build Coastguard Worker def ndim(self): 748*da0073e9SAndroid Build Coastguard Worker return self._ndim 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker @ndim.setter 751*da0073e9SAndroid Build Coastguard Worker def ndim(self, value): 752*da0073e9SAndroid Build Coastguard Worker self._ndim = value 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 755*da0073e9SAndroid Build Coastguard Worker def fn(x): 756*da0073e9SAndroid Build Coastguard Worker return x.ndim 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker msg = ( 759*da0073e9SAndroid Build Coastguard Worker "Accessing overridden method/attribute ndim on a tensor" 760*da0073e9SAndroid Build Coastguard Worker " subclass with a __torch_function__ override is not supported" 761*da0073e9SAndroid Build Coastguard Worker ) 762*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 763*da0073e9SAndroid Build Coastguard Worker "traceable_tensor_subclasses", {LocalSubclass} 764*da0073e9SAndroid Build Coastguard Worker ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): 765*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2).as_subclass(LocalSubclass) 766*da0073e9SAndroid Build Coastguard Worker fn(x) 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker def test_overridden_method_guarding(self): 769*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 770*da0073e9SAndroid Build Coastguard Worker @classmethod 771*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 772*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 773*da0073e9SAndroid Build Coastguard Worker kwargs = {} 774*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 777*da0073e9SAndroid Build Coastguard Worker def fn(x): 778*da0073e9SAndroid Build Coastguard Worker return x.sigmoid() 779*da0073e9SAndroid Build Coastguard Worker 780*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 781*da0073e9SAndroid Build Coastguard Worker error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} 782*da0073e9SAndroid Build Coastguard Worker ): 783*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2).as_subclass(LocalSubclass) 784*da0073e9SAndroid Build Coastguard Worker fn(x) 785*da0073e9SAndroid Build Coastguard Worker fn(x) 786*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2).as_subclass(LocalSubclass) 787*da0073e9SAndroid Build Coastguard Worker fn(x) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch( 790*da0073e9SAndroid Build Coastguard Worker traceable_tensor_subclasses={LocalSubclass} 791*da0073e9SAndroid Build Coastguard Worker ), self.assertRaisesRegex( 792*da0073e9SAndroid Build Coastguard Worker TypeError, 793*da0073e9SAndroid Build Coastguard Worker "'bool' object is not callable", 794*da0073e9SAndroid Build Coastguard Worker ): 795*da0073e9SAndroid Build Coastguard Worker LocalSubclass.sigmoid = False 796*da0073e9SAndroid Build Coastguard Worker fn(x) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker def test_torch_function_call_on_attr(self): 799*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 800*da0073e9SAndroid Build Coastguard Worker wrapped = x.as_subclass(DummyNDim) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def fn(w): 803*da0073e9SAndroid Build Coastguard Worker return w.ndim + torch.ones(2) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker fn_opt = compile_full_eager(fn) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker res_exp = fn(wrapped) 808*da0073e9SAndroid Build Coastguard Worker res_act = fn_opt(wrapped) 809*da0073e9SAndroid Build Coastguard Worker 810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_act) 811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, torch.ones(2) + 10) 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker def test_torch_function_wrapper_class(self): 814*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 815*da0073e9SAndroid Build Coastguard Worker wrapped = WrapperSubclass(x) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def fn(w): 818*da0073e9SAndroid Build Coastguard Worker return torch.add(w, 1.0) 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker fn_opt = compile_full_eager(fn) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker res_exp = fn(wrapped) 823*da0073e9SAndroid Build Coastguard Worker res_act = fn_opt(wrapped) 824*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_act) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker def test_torch_function_wrapper_class_with_kwargs(self): 827*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 828*da0073e9SAndroid Build Coastguard Worker wrapped = WrapperSubclass(x) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker def fn(w): 831*da0073e9SAndroid Build Coastguard Worker return torch.add(w, 1.0, alpha=2.0) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker fn_opt = compile_full_eager(fn) 834*da0073e9SAndroid Build Coastguard Worker 835*da0073e9SAndroid Build Coastguard Worker res_exp = fn(wrapped) 836*da0073e9SAndroid Build Coastguard Worker res_act = fn_opt(wrapped) 837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_act) 838*da0073e9SAndroid Build Coastguard Worker 839*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_custom_attr(self): 840*da0073e9SAndroid Build Coastguard Worker class AttrSubclass(torch.Tensor): 841*da0073e9SAndroid Build Coastguard Worker x: int = 10 842*da0073e9SAndroid Build Coastguard Worker 843*da0073e9SAndroid Build Coastguard Worker @classmethod 844*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 845*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 846*da0073e9SAndroid Build Coastguard Worker kwargs = {} 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker return super().__torch_function__(func, types, args, kwargs) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 851*da0073e9SAndroid Build Coastguard Worker def fn(x): 852*da0073e9SAndroid Build Coastguard Worker return x.x + torch.ones(2, 2) 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker with traceable_subclass(AttrSubclass): 855*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2).as_subclass(AttrSubclass) 856*da0073e9SAndroid Build Coastguard Worker fn_opt = compile_full_eager(fn) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker res_exp = fn(input) 859*da0073e9SAndroid Build Coastguard Worker res_act = fn_opt(input) 860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_exp, res_act) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def test_compile_with_fake_tensor_dynamic_dim(self): 863*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 4]) 864*da0073e9SAndroid Build Coastguard Worker 865*da0073e9SAndroid Build Coastguard Worker def f(x): 866*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 867*da0073e9SAndroid Build Coastguard Worker 868*da0073e9SAndroid Build Coastguard Worker def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): 869*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 870*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=cnt, fullgraph=True) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand_like(x) 875*da0073e9SAndroid Build Coastguard Worker f(x) 876*da0073e9SAndroid Build Coastguard Worker f(torch.randn([4, 3])) 877*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 878*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode( 879*da0073e9SAndroid Build Coastguard Worker shape_env=shape_env 880*da0073e9SAndroid Build Coastguard Worker ) as fake_mode: 881*da0073e9SAndroid Build Coastguard Worker x_fake = fake_mode.from_tensor( 882*da0073e9SAndroid Build Coastguard Worker x, 883*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 884*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[dim_dynamic for i in range(x.dim())] 885*da0073e9SAndroid Build Coastguard Worker ), 886*da0073e9SAndroid Build Coastguard Worker ) 887*da0073e9SAndroid Build Coastguard Worker x1_fake = fake_mode.from_tensor( 888*da0073e9SAndroid Build Coastguard Worker x1, 889*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 890*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[dim_dynamic for i in range(x.dim())] 891*da0073e9SAndroid Build Coastguard Worker ), 892*da0073e9SAndroid Build Coastguard Worker ) 893*da0073e9SAndroid Build Coastguard Worker opt_f(x_fake) 894*da0073e9SAndroid Build Coastguard Worker opt_f(x1_fake) 895*da0073e9SAndroid Build Coastguard Worker 896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count) 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, exp_op_count) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) 900*da0073e9SAndroid Build Coastguard Worker test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) 901*da0073e9SAndroid Build Coastguard Worker test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker def test_compile_with_fake_tensor_automatic_dynamic(self): 904*da0073e9SAndroid Build Coastguard Worker def f(x): 905*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): 908*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 909*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 910*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend=cnt, fullgraph=True) 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 913*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode( 914*da0073e9SAndroid Build Coastguard Worker shape_env=shape_env 915*da0073e9SAndroid Build Coastguard Worker ) as fake_mode: 916*da0073e9SAndroid Build Coastguard Worker for inp in inps: 917*da0073e9SAndroid Build Coastguard Worker fake_inp = fake_mode.from_tensor( 918*da0073e9SAndroid Build Coastguard Worker inp, 919*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 920*da0073e9SAndroid Build Coastguard Worker [dim_dynamic for i in range(x.dim())] 921*da0073e9SAndroid Build Coastguard Worker ), 922*da0073e9SAndroid Build Coastguard Worker ) 923*da0073e9SAndroid Build Coastguard Worker opt_f(fake_inp) 924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count) 925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, exp_op_count) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 4]) 928*da0073e9SAndroid Build Coastguard Worker y = torch.randn([4, 5]) 929*da0073e9SAndroid Build Coastguard Worker z = torch.randn([5, 6]) 930*da0073e9SAndroid Build Coastguard Worker a = torch.randn([3, 5]) 931*da0073e9SAndroid Build Coastguard Worker b = torch.randn([4, 4]) 932*da0073e9SAndroid Build Coastguard Worker # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs 933*da0073e9SAndroid Build Coastguard Worker # to opt_f will be tensors with SymInt sizes. Dynamo will treat input 934*da0073e9SAndroid Build Coastguard Worker # as dynamic automatically and will only compile once 935*da0073e9SAndroid Build Coastguard Worker for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: 936*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) 937*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) 938*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) 939*da0073e9SAndroid Build Coastguard Worker 940*da0073e9SAndroid Build Coastguard Worker for dim_dynamic in [DimDynamic.STATIC]: 941*da0073e9SAndroid Build Coastguard Worker # Recompile once, first with dim 0 and 1 become Dynamic 942*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) 943*da0073e9SAndroid Build Coastguard Worker # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. 944*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) 945*da0073e9SAndroid Build Coastguard Worker # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. 946*da0073e9SAndroid Build Coastguard Worker test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker def test_compile_with_functionalization(self): 949*da0073e9SAndroid Build Coastguard Worker x = torch.randn([3, 4]) 950*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() 951*da0073e9SAndroid Build Coastguard Worker x_clone2 = x.clone() 952*da0073e9SAndroid Build Coastguard Worker backend = EagerRecordGraphAndInputs() 953*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 956*da0073e9SAndroid Build Coastguard Worker def f(x): 957*da0073e9SAndroid Build Coastguard Worker return x.add_(1.0) + torch.nn.functional.relu_(x) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker f_out = f(x) 960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.graphs), 1) 963*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.example_inputs), 1) 964*da0073e9SAndroid Build Coastguard Worker 965*da0073e9SAndroid Build Coastguard Worker actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) 966*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 967*da0073e9SAndroid Build Coastguard Worker actual, 968*da0073e9SAndroid Build Coastguard Worker """\ 969*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 970*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 971*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0) 974*da0073e9SAndroid Build Coastguard Worker relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 975*da0073e9SAndroid Build Coastguard Worker add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 976*da0073e9SAndroid Build Coastguard Worker return (add,) 977*da0073e9SAndroid Build Coastguard Worker""", 978*da0073e9SAndroid Build Coastguard Worker ) 979*da0073e9SAndroid Build Coastguard Worker 980*da0073e9SAndroid Build Coastguard Worker ff = torch.func.functionalize(f) 981*da0073e9SAndroid Build Coastguard Worker ff_out = ff(x_clone) 982*da0073e9SAndroid Build Coastguard Worker 983*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 6) 985*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.graphs), 2) 986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.example_inputs), 2) 987*da0073e9SAndroid Build Coastguard Worker actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) 988*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 989*da0073e9SAndroid Build Coastguard Worker actual, 990*da0073e9SAndroid Build Coastguard Worker """\ 991*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 992*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 993*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0) 996*da0073e9SAndroid Build Coastguard Worker relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 997*da0073e9SAndroid Build Coastguard Worker add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 998*da0073e9SAndroid Build Coastguard Worker return (add,) 999*da0073e9SAndroid Build Coastguard Worker""", 1000*da0073e9SAndroid Build Coastguard Worker ) 1001*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. 1004*da0073e9SAndroid Build Coastguard Worker def to_fun(x): 1005*da0073e9SAndroid Build Coastguard Worker x_functional = torch._to_functional_tensor(x) 1006*da0073e9SAndroid Build Coastguard Worker torch._mirror_autograd_meta_to(x, x_functional) 1007*da0073e9SAndroid Build Coastguard Worker return x_functional 1008*da0073e9SAndroid Build Coastguard Worker 1009*da0073e9SAndroid Build Coastguard Worker def aot_f_wrapper(func): 1010*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 1011*da0073e9SAndroid Build Coastguard Worker def wrapper(*args, **kwargs): 1012*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=False) 1013*da0073e9SAndroid Build Coastguard Worker try: 1014*da0073e9SAndroid Build Coastguard Worker func_args = pytree.tree_map(to_fun, args) 1015*da0073e9SAndroid Build Coastguard Worker func_kwargs = pytree.tree_map(to_fun, kwargs) 1016*da0073e9SAndroid Build Coastguard Worker return func(*func_args, **func_kwargs) 1017*da0073e9SAndroid Build Coastguard Worker finally: 1018*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker return wrapper 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker aot_ff = aot_f_wrapper(f) 1023*da0073e9SAndroid Build Coastguard Worker aot_ff_out = aot_ff(x_clone2) 1024*da0073e9SAndroid Build Coastguard Worker 1025*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 1026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 9) 1027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.graphs), 3) 1028*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.example_inputs), 3) 1029*da0073e9SAndroid Build Coastguard Worker actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) 1030*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1031*da0073e9SAndroid Build Coastguard Worker actual, 1032*da0073e9SAndroid Build Coastguard Worker """\ 1033*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1034*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 1035*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0) 1038*da0073e9SAndroid Build Coastguard Worker relu_: "f32[3, 4]" = torch.relu_(l_x_); l_x_ = None 1039*da0073e9SAndroid Build Coastguard Worker add: "f32[3, 4]" = add_ + relu_; add_ = relu_ = None 1040*da0073e9SAndroid Build Coastguard Worker return (add,) 1041*da0073e9SAndroid Build Coastguard Worker""", 1042*da0073e9SAndroid Build Coastguard Worker ) 1043*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_out, ff_out) 1046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_out, aot_ff_out) 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker try: 1049*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=False) 1050*da0073e9SAndroid Build Coastguard Worker xf = pytree.tree_map(to_fun, x) 1051*da0073e9SAndroid Build Coastguard Worker x_view = xf.t() 1052*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): 1053*da0073e9SAndroid Build Coastguard Worker f(x_view) 1054*da0073e9SAndroid Build Coastguard Worker finally: 1055*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker def test_compile_higher_order_with_functionalization(self): 1058*da0073e9SAndroid Build Coastguard Worker backend = EagerRecordGraphAndInputs() 1059*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 1062*da0073e9SAndroid Build Coastguard Worker def f(x): 1063*da0073e9SAndroid Build Coastguard Worker return wrap(lambda x: x.add_(1.0), x) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker def check_count_and_graph( 1066*da0073e9SAndroid Build Coastguard Worker exp_frame_count, exp_op_count, exp_n_graph, exp_graph 1067*da0073e9SAndroid Build Coastguard Worker ): 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count) 1069*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, exp_op_count) 1070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(backend.graphs), exp_n_graph) 1071*da0073e9SAndroid Build Coastguard Worker actual = normalize_gm( 1072*da0073e9SAndroid Build Coastguard Worker backend.graphs[exp_n_graph - 1].print_readable(print_output=False) 1073*da0073e9SAndroid Build Coastguard Worker ) 1074*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(actual, exp_graph, skip=1) 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker t = torch.randn([3, 4]) 1077*da0073e9SAndroid Build Coastguard Worker t_clone = t.clone() 1078*da0073e9SAndroid Build Coastguard Worker t_clone2 = t.clone() 1079*da0073e9SAndroid Build Coastguard Worker f(t) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker check_count_and_graph( 1082*da0073e9SAndroid Build Coastguard Worker 1, 1083*da0073e9SAndroid Build Coastguard Worker 2, 1084*da0073e9SAndroid Build Coastguard Worker 1, 1085*da0073e9SAndroid Build Coastguard Worker """\ 1086*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1087*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 1088*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker wrap_body_0 = self.wrap_body_0 1091*da0073e9SAndroid Build Coastguard Worker wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1092*da0073e9SAndroid Build Coastguard Worker getitem: "f32[3, 4]" = wrap[0]; wrap = None 1093*da0073e9SAndroid Build Coastguard Worker return (getitem,) 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker class wrap_body_0(torch.nn.Module): 1096*da0073e9SAndroid Build Coastguard Worker def forward(self, l_x_: "f32[3, 4]"): 1097*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1098*da0073e9SAndroid Build Coastguard Worker return (add_,) 1099*da0073e9SAndroid Build Coastguard Worker""", 1100*da0073e9SAndroid Build Coastguard Worker ) 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker ff = torch.func.functionalize(f) 1103*da0073e9SAndroid Build Coastguard Worker ff_out = ff(t_clone) 1104*da0073e9SAndroid Build Coastguard Worker # frame count and op count are incremented due to re-compilation 1105*da0073e9SAndroid Build Coastguard Worker check_count_and_graph( 1106*da0073e9SAndroid Build Coastguard Worker 2, 1107*da0073e9SAndroid Build Coastguard Worker 4, 1108*da0073e9SAndroid Build Coastguard Worker 2, 1109*da0073e9SAndroid Build Coastguard Worker """\ 1110*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1111*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 1112*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker wrap_body_0 = self.wrap_body_0 1115*da0073e9SAndroid Build Coastguard Worker wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1116*da0073e9SAndroid Build Coastguard Worker getitem: "f32[3, 4]" = wrap[0]; wrap = None 1117*da0073e9SAndroid Build Coastguard Worker return (getitem,) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker class wrap_body_0(torch.nn.Module): 1120*da0073e9SAndroid Build Coastguard Worker def forward(self, l_x_: "f32[3, 4]"): 1121*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1122*da0073e9SAndroid Build Coastguard Worker return (add_,) 1123*da0073e9SAndroid Build Coastguard Worker""", 1124*da0073e9SAndroid Build Coastguard Worker ) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker try: 1127*da0073e9SAndroid Build Coastguard Worker x = torch._to_functional_tensor(t_clone2) 1128*da0073e9SAndroid Build Coastguard Worker torch._mirror_autograd_meta_to(t_clone2, x) 1129*da0073e9SAndroid Build Coastguard Worker torch._enable_functionalization(reapply_views=False) 1130*da0073e9SAndroid Build Coastguard Worker aot_f_out = f(x) 1131*da0073e9SAndroid Build Coastguard Worker finally: 1132*da0073e9SAndroid Build Coastguard Worker torch._disable_functionalization() 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker # frame count and op count are incremented due to re-compilation 1135*da0073e9SAndroid Build Coastguard Worker check_count_and_graph( 1136*da0073e9SAndroid Build Coastguard Worker 3, 1137*da0073e9SAndroid Build Coastguard Worker 6, 1138*da0073e9SAndroid Build Coastguard Worker 3, 1139*da0073e9SAndroid Build Coastguard Worker """\ 1140*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1141*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x_: "f32[3, 4]"): 1142*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 1143*da0073e9SAndroid Build Coastguard Worker 1144*da0073e9SAndroid Build Coastguard Worker wrap_body_0 = self.wrap_body_0 1145*da0073e9SAndroid Build Coastguard Worker wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None 1146*da0073e9SAndroid Build Coastguard Worker getitem: "f32[3, 4]" = wrap[0]; wrap = None 1147*da0073e9SAndroid Build Coastguard Worker return (getitem,) 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker class wrap_body_0(torch.nn.Module): 1150*da0073e9SAndroid Build Coastguard Worker def forward(self, l_x_: "f32[3, 4]"): 1151*da0073e9SAndroid Build Coastguard Worker add_: "f32[3, 4]" = l_x_.add_(1.0); l_x_ = None 1152*da0073e9SAndroid Build Coastguard Worker return (add_,) 1153*da0073e9SAndroid Build Coastguard Worker""", 1154*da0073e9SAndroid Build Coastguard Worker ) 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker def test_has_torch_function(self): 1157*da0073e9SAndroid Build Coastguard Worker class MyTensor: 1158*da0073e9SAndroid Build Coastguard Worker @classmethod 1159*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1160*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1161*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1162*da0073e9SAndroid Build Coastguard Worker 1163*da0073e9SAndroid Build Coastguard Worker if func is torch.max: 1164*da0073e9SAndroid Build Coastguard Worker return torch.tensor(123) 1165*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker class LocalSubclass(torch.Tensor): 1168*da0073e9SAndroid Build Coastguard Worker @classmethod 1169*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1170*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1171*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1172*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1173*da0073e9SAndroid Build Coastguard Worker 1174*da0073e9SAndroid Build Coastguard Worker def fn(x): 1175*da0073e9SAndroid Build Coastguard Worker return torch.overrides.has_torch_function_unary( 1176*da0073e9SAndroid Build Coastguard Worker x 1177*da0073e9SAndroid Build Coastguard Worker ), torch.overrides.has_torch_function_variadic(x) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker for test_class in [MyTensor, LocalSubclass]: 1180*da0073e9SAndroid Build Coastguard Worker x = test_class() 1181*da0073e9SAndroid Build Coastguard Worker ref0 = fn(x) 1182*da0073e9SAndroid Build Coastguard Worker ref1 = fn(4) 1183*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1184*da0073e9SAndroid Build Coastguard Worker res0 = opt_fn(x) 1185*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(4) 1186*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref0, res0) 1187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1, res1) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_guards_on_inner_tensor(self): 1190*da0073e9SAndroid Build Coastguard Worker # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. 1191*da0073e9SAndroid Build Coastguard Worker # Also adds additional guards on the inner tensor's sizes. 1192*da0073e9SAndroid Build Coastguard Worker # When the first input to an op has x.shape[0] > 5, we insert an extra add node. 1193*da0073e9SAndroid Build Coastguard Worker class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): 1194*da0073e9SAndroid Build Coastguard Worker @staticmethod 1195*da0073e9SAndroid Build Coastguard Worker def __new__(cls, inner): 1196*da0073e9SAndroid Build Coastguard Worker # Double the outer-most dimension 1197*da0073e9SAndroid Build Coastguard Worker outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] 1198*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 1199*da0073e9SAndroid Build Coastguard Worker # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 1200*da0073e9SAndroid Build Coastguard Worker # Calling the overload that has kwargs causes us to go down the first overload path, 1201*da0073e9SAndroid Build Coastguard Worker # which will **always** specialize sizes. 1202*da0073e9SAndroid Build Coastguard Worker # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 1203*da0073e9SAndroid Build Coastguard Worker cls, 1204*da0073e9SAndroid Build Coastguard Worker outer_shape, 1205*da0073e9SAndroid Build Coastguard Worker inner.stride(), 1206*da0073e9SAndroid Build Coastguard Worker None, 1207*da0073e9SAndroid Build Coastguard Worker None, 1208*da0073e9SAndroid Build Coastguard Worker inner.dtype, 1209*da0073e9SAndroid Build Coastguard Worker inner.layout, 1210*da0073e9SAndroid Build Coastguard Worker inner.device, 1211*da0073e9SAndroid Build Coastguard Worker False, 1212*da0073e9SAndroid Build Coastguard Worker inner.requires_grad, 1213*da0073e9SAndroid Build Coastguard Worker ) 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 1216*da0073e9SAndroid Build Coastguard Worker self.inner_elem = inner 1217*da0073e9SAndroid Build Coastguard Worker 1218*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 1219*da0073e9SAndroid Build Coastguard Worker return ["inner_elem"], None 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker @staticmethod 1222*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): 1223*da0073e9SAndroid Build Coastguard Worker return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1226*da0073e9SAndroid Build Coastguard Worker return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker @classmethod 1229*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1230*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1231*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1232*da0073e9SAndroid Build Coastguard Worker 1233*da0073e9SAndroid Build Coastguard Worker args_inner = torch.utils._pytree.tree_map_only( 1234*da0073e9SAndroid Build Coastguard Worker DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args 1235*da0073e9SAndroid Build Coastguard Worker ) 1236*da0073e9SAndroid Build Coastguard Worker out_inner = func(*args_inner, **kwargs) 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker # Add guards on the inner tensor's sizes 1239*da0073e9SAndroid Build Coastguard Worker if args_inner[0].shape[0] > 3: 1240*da0073e9SAndroid Build Coastguard Worker out_inner += 2 1241*da0073e9SAndroid Build Coastguard Worker 1242*da0073e9SAndroid Build Coastguard Worker return DoubleSizeMaybeAddGeThreeTensor(out_inner) 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker curr_var_to_val = None 1245*da0073e9SAndroid Build Coastguard Worker curr_var_to_sources = None 1246*da0073e9SAndroid Build Coastguard Worker guards = None 1247*da0073e9SAndroid Build Coastguard Worker 1248*da0073e9SAndroid Build Coastguard Worker def backend(gm, args): 1249*da0073e9SAndroid Build Coastguard Worker context = torch._guards.TracingContext.get() 1250*da0073e9SAndroid Build Coastguard Worker 1251*da0073e9SAndroid Build Coastguard Worker # Grab info on sources and guards from the shapeenv 1252*da0073e9SAndroid Build Coastguard Worker nonlocal curr_var_to_val 1253*da0073e9SAndroid Build Coastguard Worker nonlocal curr_var_to_sources 1254*da0073e9SAndroid Build Coastguard Worker nonlocal guards 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] 1257*da0073e9SAndroid Build Coastguard Worker curr_var_to_val = { 1258*da0073e9SAndroid Build Coastguard Worker str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() 1259*da0073e9SAndroid Build Coastguard Worker } 1260*da0073e9SAndroid Build Coastguard Worker curr_var_to_sources = { 1261*da0073e9SAndroid Build Coastguard Worker str(k): v[0].name() 1262*da0073e9SAndroid Build Coastguard Worker for k, v in context.fake_mode.shape_env.var_to_sources.items() 1263*da0073e9SAndroid Build Coastguard Worker } 1264*da0073e9SAndroid Build Coastguard Worker return gm 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=backend) 1267*da0073e9SAndroid Build Coastguard Worker def fn(x): 1268*da0073e9SAndroid Build Coastguard Worker if x.shape[0] < 13: 1269*da0073e9SAndroid Build Coastguard Worker return torch.mul(x, x) 1270*da0073e9SAndroid Build Coastguard Worker else: 1271*da0073e9SAndroid Build Coastguard Worker return torch.div(x, x) 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(4, 4) 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker x = DoubleSizeMaybeAddGeThreeTensor(inp) 1276*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(x, 0) 1277*da0073e9SAndroid Build Coastguard Worker res = fn(x) 1278*da0073e9SAndroid Build Coastguard Worker # During fakeifying, we end up allocating a separate symint 1279*da0073e9SAndroid Build Coastguard Worker # for the outer and inner tensor (in this test, s0 is unused). 1280*da0073e9SAndroid Build Coastguard Worker expected_var_to_val = { 1281*da0073e9SAndroid Build Coastguard Worker "s0": 8, 1282*da0073e9SAndroid Build Coastguard Worker "s1": 4, 1283*da0073e9SAndroid Build Coastguard Worker } 1284*da0073e9SAndroid Build Coastguard Worker expected_var_to_sources = { 1285*da0073e9SAndroid Build Coastguard Worker "s0": "L['x'].size()[0]", 1286*da0073e9SAndroid Build Coastguard Worker "s1": "L['x'].inner_elem.size()[0]", 1287*da0073e9SAndroid Build Coastguard Worker } 1288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(curr_var_to_val, expected_var_to_val) 1289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(curr_var_to_sources, expected_var_to_sources) 1290*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1291*da0073e9SAndroid Build Coastguard Worker "\n".join(guards), 1292*da0073e9SAndroid Build Coastguard Worker """\ 1293*da0073e9SAndroid Build Coastguard WorkerEq(2*s1, s0) 1294*da0073e9SAndroid Build Coastguard Worker2*s1 < 13 1295*da0073e9SAndroid Build Coastguard Workers1 > 3""", 1296*da0073e9SAndroid Build Coastguard Worker ) 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_with_same_sized_inner_tensor(self): 1299*da0073e9SAndroid Build Coastguard Worker # shouldn't recompile for different sizes when dynamic=True 1300*da0073e9SAndroid Build Coastguard Worker sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1301*da0073e9SAndroid Build Coastguard Worker sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) 1302*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) 1303*da0073e9SAndroid Build Coastguard Worker 1304*da0073e9SAndroid Build Coastguard Worker # should recompile for different data size when dynamic=False 1305*da0073e9SAndroid Build Coastguard Worker sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1306*da0073e9SAndroid Build Coastguard Worker sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) 1307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker # avoid recompile using manual mark_dynamic() for different data size 1310*da0073e9SAndroid Build Coastguard Worker sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) 1311*da0073e9SAndroid Build Coastguard Worker # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size 1312*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(sub1, 0) 1313*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(sub1, 1) 1314*da0073e9SAndroid Build Coastguard Worker sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) 1315*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_with_differently_sized_inner_tensor(self): 1318*da0073e9SAndroid Build Coastguard Worker # should recompile for different scale size when dynamic=False 1319*da0073e9SAndroid Build Coastguard Worker sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) 1320*da0073e9SAndroid Build Coastguard Worker sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) 1321*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker # still recompiles using manual mark_dynamic() on outer for different scale size 1324*da0073e9SAndroid Build Coastguard Worker sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) 1325*da0073e9SAndroid Build Coastguard Worker # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size 1326*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(sub1, 0) 1327*da0073e9SAndroid Build Coastguard Worker torch._dynamo.mark_dynamic(sub1, 1) 1328*da0073e9SAndroid Build Coastguard Worker sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) 1329*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker def test_recompiles_with_optional_inner_tensor(self): 1332*da0073e9SAndroid Build Coastguard Worker def f(x): 1333*da0073e9SAndroid Build Coastguard Worker return x + 1 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker # sub1 does not have the optional tensor specified while sub2 does 1336*da0073e9SAndroid Build Coastguard Worker sub1 = OptionalScaledTensor(torch.randn(2, 4), None) 1337*da0073e9SAndroid Build Coastguard Worker sub2 = OptionalScaledTensor(torch.randn(2, 4), torch.randn(2, 4)) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker # sanity check; don't recompile for same input 1340*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(f, (sub1,), (sub1,), dynamic=True)) 1341*da0073e9SAndroid Build Coastguard Worker self.assertFalse(_recompiles_for_inputs(f, (sub2,), (sub2,), dynamic=True)) 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker # these should recompile; optional tensor changes between specified and unspecified 1344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_recompiles_for_inputs(f, (sub1,), (sub2,), dynamic=True)) 1345*da0073e9SAndroid Build Coastguard Worker self.assertTrue(_recompiles_for_inputs(f, (sub2,), (sub1,), dynamic=True)) 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1348*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(sub1)._data, f_compiled(sub1)._data) 1349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(sub2)._data, f_compiled(sub2)._data) 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker def test_torch_dispatch_subclass_guard_recompile(self): 1352*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 1353*da0073e9SAndroid Build Coastguard Worker x_two = TwoTensor(x.clone(), x.clone()) 1354*da0073e9SAndroid Build Coastguard Worker 1355*da0073e9SAndroid Build Coastguard Worker def fn(w): 1356*da0073e9SAndroid Build Coastguard Worker return torch.add(w, 1.0) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(backend="eager")(fn) 1359*da0073e9SAndroid Build Coastguard Worker 1360*da0073e9SAndroid Build Coastguard Worker ref = fn(x_two) 1361*da0073e9SAndroid Build Coastguard Worker res = fn_opt(x_two) 1362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker # ensure no recompilation on same input type 1365*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 1366*da0073e9SAndroid Build Coastguard Worker fn_opt(TwoTensor(x + 1, x + 2)) 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker # recompile! 1369*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1370*da0073e9SAndroid Build Coastguard Worker res = fn_opt(x) 1371*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_ctx_guards(self): 1374*da0073e9SAndroid Build Coastguard Worker x = CtxSubclassTensor(torch.ones(2), 3) 1375*da0073e9SAndroid Build Coastguard Worker x2 = CtxSubclassTensor(torch.ones(2), 3) 1376*da0073e9SAndroid Build Coastguard Worker x3 = CtxSubclassTensor(torch.ones(2), 4) 1377*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) 1378*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_ctx_recursive_guards(self): 1381*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(2, 2) 1382*da0073e9SAndroid Build Coastguard Worker x1 = CtxSubclassTensor(x0.clone(), 2) 1383*da0073e9SAndroid Build Coastguard Worker x2 = CtxSubclassTensor(x0.clone(), 3) 1384*da0073e9SAndroid Build Coastguard Worker tt0 = TwoTensor(x0.clone(), x1) 1385*da0073e9SAndroid Build Coastguard Worker tt1 = TwoTensor(x0.clone(), x2) 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, lambda x: x * x, (tt0,), (tt1,), True) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_ctx_custom_guards_override(self): 1390*da0073e9SAndroid Build Coastguard Worker class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1391*da0073e9SAndroid Build Coastguard Worker @classmethod 1392*da0073e9SAndroid Build Coastguard Worker def __metadata_guard__(cls, orig_data, other): 1393*da0073e9SAndroid Build Coastguard Worker return orig_data[0] <= other[0] 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 2) 1396*da0073e9SAndroid Build Coastguard Worker x2 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1397*da0073e9SAndroid Build Coastguard Worker x3 = CtxSubclassTensorCustomGuardFn(torch.ones(2), 1) 1398*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, lambda x: x * x, (x,), (x2,), False) 1399*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, lambda x: x * x, (x,), (x3,), True) 1400*da0073e9SAndroid Build Coastguard Worker 1401*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_ctx_custom_guards_error_arg_num(self): 1402*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.exc 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1405*da0073e9SAndroid Build Coastguard Worker @classmethod 1406*da0073e9SAndroid Build Coastguard Worker def __metadata_guard__(cls, y): 1407*da0073e9SAndroid Build Coastguard Worker # Shouldn't reach here 1408*da0073e9SAndroid Build Coastguard Worker return False 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1411*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1412*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.InternalTorchDynamoError, 1413*da0073e9SAndroid Build Coastguard Worker "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments", 1414*da0073e9SAndroid Build Coastguard Worker lambda: torch.compile(lambda x: x * x)(x), 1415*da0073e9SAndroid Build Coastguard Worker ) 1416*da0073e9SAndroid Build Coastguard Worker 1417*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclass_ctx_custom_guards_error_not_classmethod(self): 1418*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.exc 1419*da0073e9SAndroid Build Coastguard Worker 1420*da0073e9SAndroid Build Coastguard Worker class CtxSubclassTensorCustomGuardFn(CtxSubclassTensor): 1421*da0073e9SAndroid Build Coastguard Worker def __metadata_guard__(self, x, y): 1422*da0073e9SAndroid Build Coastguard Worker return False 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker x = CtxSubclassTensorCustomGuardFn(torch.ones(2), 3) 1425*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 1426*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.InternalTorchDynamoError, 1427*da0073e9SAndroid Build Coastguard Worker "Tensor subclass method __metadata_guard__ must be a classmethod", 1428*da0073e9SAndroid Build Coastguard Worker lambda: torch.compile(lambda x: x * x)(x), 1429*da0073e9SAndroid Build Coastguard Worker ) 1430*da0073e9SAndroid Build Coastguard Worker 1431*da0073e9SAndroid Build Coastguard Worker def test_subclass_constructor_proxying(self): 1432*da0073e9SAndroid Build Coastguard Worker import dataclasses 1433*da0073e9SAndroid Build Coastguard Worker from collections import namedtuple 1434*da0073e9SAndroid Build Coastguard Worker from typing import Any 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass(frozen=True) 1437*da0073e9SAndroid Build Coastguard Worker class SubclassTensorArgs: 1438*da0073e9SAndroid Build Coastguard Worker original_shape: torch.Size 1439*da0073e9SAndroid Build Coastguard Worker device: torch.device 1440*da0073e9SAndroid Build Coastguard Worker inner_meta: Any 1441*da0073e9SAndroid Build Coastguard Worker 1442*da0073e9SAndroid Build Coastguard Worker SubclassTensorArgs2 = namedtuple( 1443*da0073e9SAndroid Build Coastguard Worker "SubclassTensorArgs2", 1444*da0073e9SAndroid Build Coastguard Worker [ 1445*da0073e9SAndroid Build Coastguard Worker "original_shape", 1446*da0073e9SAndroid Build Coastguard Worker "device", 1447*da0073e9SAndroid Build Coastguard Worker "inner_meta", 1448*da0073e9SAndroid Build Coastguard Worker ], 1449*da0073e9SAndroid Build Coastguard Worker ) 1450*da0073e9SAndroid Build Coastguard Worker 1451*da0073e9SAndroid Build Coastguard Worker class SubclassTensor(torch.Tensor): 1452*da0073e9SAndroid Build Coastguard Worker @staticmethod 1453*da0073e9SAndroid Build Coastguard Worker def __new__(cls, a, meta): 1454*da0073e9SAndroid Build Coastguard Worker shape = a.shape 1455*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1456*da0073e9SAndroid Build Coastguard Worker kwargs["strides"] = a.stride() 1457*da0073e9SAndroid Build Coastguard Worker kwargs["storage_offset"] = a.storage_offset() 1458*da0073e9SAndroid Build Coastguard Worker kwargs["device"] = a.device 1459*da0073e9SAndroid Build Coastguard Worker kwargs["layout"] = a.layout 1460*da0073e9SAndroid Build Coastguard Worker kwargs["requires_grad"] = a.requires_grad 1461*da0073e9SAndroid Build Coastguard Worker kwargs["dtype"] = a.dtype 1462*da0073e9SAndroid Build Coastguard Worker out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) 1463*da0073e9SAndroid Build Coastguard Worker return out 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, meta): 1466*da0073e9SAndroid Build Coastguard Worker self.a = a 1467*da0073e9SAndroid Build Coastguard Worker self.meta = meta 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1470*da0073e9SAndroid Build Coastguard Worker a_repr = repr(self.a) 1471*da0073e9SAndroid Build Coastguard Worker return f"SubclassTensor({a_repr})" 1472*da0073e9SAndroid Build Coastguard Worker 1473*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 1474*da0073e9SAndroid Build Coastguard Worker return ["a"], self.meta 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker @staticmethod 1477*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, meta, _, __): 1478*da0073e9SAndroid Build Coastguard Worker a = inner_tensors["a"] 1479*da0073e9SAndroid Build Coastguard Worker return SubclassTensor(a, meta) 1480*da0073e9SAndroid Build Coastguard Worker 1481*da0073e9SAndroid Build Coastguard Worker @classmethod 1482*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 1483*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1484*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1485*da0073e9SAndroid Build Coastguard Worker args_a = pytree.tree_map( 1486*da0073e9SAndroid Build Coastguard Worker lambda x: x.a if isinstance(x, SubclassTensor) else x, args 1487*da0073e9SAndroid Build Coastguard Worker ) 1488*da0073e9SAndroid Build Coastguard Worker kwargs_a = pytree.tree_map( 1489*da0073e9SAndroid Build Coastguard Worker lambda x: x.a if isinstance(x, SubclassTensor) else x, kwargs 1490*da0073e9SAndroid Build Coastguard Worker ) 1491*da0073e9SAndroid Build Coastguard Worker out_a = func(*args_a, **kwargs_a) 1492*da0073e9SAndroid Build Coastguard Worker out = pytree.tree_map( 1493*da0073e9SAndroid Build Coastguard Worker lambda x: SubclassTensor( 1494*da0073e9SAndroid Build Coastguard Worker x, SubclassTensorArgs2(x.shape, x.device, None) 1495*da0073e9SAndroid Build Coastguard Worker ) 1496*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Tensor) 1497*da0073e9SAndroid Build Coastguard Worker else x, 1498*da0073e9SAndroid Build Coastguard Worker out_a, 1499*da0073e9SAndroid Build Coastguard Worker ) 1500*da0073e9SAndroid Build Coastguard Worker return return_and_correct_aliasing(func, args, kwargs, out) 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 1503*da0073e9SAndroid Build Coastguard Worker def f1(x): 1504*da0073e9SAndroid Build Coastguard Worker meta = SubclassTensorArgs( 1505*da0073e9SAndroid Build Coastguard Worker x.shape, x.device, SubclassTensorArgs(x.shape, x.device, None) 1506*da0073e9SAndroid Build Coastguard Worker ) 1507*da0073e9SAndroid Build Coastguard Worker out = SubclassTensor(x, meta) 1508*da0073e9SAndroid Build Coastguard Worker return out * out 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 1511*da0073e9SAndroid Build Coastguard Worker f1(x) 1512*da0073e9SAndroid Build Coastguard Worker 1513*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 1514*da0073e9SAndroid Build Coastguard Worker def f1(x): 1515*da0073e9SAndroid Build Coastguard Worker meta = SubclassTensorArgs2( 1516*da0073e9SAndroid Build Coastguard Worker x.shape, x.device, SubclassTensorArgs2(x.shape, x.device, None) 1517*da0073e9SAndroid Build Coastguard Worker ) 1518*da0073e9SAndroid Build Coastguard Worker out = SubclassTensor(x, meta) 1519*da0073e9SAndroid Build Coastguard Worker return out * out 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 1522*da0073e9SAndroid Build Coastguard Worker f1(x) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker def test_torch_function_subclass_survives_into_aot_autograd(self): 1525*da0073e9SAndroid Build Coastguard Worker # If you have a tensor subclass that relies on dispatch into the same op 1526*da0073e9SAndroid Build Coastguard Worker # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), 1527*da0073e9SAndroid Build Coastguard Worker # the torch function-ness will survive into AOTAutograd. Today, NestedTensor 1528*da0073e9SAndroid Build Coastguard Worker # actually relies on this behavior! Because that torch function logic 1529*da0073e9SAndroid Build Coastguard Worker # runs during AOTAutograd, this test tests that there is no logic below 1530*da0073e9SAndroid Build Coastguard Worker # that relies torch function that gets unexpectedly disabled after we 1531*da0073e9SAndroid Build Coastguard Worker # redispatch from the subclass's torch function. 1532*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 1533*da0073e9SAndroid Build Coastguard Worker @staticmethod 1534*da0073e9SAndroid Build Coastguard Worker def __new__(cls, t): 1535*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 1536*da0073e9SAndroid Build Coastguard Worker cls, 1537*da0073e9SAndroid Build Coastguard Worker t.shape, 1538*da0073e9SAndroid Build Coastguard Worker t.stride(), 1539*da0073e9SAndroid Build Coastguard Worker t.storage_offset(), 1540*da0073e9SAndroid Build Coastguard Worker torch.contiguous_format, 1541*da0073e9SAndroid Build Coastguard Worker t.dtype, 1542*da0073e9SAndroid Build Coastguard Worker torch.strided, 1543*da0073e9SAndroid Build Coastguard Worker t.device, 1544*da0073e9SAndroid Build Coastguard Worker False, 1545*da0073e9SAndroid Build Coastguard Worker t.requires_grad, 1546*da0073e9SAndroid Build Coastguard Worker "sizes", 1547*da0073e9SAndroid Build Coastguard Worker False, 1548*da0073e9SAndroid Build Coastguard Worker False, 1549*da0073e9SAndroid Build Coastguard Worker None, 1550*da0073e9SAndroid Build Coastguard Worker ) 1551*da0073e9SAndroid Build Coastguard Worker 1552*da0073e9SAndroid Build Coastguard Worker def __init__(self, t): 1553*da0073e9SAndroid Build Coastguard Worker super().__init__() 1554*da0073e9SAndroid Build Coastguard Worker self._t = t 1555*da0073e9SAndroid Build Coastguard Worker 1556*da0073e9SAndroid Build Coastguard Worker def __tensor_flatten__(self): 1557*da0073e9SAndroid Build Coastguard Worker return ["_t"], {} 1558*da0073e9SAndroid Build Coastguard Worker 1559*da0073e9SAndroid Build Coastguard Worker @staticmethod 1560*da0073e9SAndroid Build Coastguard Worker def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): 1561*da0073e9SAndroid Build Coastguard Worker t = inner_tensors["_t"] 1562*da0073e9SAndroid Build Coastguard Worker return SubTensor(t) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 1565*da0073e9SAndroid Build Coastguard Worker return f"SubTensor({self._t})" 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker @classmethod 1568*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 1569*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1570*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker with torch._C.DisableTorchFunctionSubclass(): 1573*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1574*da0073e9SAndroid Build Coastguard Worker 1575*da0073e9SAndroid Build Coastguard Worker @classmethod 1576*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1577*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 1578*da0073e9SAndroid Build Coastguard Worker new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) 1579*da0073e9SAndroid Build Coastguard Worker output = func(*new_args, **kwargs) 1580*da0073e9SAndroid Build Coastguard Worker output = pytree.tree_map_only( 1581*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda t: SubTensor(t), output 1582*da0073e9SAndroid Build Coastguard Worker ) 1583*da0073e9SAndroid Build Coastguard Worker return output 1584*da0073e9SAndroid Build Coastguard Worker 1585*da0073e9SAndroid Build Coastguard Worker @torch.compile(dynamic=True) 1586*da0073e9SAndroid Build Coastguard Worker def f(x): 1587*da0073e9SAndroid Build Coastguard Worker return x.unflatten(-1, [2, 5]) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker s = SubTensor(torch.randn(3, 10)) 1590*da0073e9SAndroid Build Coastguard Worker f(s) 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker # Guard validation upsets the guard 1593*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/129936 1594*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1595*da0073e9SAndroid Build Coastguard Worker def test_recompile_with_symbool_inputs(self): 1596*da0073e9SAndroid Build Coastguard Worker def f(pred: bool): 1597*da0073e9SAndroid Build Coastguard Worker if pred: 1598*da0073e9SAndroid Build Coastguard Worker return torch.ones([3, 4]) 1599*da0073e9SAndroid Build Coastguard Worker else: 1600*da0073e9SAndroid Build Coastguard Worker return torch.ones([4, 3]) 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker def test_recompilation( 1603*da0073e9SAndroid Build Coastguard Worker f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards 1604*da0073e9SAndroid Build Coastguard Worker ): 1605*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 1606*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 1607*da0073e9SAndroid Build Coastguard Worker backend = torch._dynamo.testing.EagerAndRecordGraphs() 1608*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1609*da0073e9SAndroid Build Coastguard Worker f_cond = torch.compile(f, backend=cnt, fullgraph=True) 1610*da0073e9SAndroid Build Coastguard Worker with torch._subclasses.fake_tensor.FakeTensorMode( 1611*da0073e9SAndroid Build Coastguard Worker shape_env=shape_env 1612*da0073e9SAndroid Build Coastguard Worker ) as fake_mode: 1613*da0073e9SAndroid Build Coastguard Worker fake_inp = fake_mode.from_tensor( 1614*da0073e9SAndroid Build Coastguard Worker x, 1615*da0073e9SAndroid Build Coastguard Worker symbolic_context=StatelessSymbolicContext( 1616*da0073e9SAndroid Build Coastguard Worker dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] 1617*da0073e9SAndroid Build Coastguard Worker ), 1618*da0073e9SAndroid Build Coastguard Worker ) 1619*da0073e9SAndroid Build Coastguard Worker for i, size in enumerate(sizes): 1620*da0073e9SAndroid Build Coastguard Worker pred = fake_inp.size(0) == size 1621*da0073e9SAndroid Build Coastguard Worker f_cond(pred) 1622*da0073e9SAndroid Build Coastguard Worker actual = normalize_gm( 1623*da0073e9SAndroid Build Coastguard Worker backend.graphs[exp_frame_count[i] - 1].print_readable( 1624*da0073e9SAndroid Build Coastguard Worker print_output=False 1625*da0073e9SAndroid Build Coastguard Worker ) 1626*da0073e9SAndroid Build Coastguard Worker ) 1627*da0073e9SAndroid Build Coastguard Worker actual_guard_str = [str(guard.expr) for guard in shape_env.guards] 1628*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(actual, exp_graphs[i]) 1629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, exp_frame_count[i]) 1630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) 1631*da0073e9SAndroid Build Coastguard Worker 1632*da0073e9SAndroid Build Coastguard Worker true_graph = """\ 1633*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1634*da0073e9SAndroid Build Coastguard Worker def forward(self): 1635*da0073e9SAndroid Build Coastguard Worker ones: "f32[3, 4]" = torch.ones([3, 4]) 1636*da0073e9SAndroid Build Coastguard Worker return (ones,) 1637*da0073e9SAndroid Build Coastguard Worker""" 1638*da0073e9SAndroid Build Coastguard Worker false_graph = """\ 1639*da0073e9SAndroid Build Coastguard Workerclass GraphModule(torch.nn.Module): 1640*da0073e9SAndroid Build Coastguard Worker def forward(self): 1641*da0073e9SAndroid Build Coastguard Worker ones: "f32[4, 3]" = torch.ones([4, 3]) 1642*da0073e9SAndroid Build Coastguard Worker return (ones,) 1643*da0073e9SAndroid Build Coastguard Worker""" 1644*da0073e9SAndroid Build Coastguard Worker test_recompilation( 1645*da0073e9SAndroid Build Coastguard Worker f, 1646*da0073e9SAndroid Build Coastguard Worker torch.randn([3, 4]), 1647*da0073e9SAndroid Build Coastguard Worker [3, 3, 4, 5], 1648*da0073e9SAndroid Build Coastguard Worker exp_graphs=[true_graph, true_graph, false_graph, false_graph], 1649*da0073e9SAndroid Build Coastguard Worker exp_frame_count=[1, 1, 2, 2], 1650*da0073e9SAndroid Build Coastguard Worker exp_shape_env_guards=[ 1651*da0073e9SAndroid Build Coastguard Worker [], 1652*da0073e9SAndroid Build Coastguard Worker # s0 is specialized and guarded in outter shape_env when dynamo checks the guards 1653*da0073e9SAndroid Build Coastguard Worker ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], 1654*da0073e9SAndroid Build Coastguard Worker [ 1655*da0073e9SAndroid Build Coastguard Worker "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1656*da0073e9SAndroid Build Coastguard Worker "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", 1657*da0073e9SAndroid Build Coastguard Worker ], 1658*da0073e9SAndroid Build Coastguard Worker [ 1659*da0073e9SAndroid Build Coastguard Worker "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1660*da0073e9SAndroid Build Coastguard Worker "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", 1661*da0073e9SAndroid Build Coastguard Worker "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1662*da0073e9SAndroid Build Coastguard Worker ], 1663*da0073e9SAndroid Build Coastguard Worker ], 1664*da0073e9SAndroid Build Coastguard Worker ) 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker test_recompilation( 1667*da0073e9SAndroid Build Coastguard Worker f, 1668*da0073e9SAndroid Build Coastguard Worker torch.randn([3, 4]), 1669*da0073e9SAndroid Build Coastguard Worker [4, 5, 3, 3], 1670*da0073e9SAndroid Build Coastguard Worker exp_graphs=[false_graph, false_graph, true_graph, true_graph], 1671*da0073e9SAndroid Build Coastguard Worker exp_frame_count=[1, 1, 2, 2], 1672*da0073e9SAndroid Build Coastguard Worker exp_shape_env_guards=[ 1673*da0073e9SAndroid Build Coastguard Worker [], 1674*da0073e9SAndroid Build Coastguard Worker # s0 is specialized and guarded in outter shape_env when dynamo checks the guards 1675*da0073e9SAndroid Build Coastguard Worker ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], 1676*da0073e9SAndroid Build Coastguard Worker [ 1677*da0073e9SAndroid Build Coastguard Worker "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1678*da0073e9SAndroid Build Coastguard Worker "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1679*da0073e9SAndroid Build Coastguard Worker ], 1680*da0073e9SAndroid Build Coastguard Worker [ 1681*da0073e9SAndroid Build Coastguard Worker "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", 1682*da0073e9SAndroid Build Coastguard Worker "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1683*da0073e9SAndroid Build Coastguard Worker "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", 1684*da0073e9SAndroid Build Coastguard Worker ], 1685*da0073e9SAndroid Build Coastguard Worker ], 1686*da0073e9SAndroid Build Coastguard Worker ) 1687*da0073e9SAndroid Build Coastguard Worker 1688*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): 1689*da0073e9SAndroid Build Coastguard Worker def f(x_subclass): 1690*da0073e9SAndroid Build Coastguard Worker tmp_subclass = torch.add(x, 1) 1691*da0073e9SAndroid Build Coastguard Worker return torch.mul(tmp_subclass._scale, tmp_subclass._constant) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) 1694*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 1695*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) 1696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def test_support_bases(self): 1699*da0073e9SAndroid Build Coastguard Worker import abc 1700*da0073e9SAndroid Build Coastguard Worker 1701*da0073e9SAndroid Build Coastguard Worker import torch.fx._symbolic_trace 1702*da0073e9SAndroid Build Coastguard Worker 1703*da0073e9SAndroid Build Coastguard Worker class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): 1704*da0073e9SAndroid Build Coastguard Worker def __new__(cls, name, bases, dct): 1705*da0073e9SAndroid Build Coastguard Worker x = super().__new__(cls, name, bases, dct) 1706*da0073e9SAndroid Build Coastguard Worker x.attr = 100 1707*da0073e9SAndroid Build Coastguard Worker return x 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker class Multistreamable(abc.ABC): # noqa: B024 1710*da0073e9SAndroid Build Coastguard Worker pass 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker class Foo(Multistreamable, metaclass=Meta): 1713*da0073e9SAndroid Build Coastguard Worker pass 1714*da0073e9SAndroid Build Coastguard Worker 1715*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 1716*da0073e9SAndroid Build Coastguard Worker def f(x): 1717*da0073e9SAndroid Build Coastguard Worker typ = type(Foo()) 1718*da0073e9SAndroid Build Coastguard Worker typ.__bases__ 1719*da0073e9SAndroid Build Coastguard Worker return typ.__bases__ 1720*da0073e9SAndroid Build Coastguard Worker 1721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.randn(1)), (Multistreamable,)) 1722*da0073e9SAndroid Build Coastguard Worker 1723*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 1724*da0073e9SAndroid Build Coastguard Worker def g(x): 1725*da0073e9SAndroid Build Coastguard Worker typ = type(Foo()) 1726*da0073e9SAndroid Build Coastguard Worker typ.__base__ 1727*da0073e9SAndroid Build Coastguard Worker return typ.__base__ 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g(torch.randn(1)), Multistreamable) 1730*da0073e9SAndroid Build Coastguard Worker 1731*da0073e9SAndroid Build Coastguard Worker @parametrize("dynamic", [False, True]) 1732*da0073e9SAndroid Build Coastguard Worker def test_subclass_views(self, dynamic): 1733*da0073e9SAndroid Build Coastguard Worker def _get_views(t): # returns (view: Tensor, expects_raises_false) 1734*da0073e9SAndroid Build Coastguard Worker # Note that any closed-over SymInts will be symbolicized during fake-ification. 1735*da0073e9SAndroid Build Coastguard Worker yield t.narrow(dim=-1, start=3, length=8), False 1736*da0073e9SAndroid Build Coastguard Worker yield t.split(5, -1)[2], False 1737*da0073e9SAndroid Build Coastguard Worker yield t.split_with_sizes([9, 6], -1)[1], False 1738*da0073e9SAndroid Build Coastguard Worker yield t.unsqueeze(-1).expand(4, 15, 10), False 1739*da0073e9SAndroid Build Coastguard Worker yield t.select(-1, 6), False 1740*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/128649 1741*da0073e9SAndroid Build Coastguard Worker yield t[2:3, 5:9], dynamic 1742*da0073e9SAndroid Build Coastguard Worker yield t.view(-1, 15), False 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Worker def f(x): 1745*da0073e9SAndroid Build Coastguard Worker return x * 2 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile( 1748*da0073e9SAndroid Build Coastguard Worker f, backend="aot_eager", fullgraph=True, dynamic=dynamic 1749*da0073e9SAndroid Build Coastguard Worker ) 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker # Take a view of a subclass to pass as input. 1752*da0073e9SAndroid Build Coastguard Worker t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) 1753*da0073e9SAndroid Build Coastguard Worker for view, expects_raises in _get_views(t): 1754*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 1755*da0073e9SAndroid Build Coastguard Worker out_ref = f(view) 1756*da0073e9SAndroid Build Coastguard Worker if expects_raises: 1757*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1758*da0073e9SAndroid Build Coastguard Worker out_test = compiled_f(view) 1759*da0073e9SAndroid Build Coastguard Worker else: 1760*da0073e9SAndroid Build Coastguard Worker out_test = compiled_f(view) 1761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1762*da0073e9SAndroid Build Coastguard Worker 1763*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) 1764*da0073e9SAndroid Build Coastguard Worker def test_mark_static_with_subclass_desugaring(self): 1765*da0073e9SAndroid Build Coastguard Worker from typing import Any, Callable, Dict, List, Optional 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.decorators import mark_static_address 1768*da0073e9SAndroid Build Coastguard Worker from torch._inductor.compile_fx import compile_fx 1769*da0073e9SAndroid Build Coastguard Worker from torch._inductor.cudagraph_utils import BoxedDeviceIndex 1770*da0073e9SAndroid Build Coastguard Worker from torch._inductor.utils import BoxedBool 1771*da0073e9SAndroid Build Coastguard Worker 1772*da0073e9SAndroid Build Coastguard Worker x_inner = torch.ones(4) 1773*da0073e9SAndroid Build Coastguard Worker x = TwoTensor(x_inner, x_inner) 1774*da0073e9SAndroid Build Coastguard Worker mark_static_address(x, guard=False) 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker def inner_compile( 1777*da0073e9SAndroid Build Coastguard Worker gm: torch.fx.GraphModule, 1778*da0073e9SAndroid Build Coastguard Worker example_inputs: List[torch.Tensor], 1779*da0073e9SAndroid Build Coastguard Worker cudagraphs: Optional[BoxedBool] = None, 1780*da0073e9SAndroid Build Coastguard Worker static_input_idxs: Optional[List[int]] = None, 1781*da0073e9SAndroid Build Coastguard Worker is_backward: bool = False, 1782*da0073e9SAndroid Build Coastguard Worker graph_id: Optional[int] = None, 1783*da0073e9SAndroid Build Coastguard Worker cpp_wrapper: bool = False, 1784*da0073e9SAndroid Build Coastguard Worker aot_mode: bool = False, 1785*da0073e9SAndroid Build Coastguard Worker is_inference: bool = False, 1786*da0073e9SAndroid Build Coastguard Worker boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, 1787*da0073e9SAndroid Build Coastguard Worker user_visible_outputs: Optional[Dict[str, None]] = None, 1788*da0073e9SAndroid Build Coastguard Worker layout_opt: Optional[bool] = None, 1789*da0073e9SAndroid Build Coastguard Worker extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None, 1790*da0073e9SAndroid Build Coastguard Worker ): 1791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(static_input_idxs, [1, 2]) 1792*da0073e9SAndroid Build Coastguard Worker return gm 1793*da0073e9SAndroid Build Coastguard Worker 1794*da0073e9SAndroid Build Coastguard Worker compiler = functools.partial(compile_fx, inner_compile=inner_compile) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=compiler) 1797*da0073e9SAndroid Build Coastguard Worker def fn(t0, t1, t2): 1798*da0073e9SAndroid Build Coastguard Worker return t0 + t1 + t2 + 2 1799*da0073e9SAndroid Build Coastguard Worker 1800*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(4), x, torch.ones(4)) 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker 1803*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(SubclassTests) 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker 1806*da0073e9SAndroid Build Coastguard Workerclass TestNestedTensor(torch._dynamo.test_case.TestCase, NestedTensorTestCase): 1807*da0073e9SAndroid Build Coastguard Worker def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): 1808*da0073e9SAndroid Build Coastguard Worker return get_jagged_tensor(nested_size, offsets, requires_grad) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): 1811*da0073e9SAndroid Build Coastguard Worker # Makes a jagged tensor with N constituent tensors with size 1812*da0073e9SAndroid Build Coastguard Worker # as specified ((S0, S1, S2), D) 1813*da0073e9SAndroid Build Coastguard Worker max_dim = (starts + lengths).max() 1814*da0073e9SAndroid Build Coastguard Worker values_tensor = torch.randn( 1815*da0073e9SAndroid Build Coastguard Worker starts.shape[0], 1816*da0073e9SAndroid Build Coastguard Worker max_dim.item(), 1817*da0073e9SAndroid Build Coastguard Worker inner_dim, 1818*da0073e9SAndroid Build Coastguard Worker requires_grad=requires_grad, 1819*da0073e9SAndroid Build Coastguard Worker dtype=torch.float64, 1820*da0073e9SAndroid Build Coastguard Worker ) 1821*da0073e9SAndroid Build Coastguard Worker return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): 1824*da0073e9SAndroid Build Coastguard Worker _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles) 1825*da0073e9SAndroid Build Coastguard Worker 1826*da0073e9SAndroid Build Coastguard Worker def test_unary_does_not_recompile(self): 1827*da0073e9SAndroid Build Coastguard Worker nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) 1828*da0073e9SAndroid Build Coastguard Worker nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) 1829*da0073e9SAndroid Build Coastguard Worker self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) 1830*da0073e9SAndroid Build Coastguard Worker 1831*da0073e9SAndroid Build Coastguard Worker def test_binary_does_not_recompile(self): 1832*da0073e9SAndroid Build Coastguard Worker def binary(nt1, nt2): 1833*da0073e9SAndroid Build Coastguard Worker if nt1.shape == nt2.shape: 1834*da0073e9SAndroid Build Coastguard Worker return nt1 + nt2 1835*da0073e9SAndroid Build Coastguard Worker else: 1836*da0073e9SAndroid Build Coastguard Worker return nt1.sin() 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). 1839*da0073e9SAndroid Build Coastguard Worker # This causes a recompile later on when it realizes the batch and last dim 1840*da0073e9SAndroid Build Coastguard Worker # should not always be equal. To avoid that, we use (3, j0, 5) here. 1841*da0073e9SAndroid Build Coastguard Worker nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) 1842*da0073e9SAndroid Build Coastguard Worker nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) 1843*da0073e9SAndroid Build Coastguard Worker nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) 1844*da0073e9SAndroid Build Coastguard Worker nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) 1845*da0073e9SAndroid Build Coastguard Worker self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker def test_binary_recompiles(self): 1848*da0073e9SAndroid Build Coastguard Worker def binary(nt1, nt2): 1849*da0073e9SAndroid Build Coastguard Worker if nt1.shape == nt2.shape: 1850*da0073e9SAndroid Build Coastguard Worker return nt1 + nt2 1851*da0073e9SAndroid Build Coastguard Worker else: 1852*da0073e9SAndroid Build Coastguard Worker return nt1.sin() 1853*da0073e9SAndroid Build Coastguard Worker 1854*da0073e9SAndroid Build Coastguard Worker # Binary recompiles because singleton ints no longer match 1855*da0073e9SAndroid Build Coastguard Worker nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) 1856*da0073e9SAndroid Build Coastguard Worker nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) 1857*da0073e9SAndroid Build Coastguard Worker nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 1858*da0073e9SAndroid Build Coastguard Worker self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker def _validate_compile(self, fn, arg_fn): 1861*da0073e9SAndroid Build Coastguard Worker def _gen_grad_outputs(out_val): 1862*da0073e9SAndroid Build Coastguard Worker if isinstance(out_val, (list, tuple)): 1863*da0073e9SAndroid Build Coastguard Worker return tuple(torch.ones_like(c) for c in out_val) 1864*da0073e9SAndroid Build Coastguard Worker else: 1865*da0073e9SAndroid Build Coastguard Worker return (torch.ones_like(out_val),) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker with self.branch_nested_state(): 1868*da0073e9SAndroid Build Coastguard Worker from torch.nested._internal.nested_tensor import _tensor_symint_registry 1869*da0073e9SAndroid Build Coastguard Worker 1870*da0073e9SAndroid Build Coastguard Worker # Validate that compilation does not modify eager state 1871*da0073e9SAndroid Build Coastguard Worker registry_before = list(_tensor_symint_registry.items()) 1872*da0073e9SAndroid Build Coastguard Worker count_before = torch.nested._internal.nested_tensor._tensor_id_counter 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker guards_exported = [] 1875*da0073e9SAndroid Build Coastguard Worker guards_failed = [] 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker def append_guard_export(guards): 1878*da0073e9SAndroid Build Coastguard Worker for g in guards: 1879*da0073e9SAndroid Build Coastguard Worker if g.code_list is not None: 1880*da0073e9SAndroid Build Coastguard Worker guards_exported.append(g.code_list[0]) 1881*da0073e9SAndroid Build Coastguard Worker 1882*da0073e9SAndroid Build Coastguard Worker def append_guard_fail(guards): 1883*da0073e9SAndroid Build Coastguard Worker guards_failed.extend(guards) 1884*da0073e9SAndroid Build Coastguard Worker 1885*da0073e9SAndroid Build Coastguard Worker compiled = torch._dynamo.optimize( 1886*da0073e9SAndroid Build Coastguard Worker nopython=True, 1887*da0073e9SAndroid Build Coastguard Worker backend="aot_eager", 1888*da0073e9SAndroid Build Coastguard Worker guard_export_fn=append_guard_export, 1889*da0073e9SAndroid Build Coastguard Worker guard_fail_fn=append_guard_fail, 1890*da0073e9SAndroid Build Coastguard Worker )(fn) 1891*da0073e9SAndroid Build Coastguard Worker registry_after = list(_tensor_symint_registry.items()) 1892*da0073e9SAndroid Build Coastguard Worker count_after = torch.nested._internal.nested_tensor._tensor_id_counter 1893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(registry_before, registry_after) 1894*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_before, count_after) 1895*da0073e9SAndroid Build Coastguard Worker 1896*da0073e9SAndroid Build Coastguard Worker args = arg_fn() 1897*da0073e9SAndroid Build Coastguard Worker compile_out = compiled(*args) 1898*da0073e9SAndroid Build Coastguard Worker compile_grads = [] 1899*da0073e9SAndroid Build Coastguard Worker g_args = [arg for arg in args if arg.requires_grad] 1900*da0073e9SAndroid Build Coastguard Worker if len(g_args) > 0: 1901*da0073e9SAndroid Build Coastguard Worker compile_grad_outputs = _gen_grad_outputs(compile_out) 1902*da0073e9SAndroid Build Coastguard Worker compile_grads = torch.autograd.grad( 1903*da0073e9SAndroid Build Coastguard Worker compile_out, inputs=g_args, grad_outputs=compile_grad_outputs 1904*da0073e9SAndroid Build Coastguard Worker ) 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker with self.branch_nested_state(): 1907*da0073e9SAndroid Build Coastguard Worker args = arg_fn() 1908*da0073e9SAndroid Build Coastguard Worker ref_out = fn(*args) 1909*da0073e9SAndroid Build Coastguard Worker ref_grads = [] 1910*da0073e9SAndroid Build Coastguard Worker g_args = [arg for arg in args if arg.requires_grad] 1911*da0073e9SAndroid Build Coastguard Worker if len(g_args) > 0: 1912*da0073e9SAndroid Build Coastguard Worker ref_grad_outputs = _gen_grad_outputs(ref_out) 1913*da0073e9SAndroid Build Coastguard Worker ref_grads = torch.autograd.grad( 1914*da0073e9SAndroid Build Coastguard Worker ref_out, inputs=g_args, grad_outputs=ref_grad_outputs 1915*da0073e9SAndroid Build Coastguard Worker ) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker # Validate correctness forward 1918*da0073e9SAndroid Build Coastguard Worker if isinstance(compile_out, (list, tuple)): 1919*da0073e9SAndroid Build Coastguard Worker # TODO: Fix assertEqual() to support NJTs so this isn't necessary 1920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(compile_out), len(ref_out)) 1921*da0073e9SAndroid Build Coastguard Worker for c, r in zip(compile_out, ref_out): 1922*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(c, r) 1923*da0073e9SAndroid Build Coastguard Worker else: 1924*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(compile_out, ref_out) 1925*da0073e9SAndroid Build Coastguard Worker 1926*da0073e9SAndroid Build Coastguard Worker # Validate correctness backward 1927*da0073e9SAndroid Build Coastguard Worker for compile_grad, ref_grad in zip(compile_grads, ref_grads): 1928*da0073e9SAndroid Build Coastguard Worker self.assertEqualIgnoringNestedInts(compile_grad, ref_grad) 1929*da0073e9SAndroid Build Coastguard Worker 1930*da0073e9SAndroid Build Coastguard Worker return guards_exported, guards_failed 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker # Note: [What kind of guards are involved in nested tensor compilation] 1933*da0073e9SAndroid Build Coastguard Worker # 1934*da0073e9SAndroid Build Coastguard Worker # Until we implement UnionFind, dynamic shapes guards are not involved. 1935*da0073e9SAndroid Build Coastguard Worker # we rely only on dynamo's tensor aliasing guards. 1936*da0073e9SAndroid Build Coastguard Worker # 1937*da0073e9SAndroid Build Coastguard Worker # This is possible because dynamo able to generate tensor aliasing guards 1938*da0073e9SAndroid Build Coastguard Worker # not only for the outer tensor, but also for the inner tensor. 1939*da0073e9SAndroid Build Coastguard Worker # 1940*da0073e9SAndroid Build Coastguard Worker # The case where dynamic shapes guards would eventually come into play is 1941*da0073e9SAndroid Build Coastguard Worker # when my inputs are (1) two non-aliased tensors, but (2) declared as 1942*da0073e9SAndroid Build Coastguard Worker # equal using a "trust me assert equal" API. 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker # Note: [Compiling nested tensor global state] 1945*da0073e9SAndroid Build Coastguard Worker # 1946*da0073e9SAndroid Build Coastguard Worker # Today there are two pieces of global eager state that NJTs deals with: 1947*da0073e9SAndroid Build Coastguard Worker # - tensor_id_counter: a global counter that assigns unique ids to tensors 1948*da0073e9SAndroid Build Coastguard Worker # - tensor_symint_registry: maps tensor to nested int 1949*da0073e9SAndroid Build Coastguard Worker # - this is used in eager only (we should get rid of this because it is 1950*da0073e9SAndroid Build Coastguard Worker # not necessary to cache nested int in eager) 1951*da0073e9SAndroid Build Coastguard Worker # - during tracing, we DO need to cache nested int, but we do so on 1952*da0073e9SAndroid Build Coastguard Worker # the FakeTensor. 1953*da0073e9SAndroid Build Coastguard Worker # 1954*da0073e9SAndroid Build Coastguard Worker # Ideally we would like to satisfy the following: 1955*da0073e9SAndroid Build Coastguard Worker # - (1) The eager state is not mutated during tracing 1956*da0073e9SAndroid Build Coastguard Worker # - (2) Running the compiled function should mutate the eager state in the 1957*da0073e9SAndroid Build Coastguard Worker # same way that running the eager function would 1958*da0073e9SAndroid Build Coastguard Worker # (a) The global counter should be incremented 1959*da0073e9SAndroid Build Coastguard Worker # (b) The registry is updated in the same way 1960*da0073e9SAndroid Build Coastguard Worker # 1961*da0073e9SAndroid Build Coastguard Worker # Today we can satisfy (1) and (2a) but cannot satisfy (2b) 1962*da0073e9SAndroid Build Coastguard Worker # 1963*da0073e9SAndroid Build Coastguard Worker # Today, (1) is satisfied because we maintain a separate counter during 1964*da0073e9SAndroid Build Coastguard Worker # tracing, and cache nested int on FakeTensor instead of relying on 1965*da0073e9SAndroid Build Coastguard Worker # tensor_symint_registry. 1966*da0073e9SAndroid Build Coastguard Worker # 1967*da0073e9SAndroid Build Coastguard Worker # (2) is cannot be completely satisfied because we trace away the 1968*da0073e9SAndroid Build Coastguard Worker # side-effectful operations (which we can fix this by wrapping the 1969*da0073e9SAndroid Build Coastguard Worker # side-effectful operations in a custom op, and threading through effect 1970*da0073e9SAndroid Build Coastguard Worker # tokens.) The current plan is to do that in the UnionFind impl. 1971*da0073e9SAndroid Build Coastguard Worker # 1972*da0073e9SAndroid Build Coastguard Worker # Interestingly, despite this, the state is mutated in a way that is somewhat 1973*da0073e9SAndroid Build Coastguard Worker # close to what we want, e.g. if I construct a nested tensor using an 1974*da0073e9SAndroid Build Coastguard Worker # offsets in the compiled region and return it, AOTAutograd runtime wrapper 1975*da0073e9SAndroid Build Coastguard Worker # must rewrap the inner->inner graph outputs back into subclass. This 1976*da0073e9SAndroid Build Coastguard Worker # triggers the eager logic to run, updating the counter and registry. 1977*da0073e9SAndroid Build Coastguard Worker # 1978*da0073e9SAndroid Build Coastguard Worker # Notably however, compile differs in two ways from eager: 1979*da0073e9SAndroid Build Coastguard Worker # (1) The order in which the offsets are assigned ids is differnet 1980*da0073e9SAndroid Build Coastguard Worker # the registry would be set in the order the offsets are returned 1981*da0073e9SAndroid Build Coastguard Worker # which is not necessarily the same order as they were constructed. 1982*da0073e9SAndroid Build Coastguard Worker # (2) If a NestedTensor is not returned, then the AOTAutograd wrapping 1983*da0073e9SAndroid Build Coastguard Worker # logic will not be triggered. 1984*da0073e9SAndroid Build Coastguard Worker # 1985*da0073e9SAndroid Build Coastguard Worker # I claim that correctness is not affected by these differences today. 1986*da0073e9SAndroid Build Coastguard Worker # e.g. there is never the case where two distinct offsets silently share 1987*da0073e9SAndroid Build Coastguard Worker # the same id. 1988*da0073e9SAndroid Build Coastguard Worker # 1989*da0073e9SAndroid Build Coastguard Worker # (1) is clearly not a problem, and (2) should only be a problem if 1990*da0073e9SAndroid Build Coastguard Worker # the nested int is returned on its own, without the corresponding NJT 1991*da0073e9SAndroid Build Coastguard Worker # being returned. This is not a problem in the current implementation 1992*da0073e9SAndroid Build Coastguard Worker # because returning only a shape is not supported! 1993*da0073e9SAndroid Build Coastguard Worker 1994*da0073e9SAndroid Build Coastguard Worker # Note: [Creating symbolic nested int] 1995*da0073e9SAndroid Build Coastguard Worker # 1996*da0073e9SAndroid Build Coastguard Worker # We must create a symbolic nested int when we construct a nested tensor 1997*da0073e9SAndroid Build Coastguard Worker # from a tensor. There are two main cases: 1998*da0073e9SAndroid Build Coastguard Worker # 1999*da0073e9SAndroid Build Coastguard Worker # 1. The offsets has NOT been used to construct a NJT 2000*da0073e9SAndroid Build Coastguard Worker # - Create a new plain nested int with current val of fake nt id counter 2001*da0073e9SAndroid Build Coastguard Worker # - Increment the fake nt id counter 2002*da0073e9SAndroid Build Coastguard Worker # - Create a new symint with plain nested int as hint 2003*da0073e9SAndroid Build Coastguard Worker # 2. The offsets HAS been used to construct a NJT 2004*da0073e9SAndroid Build Coastguard Worker # - Create a new symint with plain nested int as hint 2005*da0073e9SAndroid Build Coastguard Worker # 2006*da0073e9SAndroid Build Coastguard Worker # More details on case 2: 2007*da0073e9SAndroid Build Coastguard Worker # - During fakification of the offsets, we check the eager registry, and 2008*da0073e9SAndroid Build Coastguard Worker # if the tensor HAS been used to construct a NJT, 2009*da0073e9SAndroid Build Coastguard Worker # we create a symint, with the existing nested int as hint, and cache 2010*da0073e9SAndroid Build Coastguard Worker # it on to the FakeTensor. 2011*da0073e9SAndroid Build Coastguard Worker # 2012*da0073e9SAndroid Build Coastguard Worker # [ Always use ephemeral source ] 2013*da0073e9SAndroid Build Coastguard Worker # 2014*da0073e9SAndroid Build Coastguard Worker # We create the new symint ALWAYS with ephemeral source whether that is 2015*da0073e9SAndroid Build Coastguard Worker # in case (1) or (2) even though we could've had a proper source for case (2). 2016*da0073e9SAndroid Build Coastguard Worker # Using a proper source would enable a few more (edge) cases, but since 2017*da0073e9SAndroid Build Coastguard Worker # we plan to handle things more holistically in the future anyway, we don't 2018*da0073e9SAndroid Build Coastguard Worker # bother doing so today. 2019*da0073e9SAndroid Build Coastguard Worker # 2020*da0073e9SAndroid Build Coastguard Worker # Using an ephemeral source has some consequences. But we are happy if 2021*da0073e9SAndroid Build Coastguard Worker # - We do not silently miss recompiles, e.g. we guard when necessary. 2022*da0073e9SAndroid Build Coastguard Worker # We know that this is true, because dynamo guards alone are already 2023*da0073e9SAndroid Build Coastguard Worker # sufficient. 2024*da0073e9SAndroid Build Coastguard Worker # - We are not producing errors for the cases we care about 2025*da0073e9SAndroid Build Coastguard Worker # 2026*da0073e9SAndroid Build Coastguard Worker # The main case we care about is when we guard that two shapes are equal. 2027*da0073e9SAndroid Build Coastguard Worker # In this case, the replacements logic would simplify away the ephemeral 2028*da0073e9SAndroid Build Coastguard Worker # symbol, and there is no error produced. 2029*da0073e9SAndroid Build Coastguard Worker # The unsupported case is when we guard that two shapes are not equal, in 2030*da0073e9SAndroid Build Coastguard Worker # which, we will try and fail to generate a guard. 2031*da0073e9SAndroid Build Coastguard Worker 2032*da0073e9SAndroid Build Coastguard Worker # 2033*da0073e9SAndroid Build Coastguard Worker # Case 1: in-graph construction where the offsets are passed as inputs 2034*da0073e9SAndroid Build Coastguard Worker # 2035*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_input(self): 2036*da0073e9SAndroid Build Coastguard Worker # The offsets is passed as an input 2037*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets): 2038*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor_from_jagged(values * 2, offsets) * 2 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2041*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2042*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2043*da0073e9SAndroid Build Coastguard Worker 2044*da0073e9SAndroid Build Coastguard Worker # Do not specialize on the offsets 2045*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2046*da0073e9SAndroid Build Coastguard Worker different_offsets = torch.tensor([0, 1, 5, 10], dtype=torch.int64) 2047*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, different_offsets)) 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_input_2(self): 2050*da0073e9SAndroid Build Coastguard Worker # Construct two NJTs, both are passed as inputs 2051*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets1, offsets2): 2052*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets1) 2053*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) 2054*da0073e9SAndroid Build Coastguard Worker return nt2, nt1 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2057*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2058*da0073e9SAndroid Build Coastguard Worker offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) 2059*da0073e9SAndroid Build Coastguard Worker # 1. Offsets are different 2060*da0073e9SAndroid Build Coastguard Worker guards_exported, guards_failed = self._validate_compile( 2061*da0073e9SAndroid Build Coastguard Worker fn, arg_fn=lambda: (values, offsets, offsets2) 2062*da0073e9SAndroid Build Coastguard Worker ) 2063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(guards_failed), 0) 2064*da0073e9SAndroid Build Coastguard Worker self.assertNotIn("L['offsets1'] is L['offsets2']", guards_exported) 2065*da0073e9SAndroid Build Coastguard Worker 2066*da0073e9SAndroid Build Coastguard Worker # TODO 2067*da0073e9SAndroid Build Coastguard Worker # 2. Offsets are the same 2068*da0073e9SAndroid Build Coastguard Worker new_guards_exported, _ = self._validate_compile( 2069*da0073e9SAndroid Build Coastguard Worker fn, arg_fn=lambda: (values, offsets, offsets) 2070*da0073e9SAndroid Build Coastguard Worker ) 2071*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any("Duplicate tensors found" in g for g in guards_failed)) 2072*da0073e9SAndroid Build Coastguard Worker self.assertIn("L['offsets1'] is L['offsets2']", new_guards_exported) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2075*da0073e9SAndroid Build Coastguard Worker offsets3 = offsets.clone() 2076*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, offsets3, offsets3)) 2077*da0073e9SAndroid Build Coastguard Worker 2078*da0073e9SAndroid Build Coastguard Worker # Do a binary op 2079*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets, offsets2): 2080*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) 2081*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets2) 2082*da0073e9SAndroid Build Coastguard Worker return nt1 * nt2 2083*da0073e9SAndroid Build Coastguard Worker 2084*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, offsets, offsets)) 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_input_4(self): 2087*da0073e9SAndroid Build Coastguard Worker # The offsets is taken from an NJT input 2088*da0073e9SAndroid Build Coastguard Worker def fn(nt, other_values): 2089*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(other_values, nt.offsets()) 2090*da0073e9SAndroid Build Coastguard Worker return nt + nt2 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker values = torch.randn(9, 5, requires_grad=True) 2093*da0073e9SAndroid Build Coastguard Worker other_values = torch.randn(9, 5, requires_grad=True) 2094*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) 2095*da0073e9SAndroid Build Coastguard Worker 2096*da0073e9SAndroid Build Coastguard Worker def arg_fn(values=values, other_values=other_values, offsets=offsets): 2097*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2098*da0073e9SAndroid Build Coastguard Worker return nt, other_values 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=arg_fn) 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker # Do not specialize on the offsets 2103*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2104*da0073e9SAndroid Build Coastguard Worker different_offsets = offsets.clone() 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker def arg_fn( 2107*da0073e9SAndroid Build Coastguard Worker values=values, other_values=other_values, offsets=different_offsets 2108*da0073e9SAndroid Build Coastguard Worker ): 2109*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, different_offsets) 2110*da0073e9SAndroid Build Coastguard Worker return nt, other_values 2111*da0073e9SAndroid Build Coastguard Worker 2112*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=arg_fn) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_input_5(self): 2115*da0073e9SAndroid Build Coastguard Worker # Construct from lengths instead of offsets 2116*da0073e9SAndroid Build Coastguard Worker def fn(values, lengths): 2117*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths) 2118*da0073e9SAndroid Build Coastguard Worker return nt.sin() 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker values = torch.randn(9, 5, requires_grad=True) 2121*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([2, 4, 3]) 2122*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, lengths)) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker # 2125*da0073e9SAndroid Build Coastguard Worker # Case 2: in-graph construction where offsets are graph intermediates 2126*da0073e9SAndroid Build Coastguard Worker # 2127*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_intermediate(self): 2128*da0073e9SAndroid Build Coastguard Worker # offsets is an intermediate computed from lengths 2129*da0073e9SAndroid Build Coastguard Worker def fn(values, lengths): 2130*da0073e9SAndroid Build Coastguard Worker offsets = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]) 2131*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2132*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) 2133*da0073e9SAndroid Build Coastguard Worker return (nt * nt2).sin() 2134*da0073e9SAndroid Build Coastguard Worker 2135*da0073e9SAndroid Build Coastguard Worker values = torch.randn(9, 5, requires_grad=True) 2136*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([2, 4, 3]) 2137*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, lengths)) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker # Do not specialize on the lengths 2140*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): 2141*da0073e9SAndroid Build Coastguard Worker different_lengths = lengths.clone() 2142*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, different_lengths)) 2143*da0073e9SAndroid Build Coastguard Worker 2144*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_intermediate_2(self): 2145*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets): 2146*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) 2147*da0073e9SAndroid Build Coastguard Worker 2148*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2149*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2150*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2151*da0073e9SAndroid Build Coastguard Worker 2152*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_intermediate_3(self): 2153*da0073e9SAndroid Build Coastguard Worker # Note that due to CSE, clone is not necessarily called twice! 2154*da0073e9SAndroid Build Coastguard Worker def fn(values, offsets): 2155*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets.clone()) 2156*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets.clone()) 2157*da0073e9SAndroid Build Coastguard Worker return nt2, nt1 2158*da0073e9SAndroid Build Coastguard Worker 2159*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2160*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2161*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values, offsets)) 2162*da0073e9SAndroid Build Coastguard Worker 2163*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_intermediate_4(self): 2164*da0073e9SAndroid Build Coastguard Worker # Shared intermediate (should be same as case #1) 2165*da0073e9SAndroid Build Coastguard Worker def fn(values): 2166*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2167*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2168*da0073e9SAndroid Build Coastguard Worker values2 = torch.ones_like(values) 2169*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets) 2170*da0073e9SAndroid Build Coastguard Worker return nt * nt2 2171*da0073e9SAndroid Build Coastguard Worker 2172*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5).requires_grad_(True) 2173*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values,)) 2174*da0073e9SAndroid Build Coastguard Worker 2175*da0073e9SAndroid Build Coastguard Worker # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>', 2176*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 2177*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_from_intermediate_5(self): 2178*da0073e9SAndroid Build Coastguard Worker # non-shared intermediate 2179*da0073e9SAndroid Build Coastguard Worker def fn(values): 2180*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2181*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2182*da0073e9SAndroid Build Coastguard Worker values2 = torch.ones_like(values) 2183*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values2, offsets.clone()) 2184*da0073e9SAndroid Build Coastguard Worker if nt2.shape[1] != nt.shape[1]: 2185*da0073e9SAndroid Build Coastguard Worker return nt * 2 2186*da0073e9SAndroid Build Coastguard Worker else: 2187*da0073e9SAndroid Build Coastguard Worker return nt * 3 2188*da0073e9SAndroid Build Coastguard Worker 2189*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5).requires_grad_(True) 2190*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn=lambda: (values,)) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker # 2193*da0073e9SAndroid Build Coastguard Worker # Case 3: in-graph construction where offsets are both direct graph inputs 2194*da0073e9SAndroid Build Coastguard Worker # and passed in as part of an NJT's offsets. 2195*da0073e9SAndroid Build Coastguard Worker # 2196*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_mixed(self): 2197*da0073e9SAndroid Build Coastguard Worker def fn(nt, values, offsets): 2198*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values, offsets) 2199*da0073e9SAndroid Build Coastguard Worker return nt * nt2 2200*da0073e9SAndroid Build Coastguard Worker 2201*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2202*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker def arg_fn(values=values, offsets=offsets): 2205*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2206*da0073e9SAndroid Build Coastguard Worker return nt, values, offsets 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn) 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Worker # See Note: [Creating symbolic nested int] 2211*da0073e9SAndroid Build Coastguard Worker # AssertionError: s2 (could be from ['<ephemeral: intermediate_offsets_or_lengths>', 2212*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 2213*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_mixed_2(self): 2214*da0073e9SAndroid Build Coastguard Worker def fn(nt, values, offsets, nt2): 2215*da0073e9SAndroid Build Coastguard Worker # Intermediate offsets has ephemeral source 2216*da0073e9SAndroid Build Coastguard Worker intermediate_nt = torch.nested.nested_tensor_from_jagged( 2217*da0073e9SAndroid Build Coastguard Worker values, offsets.clone() 2218*da0073e9SAndroid Build Coastguard Worker ) 2219*da0073e9SAndroid Build Coastguard Worker # This creates a dynamic shapes neq guard 2220*da0073e9SAndroid Build Coastguard Worker if nt2.shape[1] != intermediate_nt.shape[1]: 2221*da0073e9SAndroid Build Coastguard Worker # We should always go here. 2222*da0073e9SAndroid Build Coastguard Worker nt = nt * 2 2223*da0073e9SAndroid Build Coastguard Worker return nt 2224*da0073e9SAndroid Build Coastguard Worker 2225*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5, requires_grad=True) 2226*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 10], dtype=torch.int64) 2227*da0073e9SAndroid Build Coastguard Worker offsets2 = torch.tensor([0, 1, 4, 10], dtype=torch.int64) 2228*da0073e9SAndroid Build Coastguard Worker 2229*da0073e9SAndroid Build Coastguard Worker def arg_fn(values=values, offsets=offsets, offsets2=offsets2): 2230*da0073e9SAndroid Build Coastguard Worker # Values is shared, but it shouldn't matter 2231*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2232*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values, offsets2) 2233*da0073e9SAndroid Build Coastguard Worker return nt, values, offsets, nt2 2234*da0073e9SAndroid Build Coastguard Worker 2235*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn) 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker def test_in_graph_construction_mixed_3(self): 2238*da0073e9SAndroid Build Coastguard Worker # More involved mixed case 2239*da0073e9SAndroid Build Coastguard Worker def fn(nt, values, offsets): 2240*da0073e9SAndroid Build Coastguard Worker nt1 = torch.nested.nested_tensor_from_jagged(values * 2, offsets) 2241*da0073e9SAndroid Build Coastguard Worker nt2 = torch.nested.nested_tensor_from_jagged(values * 3, offsets) 2242*da0073e9SAndroid Build Coastguard Worker return nt1 + nt2 + nt 2243*da0073e9SAndroid Build Coastguard Worker 2244*da0073e9SAndroid Build Coastguard Worker values = torch.randn(9, 5, requires_grad=True) 2245*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 2, 6, 9], dtype=torch.int64) 2246*da0073e9SAndroid Build Coastguard Worker 2247*da0073e9SAndroid Build Coastguard Worker def arg_fn(values=values, offsets=offsets): 2248*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor_from_jagged(values, offsets) 2249*da0073e9SAndroid Build Coastguard Worker return nt, values, offsets 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker self._validate_compile(fn, arg_fn) 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker def test_return_shape(self): 2254*da0073e9SAndroid Build Coastguard Worker nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2255*da0073e9SAndroid Build Coastguard Worker 2256*da0073e9SAndroid Build Coastguard Worker def fn(nt): 2257*da0073e9SAndroid Build Coastguard Worker return (nt * 2).shape 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker compiled = torch.compile(fn, fullgraph=True, backend="aot_eager") 2260*da0073e9SAndroid Build Coastguard Worker compiled(nt) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker def test_inference_tensor(self): 2263*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2264*da0073e9SAndroid Build Coastguard Worker nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2265*da0073e9SAndroid Build Coastguard Worker 2266*da0073e9SAndroid Build Coastguard Worker def fn(n): 2267*da0073e9SAndroid Build Coastguard Worker return n * 2 2268*da0073e9SAndroid Build Coastguard Worker 2269*da0073e9SAndroid Build Coastguard Worker torch.compile(fn, backend="eager")(nt) 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker # TODO: cannot parametrize this test class with device for some reason 2272*da0073e9SAndroid Build Coastguard Worker def _test_autograd(self, backend): 2273*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) 2274*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) 2275*da0073e9SAndroid Build Coastguard Worker c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) 2276*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) 2277*da0073e9SAndroid Build Coastguard Worker # TODO: Switch to public API when it exists 2278*da0073e9SAndroid Build Coastguard Worker nt2, _ = jagged_from_list([a, b, c], nt.offsets()) 2279*da0073e9SAndroid Build Coastguard Worker 2280*da0073e9SAndroid Build Coastguard Worker def fn1(nt1, nt2): 2281*da0073e9SAndroid Build Coastguard Worker return (nt1 + nt2).sin().cos() 2282*da0073e9SAndroid Build Coastguard Worker 2283*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) 2284*da0073e9SAndroid Build Coastguard Worker out = compiled_f(nt, nt2) 2285*da0073e9SAndroid Build Coastguard Worker out_buffer = out.values() 2286*da0073e9SAndroid Build Coastguard Worker ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) 2287*da0073e9SAndroid Build Coastguard Worker 2288*da0073e9SAndroid Build Coastguard Worker out_ref = fn1(nt, nt2) 2289*da0073e9SAndroid Build Coastguard Worker out_buffer_ref = out_ref.values() 2290*da0073e9SAndroid Build Coastguard Worker ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) 2291*da0073e9SAndroid Build Coastguard Worker 2292*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(ga, ga_ref)) 2293*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gb, gb_ref)) 2294*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gc, gc_ref)) 2295*da0073e9SAndroid Build Coastguard Worker 2296*da0073e9SAndroid Build Coastguard Worker def test_basic_autograd(self): 2297*da0073e9SAndroid Build Coastguard Worker self._test_autograd("aot_eager") 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker @requires_cuda 2300*da0073e9SAndroid Build Coastguard Worker def test_basic_autograd_inductor(self): 2301*da0073e9SAndroid Build Coastguard Worker self._test_autograd("inductor") 2302*da0073e9SAndroid Build Coastguard Worker 2303*da0073e9SAndroid Build Coastguard Worker def test_subclass_with_mutation_in_graph(self): 2304*da0073e9SAndroid Build Coastguard Worker # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed 2305*da0073e9SAndroid Build Coastguard Worker # to remain in the graph. Normally this is allowed, but it's not allowed if 2306*da0073e9SAndroid Build Coastguard Worker # the graph handles subclasses at all. 2307*da0073e9SAndroid Build Coastguard Worker # Whether the mutation is allowed or not allowed in the graph alters the number 2308*da0073e9SAndroid Build Coastguard Worker # of outputs from the forward graph. Previously, a bug in this handling meant 2309*da0073e9SAndroid Build Coastguard Worker # that sometimes the expected number and actual number of outputs from the 2310*da0073e9SAndroid Build Coastguard Worker # joint graph did not match, causing assertion failures. 2311*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 2312*da0073e9SAndroid Build Coastguard Worker z = x.sin() 2313*da0073e9SAndroid Build Coastguard Worker y.sin_() 2314*da0073e9SAndroid Build Coastguard Worker return z.cos(), y.cos() 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard Worker fn_c = torch.compile(fn, backend="inductor") 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] 2319*da0073e9SAndroid Build Coastguard Worker values_copy = [x.detach().clone().requires_grad_(True) for x in values] 2320*da0073e9SAndroid Build Coastguard Worker 2321*da0073e9SAndroid Build Coastguard Worker nt, offsets = jagged_from_list(values, None) 2322*da0073e9SAndroid Build Coastguard Worker nt_copy, offsets = jagged_from_list(values_copy, offsets) 2323*da0073e9SAndroid Build Coastguard Worker y = torch.rand((4, 8)) 2324*da0073e9SAndroid Build Coastguard Worker y_copy = y.clone() 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker ret = fn_c(nt, y)[0] 2327*da0073e9SAndroid Build Coastguard Worker ref = fn(nt_copy, y_copy)[0] 2328*da0073e9SAndroid Build Coastguard Worker 2329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ret.values(), ref.values()) 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker ret.values().sum().backward() 2332*da0073e9SAndroid Build Coastguard Worker ref.values().sum().backward() 2333*da0073e9SAndroid Build Coastguard Worker for ref_v, res_v in zip(values_copy, values): 2334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_v.grad, res_v.grad) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch({"capture_scalar_outputs": True}) 2337*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 2338*da0073e9SAndroid Build Coastguard Worker # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). 2339*da0073e9SAndroid Build Coastguard Worker # This causes a recompile later on when it realizes the batch and last dim 2340*da0073e9SAndroid Build Coastguard Worker # should not always be equal. To avoid that, we use (3, j0, 5) here. 2341*da0073e9SAndroid Build Coastguard Worker nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2342*da0073e9SAndroid Build Coastguard Worker nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) 2343*da0073e9SAndroid Build Coastguard Worker nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker def fn(x): 2346*da0073e9SAndroid Build Coastguard Worker return x.unbind() 2347*da0073e9SAndroid Build Coastguard Worker 2348*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) 2349*da0073e9SAndroid Build Coastguard Worker out = compiled_f(nt) 2350*da0073e9SAndroid Build Coastguard Worker 2351*da0073e9SAndroid Build Coastguard Worker out_ref = fn(nt) 2352*da0073e9SAndroid Build Coastguard Worker 2353*da0073e9SAndroid Build Coastguard Worker # correctness 2354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(out), len(out_ref)) 2355*da0073e9SAndroid Build Coastguard Worker for x, x_ref in zip(out, out_ref): 2356*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(x, x_ref)) 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker # We specialize on the length of offsets, e.g. (1) we recompile if the 2359*da0073e9SAndroid Build Coastguard Worker # length of the offsets is different. (2) we don't recompile if the 2360*da0073e9SAndroid Build Coastguard Worker # length of the offsets is the same, even if the size of the constituent 2361*da0073e9SAndroid Build Coastguard Worker # tensors are different. 2362*da0073e9SAndroid Build Coastguard Worker self._check_recompiles(fn, (nt,), (nt2,), False) 2363*da0073e9SAndroid Build Coastguard Worker self._check_recompiles(fn, (nt,), (nt3,), True) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker def test_inline_nested_tensor_from_jagged(self): 2366*da0073e9SAndroid Build Coastguard Worker nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker def fn(x): 2369*da0073e9SAndroid Build Coastguard Worker return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker # The test here: nn.Parameters that are secretly subclasses 2374*da0073e9SAndroid Build Coastguard Worker # have a metaclass that overrides __isinstance__, 2375*da0073e9SAndroid Build Coastguard Worker # that dynamo needs to respect when it inlines the if statement. 2376*da0073e9SAndroid Build Coastguard Worker def test_param_subclass_isinstance_input(self): 2377*da0073e9SAndroid Build Coastguard Worker x_inner = torch.randn(16, 16, requires_grad=True) 2378*da0073e9SAndroid Build Coastguard Worker x = torch.nn.Parameter(TwoTensor(x_inner, x_inner)) 2379*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(16, 16) 2380*da0073e9SAndroid Build Coastguard Worker m.weight = x 2381*da0073e9SAndroid Build Coastguard Worker 2382*da0073e9SAndroid Build Coastguard Worker def fn(): 2383*da0073e9SAndroid Build Coastguard Worker if isinstance(m.weight, torch.nn.Parameter): 2384*da0073e9SAndroid Build Coastguard Worker return m.weight + 1 2385*da0073e9SAndroid Build Coastguard Worker else: 2386*da0073e9SAndroid Build Coastguard Worker return m.weight + 2 2387*da0073e9SAndroid Build Coastguard Worker 2388*da0073e9SAndroid Build Coastguard Worker out_ref = fn() 2389*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(fn, backend="aot_eager")() 2390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker def _input_view_test(self, nt_view_name): 2393*da0073e9SAndroid Build Coastguard Worker nt_view = VIEW_TEST_CASES[nt_view_name]() 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker def fn(x): 2396*da0073e9SAndroid Build Coastguard Worker return x.sin() 2397*da0073e9SAndroid Build Coastguard Worker 2398*da0073e9SAndroid Build Coastguard Worker out_ref = fn(nt_view) 2399*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2400*da0073e9SAndroid Build Coastguard Worker compile_fn = torch.compile( 2401*da0073e9SAndroid Build Coastguard Worker fn, fullgraph=True, backend="aot_eager", dynamic=True 2402*da0073e9SAndroid Build Coastguard Worker ) 2403*da0073e9SAndroid Build Coastguard Worker out = compile_fn(nt_view) 2404*da0073e9SAndroid Build Coastguard Worker 2405*da0073e9SAndroid Build Coastguard Worker # Check metadata and values are correct 2406*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.size() == out_ref.size()) 2407*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.stride() == out_ref.stride()) 2408*da0073e9SAndroid Build Coastguard Worker if out.is_nested: 2409*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out.values(), out_ref.values())) 2410*da0073e9SAndroid Build Coastguard Worker else: 2411*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(out, out_ref)) 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker # Check that no upper/lower bound guards are incurred 2414*da0073e9SAndroid Build Coastguard Worker def backend(gm, args): 2415*da0073e9SAndroid Build Coastguard Worker context = torch._guards.TracingContext.get() 2416*da0073e9SAndroid Build Coastguard Worker guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker # varies based on the type of view 2419*da0073e9SAndroid Build Coastguard Worker guard_str = "\n".join(guards) 2420*da0073e9SAndroid Build Coastguard Worker if nt_view_name == "subclass_dense": 2421*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") 2422*da0073e9SAndroid Build Coastguard Worker elif nt_view_name == "dense_subclass_dense_subclass": 2423*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2424*da0073e9SAndroid Build Coastguard Worker guard_str, 2425*da0073e9SAndroid Build Coastguard Worker """\ 2426*da0073e9SAndroid Build Coastguard WorkerEq(s5 - 1, s2) 2427*da0073e9SAndroid Build Coastguard WorkerEq(s12 - 1, s7) 2428*da0073e9SAndroid Build Coastguard WorkerEq(s11, s9)""", 2429*da0073e9SAndroid Build Coastguard Worker ) 2430*da0073e9SAndroid Build Coastguard Worker elif nt_view_name.startswith("base_is_nt_True"): 2431*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2432*da0073e9SAndroid Build Coastguard Worker guard_str, 2433*da0073e9SAndroid Build Coastguard Worker """Eq(s3 - 1, s0)""", 2434*da0073e9SAndroid Build Coastguard Worker ) 2435*da0073e9SAndroid Build Coastguard Worker else: 2436*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2437*da0073e9SAndroid Build Coastguard Worker guard_str, 2438*da0073e9SAndroid Build Coastguard Worker """\ 2439*da0073e9SAndroid Build Coastguard WorkerEq(s4 - 1, s1) 2440*da0073e9SAndroid Build Coastguard WorkerEq(s13 - 1, s8) 2441*da0073e9SAndroid Build Coastguard WorkerEq(s12, s10)""", 2442*da0073e9SAndroid Build Coastguard Worker ) 2443*da0073e9SAndroid Build Coastguard Worker return gm 2444*da0073e9SAndroid Build Coastguard Worker 2445*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2446*da0073e9SAndroid Build Coastguard Worker compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) 2447*da0073e9SAndroid Build Coastguard Worker out = compile_fn(nt_view) 2448*da0073e9SAndroid Build Coastguard Worker 2449*da0073e9SAndroid Build Coastguard Worker @parametrize( 2450*da0073e9SAndroid Build Coastguard Worker "nt_view_name", 2451*da0073e9SAndroid Build Coastguard Worker [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], 2452*da0073e9SAndroid Build Coastguard Worker ) 2453*da0073e9SAndroid Build Coastguard Worker def test_inputs_to_compiled_fn_are_views(self, nt_view_name): 2454*da0073e9SAndroid Build Coastguard Worker self._input_view_test(nt_view_name) 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker def test_subclass_gives_static_shapes_when_dynamic_false(self): 2457*da0073e9SAndroid Build Coastguard Worker def check_graph(gm, *args): 2458*da0073e9SAndroid Build Coastguard Worker first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"] 2459*da0073e9SAndroid Build Coastguard Worker # We compiled with dynamic=False, expect no SymInt sizes on our placeholders 2460*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2461*da0073e9SAndroid Build Coastguard Worker all(isinstance(x, int) for x in first_node_example_val.shape) 2462*da0073e9SAndroid Build Coastguard Worker ) 2463*da0073e9SAndroid Build Coastguard Worker return gm 2464*da0073e9SAndroid Build Coastguard Worker 2465*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=check_graph, dynamic=False) 2466*da0073e9SAndroid Build Coastguard Worker def f(x): 2467*da0073e9SAndroid Build Coastguard Worker return x + 1 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker x_inner = torch.ones(4) 2470*da0073e9SAndroid Build Coastguard Worker x = TwoTensor(x_inner, x_inner) 2471*da0073e9SAndroid Build Coastguard Worker x_view = x.view(2, 2) 2472*da0073e9SAndroid Build Coastguard Worker out = f(x_view) 2473*da0073e9SAndroid Build Coastguard Worker 2474*da0073e9SAndroid Build Coastguard Worker # NJT1 -> Dense -> NJT2 -> Dense view 2475*da0073e9SAndroid Build Coastguard Worker # During view replay, the Dense -> NJT2 part will construct an intermediate, 2476*da0073e9SAndroid Build Coastguard Worker # symbolically-sized NJT that is immediately deconstructed to return the final dense 2477*da0073e9SAndroid Build Coastguard Worker # view. To construct this intermediate properly, we need the associated nested int 2478*da0073e9SAndroid Build Coastguard Worker # to be symbolic. This view is expected to fail compilation until symbolic nested ints 2479*da0073e9SAndroid Build Coastguard Worker # are cached onto fake offsets to solve this problem. 2480*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 2481*da0073e9SAndroid Build Coastguard Worker def test_subclass_dense_subclass_dense_view(self): 2482*da0073e9SAndroid Build Coastguard Worker self._input_view_test("subclass_dense_subclass_dense") 2483*da0073e9SAndroid Build Coastguard Worker 2484*da0073e9SAndroid Build Coastguard Worker 2485*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedTensor) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker 2488*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2489*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 2490*da0073e9SAndroid Build Coastguard Worker 2491*da0073e9SAndroid Build Coastguard Worker run_tests() 2492