1*da0073e9SAndroid Build Coastguard Worker""" 2*da0073e9SAndroid Build Coastguard WorkerPYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 3*da0073e9SAndroid Build Coastguard Workerwith test_rewrite_assert_with_msg and test_rewrite_assert_without_msg) 4*da0073e9SAndroid Build Coastguard Worker""" 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 7*da0073e9SAndroid Build Coastguard Workerimport collections 8*da0073e9SAndroid Build Coastguard Workerimport contextlib 9*da0073e9SAndroid Build Coastguard Workerimport copy 10*da0073e9SAndroid Build Coastguard Workerimport dataclasses 11*da0073e9SAndroid Build Coastguard Workerimport functools 12*da0073e9SAndroid Build Coastguard Workerimport gc 13*da0073e9SAndroid Build Coastguard Workerimport inspect 14*da0073e9SAndroid Build Coastguard Workerimport itertools 15*da0073e9SAndroid Build Coastguard Workerimport os 16*da0073e9SAndroid Build Coastguard Workerimport random 17*da0073e9SAndroid Build Coastguard Workerimport unittest 18*da0073e9SAndroid Build Coastguard Workerimport warnings 19*da0073e9SAndroid Build Coastguard Workerimport weakref 20*da0073e9SAndroid Build Coastguard Workerfrom abc import ABC 21*da0073e9SAndroid Build Coastguard Workerfrom collections import namedtuple 22*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 23*da0073e9SAndroid Build Coastguard Workerfrom enum import Enum 24*da0073e9SAndroid Build Coastguard Workerfrom functools import wraps 25*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Dict, Iterator, List, Tuple 26*da0073e9SAndroid Build Coastguard Workerfrom unittest import mock 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerimport numpy as np 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Workerimport torch 31*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 32*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 33*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.utils 34*da0073e9SAndroid Build Coastguard Workerimport torch._functorch.config 35*da0073e9SAndroid Build Coastguard Workerimport torch.library 36*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as pytree 37*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 38*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.debug_utils import same_two_models 39*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter, rand_strided, same 40*da0073e9SAndroid Build Coastguard Workerfrom torch._inductor.utils import fresh_inductor_cache 41*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F 42*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION 43*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 44*da0073e9SAndroid Build Coastguard Worker disable_translation_validation_if_dynamic_shapes, 45*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 46*da0073e9SAndroid Build Coastguard Worker parametrize, 47*da0073e9SAndroid Build Coastguard Worker skipIfWindows, 48*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 49*da0073e9SAndroid Build Coastguard Worker) 50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker_orig_module_call = torch.nn.Module.__call__ 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker# Custom operator that only supports CPU and Meta 56*da0073e9SAndroid Build Coastguard Workerlib = torch.library.Library("test_sample", "DEF") # noqa: TOR901 57*da0073e9SAndroid Build Coastguard Workerlib.define("foo(Tensor self) -> Tensor") 58*da0073e9SAndroid Build Coastguard Workerlib.impl("foo", torch.sin, "CPU") 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker_GLOBAL_CPU_TENSOR = torch.randn(3) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Workerdef exists(val): 68*da0073e9SAndroid Build Coastguard Worker return val is not None 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef maybe(fn): 72*da0073e9SAndroid Build Coastguard Worker @wraps(fn) 73*da0073e9SAndroid Build Coastguard Worker def inner(x, *args, **kwargs): 74*da0073e9SAndroid Build Coastguard Worker if not exists(x): 75*da0073e9SAndroid Build Coastguard Worker return x 76*da0073e9SAndroid Build Coastguard Worker return fn(x, *args, **kwargs) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker return inner 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Workerdef is_fx_tracing_test() -> bool: 82*da0073e9SAndroid Build Coastguard Worker """ 83*da0073e9SAndroid Build Coastguard Worker Copied from the hpc trainer codebase 84*da0073e9SAndroid Build Coastguard Worker """ 85*da0073e9SAndroid Build Coastguard Worker return torch.nn.Module.__call__ is not _orig_module_call 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workerdef has_detectron2(): 89*da0073e9SAndroid Build Coastguard Worker try: 90*da0073e9SAndroid Build Coastguard Worker from detectron2.layers.mask_ops import _paste_masks_tensor_shape 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker return _paste_masks_tensor_shape is not None 93*da0073e9SAndroid Build Coastguard Worker except ImportError: 94*da0073e9SAndroid Build Coastguard Worker return False 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerdef _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): 98*da0073e9SAndroid Build Coastguard Worker # from detectron2 mask_ops.py 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker device = masks.device 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker if skip_empty and not torch.jit.is_scripting(): 103*da0073e9SAndroid Build Coastguard Worker x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( 104*da0073e9SAndroid Build Coastguard Worker dtype=torch.int32 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to( 107*da0073e9SAndroid Build Coastguard Worker dtype=torch.int32 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to( 110*da0073e9SAndroid Build Coastguard Worker dtype=torch.int32 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker else: 113*da0073e9SAndroid Build Coastguard Worker x0_int, y0_int = 0, 0 114*da0073e9SAndroid Build Coastguard Worker x1_int, y1_int = img_w, img_h 115*da0073e9SAndroid Build Coastguard Worker x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker N = masks.shape[0] 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 120*da0073e9SAndroid Build Coastguard Worker img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 121*da0073e9SAndroid Build Coastguard Worker img_y = (img_y - y0) / (y1 - y0) * 2 - 1 122*da0073e9SAndroid Build Coastguard Worker img_x = (img_x - x0) / (x1 - x0) * 2 - 1 123*da0073e9SAndroid Build Coastguard Worker # img_x, img_y have shapes (N, w), (N, h) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) 126*da0073e9SAndroid Build Coastguard Worker gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) 127*da0073e9SAndroid Build Coastguard Worker grid = torch.stack([gx, gy], dim=3) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 130*da0073e9SAndroid Build Coastguard Worker if not masks.dtype.is_floating_point: 131*da0073e9SAndroid Build Coastguard Worker masks = masks.float() 132*da0073e9SAndroid Build Coastguard Worker img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Worker if skip_empty and not torch.jit.is_scripting(): 135*da0073e9SAndroid Build Coastguard Worker return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) 136*da0073e9SAndroid Build Coastguard Worker else: 137*da0073e9SAndroid Build Coastguard Worker return img_masks[:, 0], () 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Workerdef global_fn(x): 141*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workerdef cat(tensors, dim=0): 145*da0073e9SAndroid Build Coastguard Worker # from detectron2 wrappers.py 146*da0073e9SAndroid Build Coastguard Worker assert isinstance(tensors, (list, tuple)) 147*da0073e9SAndroid Build Coastguard Worker if len(tensors) == 1: 148*da0073e9SAndroid Build Coastguard Worker return tensors[0] 149*da0073e9SAndroid Build Coastguard Worker return torch.cat(tensors, dim) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Workerdef shapes_to_tensor(x, device=None): 153*da0073e9SAndroid Build Coastguard Worker # from detectron2 wrappers.py 154*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 155*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(x, device=device) 156*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_tracing(): 157*da0073e9SAndroid Build Coastguard Worker assert all( 158*da0073e9SAndroid Build Coastguard Worker isinstance(t, torch.Tensor) for t in x 159*da0073e9SAndroid Build Coastguard Worker ), "Shape should be tensor during tracing!" 160*da0073e9SAndroid Build Coastguard Worker # as_tensor should not be used in tracing because it records a constant 161*da0073e9SAndroid Build Coastguard Worker ret = torch.stack(x) 162*da0073e9SAndroid Build Coastguard Worker if ret.device != device: # avoid recording a hard-coded device if not necessary 163*da0073e9SAndroid Build Coastguard Worker ret = ret.to(device=device) 164*da0073e9SAndroid Build Coastguard Worker return ret 165*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(x, device=device) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Workerfw_graph = [None] 169*da0073e9SAndroid Build Coastguard Workerbw_graph = [None] 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Workerdef aot_graph_capture_backend(gm, args): 173*da0073e9SAndroid Build Coastguard Worker from functorch.compile import min_cut_rematerialization_partition 174*da0073e9SAndroid Build Coastguard Worker from torch._functorch.aot_autograd import aot_module_simplified 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker def fw_compiler(gm, _): 177*da0073e9SAndroid Build Coastguard Worker fw_graph[0] = gm 178*da0073e9SAndroid Build Coastguard Worker return gm 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def bw_compiler(gm, _): 181*da0073e9SAndroid Build Coastguard Worker bw_graph[0] = gm 182*da0073e9SAndroid Build Coastguard Worker return gm 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker return aot_module_simplified( 185*da0073e9SAndroid Build Coastguard Worker gm, 186*da0073e9SAndroid Build Coastguard Worker args, 187*da0073e9SAndroid Build Coastguard Worker fw_compiler, 188*da0073e9SAndroid Build Coastguard Worker bw_compiler, 189*da0073e9SAndroid Build Coastguard Worker partition_fn=min_cut_rematerialization_partition, 190*da0073e9SAndroid Build Coastguard Worker keep_inference_input_mutations=True, 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Workerclass Boxes: 195*da0073e9SAndroid Build Coastguard Worker # from detectron2 poolers.py 196*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor: torch.Tensor): 197*da0073e9SAndroid Build Coastguard Worker """ 198*da0073e9SAndroid Build Coastguard Worker Args: 199*da0073e9SAndroid Build Coastguard Worker tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). 200*da0073e9SAndroid Build Coastguard Worker """ 201*da0073e9SAndroid Build Coastguard Worker device = ( 202*da0073e9SAndroid Build Coastguard Worker tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 203*da0073e9SAndroid Build Coastguard Worker ) 204*da0073e9SAndroid Build Coastguard Worker tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 205*da0073e9SAndroid Build Coastguard Worker if tensor.numel() == 0: 206*da0073e9SAndroid Build Coastguard Worker # Use reshape, so we don't end up creating a new tensor that does not depend on 207*da0073e9SAndroid Build Coastguard Worker # the inputs (and consequently confuses jit) 208*da0073e9SAndroid Build Coastguard Worker tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device) 209*da0073e9SAndroid Build Coastguard Worker assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() 210*da0073e9SAndroid Build Coastguard Worker self.tensor = tensor 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker def __len__(self) -> int: 213*da0073e9SAndroid Build Coastguard Worker return self.tensor.shape[0] 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker @property 216*da0073e9SAndroid Build Coastguard Worker def device(self): 217*da0073e9SAndroid Build Coastguard Worker return self.tensor.device 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Workerdef convert_boxes_to_pooler_format(box_lists): 221*da0073e9SAndroid Build Coastguard Worker # from detectron2 structures.py 222*da0073e9SAndroid Build Coastguard Worker boxes = torch.cat([x.tensor for x in box_lists], dim=0) 223*da0073e9SAndroid Build Coastguard Worker # __len__ returns Tensor in tracing. 224*da0073e9SAndroid Build Coastguard Worker sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device) 225*da0073e9SAndroid Build Coastguard Worker indices = torch.repeat_interleave( 226*da0073e9SAndroid Build Coastguard Worker torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes 227*da0073e9SAndroid Build Coastguard Worker ) 228*da0073e9SAndroid Build Coastguard Worker return cat([indices[:, None], boxes], dim=1) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard WorkerReformerBackwardOutput = namedtuple( 232*da0073e9SAndroid Build Coastguard Worker "ReformerBackwardOutput", 233*da0073e9SAndroid Build Coastguard Worker ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"], 234*da0073e9SAndroid Build Coastguard Worker) 235*da0073e9SAndroid Build Coastguard WorkerReformerEncoderOutput = namedtuple( 236*da0073e9SAndroid Build Coastguard Worker "ReformerEncoderOutput", 237*da0073e9SAndroid Build Coastguard Worker ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], 238*da0073e9SAndroid Build Coastguard Worker) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Workerclass _ReversibleFunction(torch.autograd.Function): 242*da0073e9SAndroid Build Coastguard Worker # taken from modeling_reformer.py in huggingface 243*da0073e9SAndroid Build Coastguard Worker @staticmethod 244*da0073e9SAndroid Build Coastguard Worker def forward( 245*da0073e9SAndroid Build Coastguard Worker ctx, 246*da0073e9SAndroid Build Coastguard Worker hidden_states, 247*da0073e9SAndroid Build Coastguard Worker layers, 248*da0073e9SAndroid Build Coastguard Worker attention_mask, 249*da0073e9SAndroid Build Coastguard Worker head_mask, 250*da0073e9SAndroid Build Coastguard Worker num_hashes, 251*da0073e9SAndroid Build Coastguard Worker all_hidden_states, 252*da0073e9SAndroid Build Coastguard Worker all_attentions, 253*da0073e9SAndroid Build Coastguard Worker past_buckets_states, 254*da0073e9SAndroid Build Coastguard Worker use_cache, 255*da0073e9SAndroid Build Coastguard Worker orig_sequence_length, 256*da0073e9SAndroid Build Coastguard Worker output_hidden_states, 257*da0073e9SAndroid Build Coastguard Worker output_attentions, 258*da0073e9SAndroid Build Coastguard Worker ): 259*da0073e9SAndroid Build Coastguard Worker all_buckets = () 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker # split duplicated tensor 262*da0073e9SAndroid Build Coastguard Worker hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) 263*da0073e9SAndroid Build Coastguard Worker 264*da0073e9SAndroid Build Coastguard Worker for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): 265*da0073e9SAndroid Build Coastguard Worker if output_hidden_states is True: 266*da0073e9SAndroid Build Coastguard Worker all_hidden_states.append(hidden_states) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker attn_output = layer(attn_output) 269*da0073e9SAndroid Build Coastguard Worker all_buckets = all_buckets + (attn_output,) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker # Add last layer 272*da0073e9SAndroid Build Coastguard Worker if output_hidden_states is True: 273*da0073e9SAndroid Build Coastguard Worker all_hidden_states.append(hidden_states) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker # attach params to ctx for backward 276*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) 277*da0073e9SAndroid Build Coastguard Worker ctx.layers = layers 278*da0073e9SAndroid Build Coastguard Worker ctx.all_buckets = all_buckets 279*da0073e9SAndroid Build Coastguard Worker ctx.head_mask = head_mask 280*da0073e9SAndroid Build Coastguard Worker ctx.attention_mask = attention_mask 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker # Concatenate 2 RevNet outputs 283*da0073e9SAndroid Build Coastguard Worker return torch.cat([attn_output, hidden_states], dim=-1) 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker @staticmethod 286*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_hidden_states): 287*da0073e9SAndroid Build Coastguard Worker grad_attn_output, grad_hidden_states = torch.chunk( 288*da0073e9SAndroid Build Coastguard Worker grad_hidden_states, 2, dim=-1 289*da0073e9SAndroid Build Coastguard Worker ) 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker # free memory 292*da0073e9SAndroid Build Coastguard Worker del grad_attn_output 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # num of return vars has to match num of forward() args 295*da0073e9SAndroid Build Coastguard Worker # return gradient for hidden_states arg and None for other args 296*da0073e9SAndroid Build Coastguard Worker return ( 297*da0073e9SAndroid Build Coastguard Worker grad_hidden_states, 298*da0073e9SAndroid Build Coastguard Worker None, 299*da0073e9SAndroid Build Coastguard Worker None, 300*da0073e9SAndroid Build Coastguard Worker None, 301*da0073e9SAndroid Build Coastguard Worker None, 302*da0073e9SAndroid Build Coastguard Worker None, 303*da0073e9SAndroid Build Coastguard Worker None, 304*da0073e9SAndroid Build Coastguard Worker None, 305*da0073e9SAndroid Build Coastguard Worker None, 306*da0073e9SAndroid Build Coastguard Worker None, 307*da0073e9SAndroid Build Coastguard Worker None, 308*da0073e9SAndroid Build Coastguard Worker None, 309*da0073e9SAndroid Build Coastguard Worker ) 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Workerclass ReformerEncoder(torch.nn.Module): 313*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 314*da0073e9SAndroid Build Coastguard Worker super().__init__() 315*da0073e9SAndroid Build Coastguard Worker self.dropout = 0.5 316*da0073e9SAndroid Build Coastguard Worker self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12) 317*da0073e9SAndroid Build Coastguard Worker self.layers = [torch.nn.Linear(256, 256)] 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker def forward( 320*da0073e9SAndroid Build Coastguard Worker self, 321*da0073e9SAndroid Build Coastguard Worker hidden_states, 322*da0073e9SAndroid Build Coastguard Worker attention_mask=None, 323*da0073e9SAndroid Build Coastguard Worker head_mask=[None] * 6, 324*da0073e9SAndroid Build Coastguard Worker num_hashes=None, 325*da0073e9SAndroid Build Coastguard Worker use_cache=False, 326*da0073e9SAndroid Build Coastguard Worker orig_sequence_length=64, 327*da0073e9SAndroid Build Coastguard Worker output_hidden_states=False, 328*da0073e9SAndroid Build Coastguard Worker output_attentions=False, 329*da0073e9SAndroid Build Coastguard Worker ): 330*da0073e9SAndroid Build Coastguard Worker # hidden_states and attention lists to be filled if wished 331*da0073e9SAndroid Build Coastguard Worker all_hidden_states = [] 332*da0073e9SAndroid Build Coastguard Worker all_attentions = [] 333*da0073e9SAndroid Build Coastguard Worker past_buckets_states = [((None), (None)) for i in range(len(self.layers))] 334*da0073e9SAndroid Build Coastguard Worker 335*da0073e9SAndroid Build Coastguard Worker # concat same tensor for reversible ResNet 336*da0073e9SAndroid Build Coastguard Worker hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) 337*da0073e9SAndroid Build Coastguard Worker hidden_states = _ReversibleFunction.apply( 338*da0073e9SAndroid Build Coastguard Worker hidden_states, 339*da0073e9SAndroid Build Coastguard Worker self.layers, 340*da0073e9SAndroid Build Coastguard Worker attention_mask, 341*da0073e9SAndroid Build Coastguard Worker head_mask, 342*da0073e9SAndroid Build Coastguard Worker num_hashes, 343*da0073e9SAndroid Build Coastguard Worker all_hidden_states, 344*da0073e9SAndroid Build Coastguard Worker all_attentions, 345*da0073e9SAndroid Build Coastguard Worker past_buckets_states, 346*da0073e9SAndroid Build Coastguard Worker use_cache, 347*da0073e9SAndroid Build Coastguard Worker orig_sequence_length, 348*da0073e9SAndroid Build Coastguard Worker output_hidden_states, 349*da0073e9SAndroid Build Coastguard Worker output_attentions, 350*da0073e9SAndroid Build Coastguard Worker ) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker # Apply layer norm to concatenated hidden states 353*da0073e9SAndroid Build Coastguard Worker hidden_states = self.layer_norm(hidden_states) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker # Apply dropout 356*da0073e9SAndroid Build Coastguard Worker hidden_states = torch.nn.functional.dropout( 357*da0073e9SAndroid Build Coastguard Worker hidden_states, p=self.dropout, training=self.training 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker return ReformerEncoderOutput( 361*da0073e9SAndroid Build Coastguard Worker hidden_states=hidden_states, 362*da0073e9SAndroid Build Coastguard Worker all_hidden_states=all_hidden_states, 363*da0073e9SAndroid Build Coastguard Worker all_attentions=all_attentions, 364*da0073e9SAndroid Build Coastguard Worker past_buckets_states=past_buckets_states, 365*da0073e9SAndroid Build Coastguard Worker ) 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Workerclass ListConfig: 369*da0073e9SAndroid Build Coastguard Worker class ValueNode: 370*da0073e9SAndroid Build Coastguard Worker def __init__(self, value): 371*da0073e9SAndroid Build Coastguard Worker self.value = value 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker def _dereference_node(self): 374*da0073e9SAndroid Build Coastguard Worker return self 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def _is_missing(self): 377*da0073e9SAndroid Build Coastguard Worker return False 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker def _value(self): 380*da0073e9SAndroid Build Coastguard Worker return self.value 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker # Based on an example from omegaconfig.listconfig 383*da0073e9SAndroid Build Coastguard Worker class ListIterator(Iterator[Any]): 384*da0073e9SAndroid Build Coastguard Worker def __init__(self, lst: Any, resolve: bool) -> None: 385*da0073e9SAndroid Build Coastguard Worker self.resolve = resolve 386*da0073e9SAndroid Build Coastguard Worker self.iterator = iter(lst.__dict__["_content"]) 387*da0073e9SAndroid Build Coastguard Worker self.index = 0 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker def __next__(self) -> Any: 390*da0073e9SAndroid Build Coastguard Worker x = next(self.iterator) 391*da0073e9SAndroid Build Coastguard Worker if self.resolve: 392*da0073e9SAndroid Build Coastguard Worker x = x._dereference_node() 393*da0073e9SAndroid Build Coastguard Worker if x._is_missing(): 394*da0073e9SAndroid Build Coastguard Worker raise AssertionError 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker self.index = self.index + 1 397*da0073e9SAndroid Build Coastguard Worker if isinstance(x, ListConfig.ValueNode): 398*da0073e9SAndroid Build Coastguard Worker return x._value() 399*da0073e9SAndroid Build Coastguard Worker raise AssertionError 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 402*da0073e9SAndroid Build Coastguard Worker return self._iter_ex(True) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker def _iter_ex(self, resolve: bool) -> Iterator[Any]: 405*da0073e9SAndroid Build Coastguard Worker try: 406*da0073e9SAndroid Build Coastguard Worker return ListConfig.ListIterator(self, resolve) 407*da0073e9SAndroid Build Coastguard Worker except Exception: 408*da0073e9SAndroid Build Coastguard Worker raise AssertionError from None 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 411*da0073e9SAndroid Build Coastguard Worker self._content = [ 412*da0073e9SAndroid Build Coastguard Worker ListConfig.ValueNode(1), 413*da0073e9SAndroid Build Coastguard Worker ListConfig.ValueNode(3), 414*da0073e9SAndroid Build Coastguard Worker ListConfig.ValueNode(torch.tensor([7.0])), 415*da0073e9SAndroid Build Coastguard Worker ] 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Workerdef longformer_chunk(hidden_states, window_overlap=256): 419*da0073e9SAndroid Build Coastguard Worker """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker # non-overlapping chunks of size = 2w 422*da0073e9SAndroid Build Coastguard Worker hidden_states = hidden_states.view( 423*da0073e9SAndroid Build Coastguard Worker hidden_states.size(0), 424*da0073e9SAndroid Build Coastguard Worker hidden_states.size(1) // (window_overlap * 2), 425*da0073e9SAndroid Build Coastguard Worker window_overlap * 2, 426*da0073e9SAndroid Build Coastguard Worker hidden_states.size(2), 427*da0073e9SAndroid Build Coastguard Worker ) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker # use `as_strided` to make the chunks overlap with an overlap size = window_overlap 430*da0073e9SAndroid Build Coastguard Worker chunk_size = list(hidden_states.size()) 431*da0073e9SAndroid Build Coastguard Worker chunk_size[1] = chunk_size[1] * 2 - 1 432*da0073e9SAndroid Build Coastguard Worker 433*da0073e9SAndroid Build Coastguard Worker chunk_stride = list(hidden_states.stride()) 434*da0073e9SAndroid Build Coastguard Worker chunk_stride[1] = chunk_stride[1] // 2 435*da0073e9SAndroid Build Coastguard Worker return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Workerclass PartialT5(torch.nn.Module): 439*da0073e9SAndroid Build Coastguard Worker # Highly simplified T5Attention prefix 440*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 441*da0073e9SAndroid Build Coastguard Worker super().__init__() 442*da0073e9SAndroid Build Coastguard Worker self.q = torch.nn.Linear(512, 512) 443*da0073e9SAndroid Build Coastguard Worker self.k = torch.nn.Linear(512, 512) 444*da0073e9SAndroid Build Coastguard Worker self.v = torch.nn.Linear(512, 512) 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Worker def forward( 447*da0073e9SAndroid Build Coastguard Worker self, 448*da0073e9SAndroid Build Coastguard Worker hidden_states, 449*da0073e9SAndroid Build Coastguard Worker key_value_states=None, 450*da0073e9SAndroid Build Coastguard Worker past_key_value=None, 451*da0073e9SAndroid Build Coastguard Worker query_length=None, 452*da0073e9SAndroid Build Coastguard Worker ): 453*da0073e9SAndroid Build Coastguard Worker batch_size, seq_length = hidden_states.shape[:2] 454*da0073e9SAndroid Build Coastguard Worker 455*da0073e9SAndroid Build Coastguard Worker real_seq_length = seq_length 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker if past_key_value is not None: 458*da0073e9SAndroid Build Coastguard Worker assert ( 459*da0073e9SAndroid Build Coastguard Worker len(past_key_value) == 2 460*da0073e9SAndroid Build Coastguard Worker ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 461*da0073e9SAndroid Build Coastguard Worker real_seq_length += ( 462*da0073e9SAndroid Build Coastguard Worker past_key_value[0].shape[2] if query_length is None else query_length 463*da0073e9SAndroid Build Coastguard Worker ) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker def shape(states): 466*da0073e9SAndroid Build Coastguard Worker """projection""" 467*da0073e9SAndroid Build Coastguard Worker return states.view(batch_size, -1, 8, 64).transpose(1, 2) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker def project(hidden_states, proj_layer, key_value_states, past_key_value): 470*da0073e9SAndroid Build Coastguard Worker """projects hidden states correctly to key/query states""" 471*da0073e9SAndroid Build Coastguard Worker if key_value_states is None: 472*da0073e9SAndroid Build Coastguard Worker # self-attn 473*da0073e9SAndroid Build Coastguard Worker # (batch_size, n_heads, seq_length, dim_per_head) 474*da0073e9SAndroid Build Coastguard Worker hidden_states = shape(proj_layer(hidden_states)) 475*da0073e9SAndroid Build Coastguard Worker elif past_key_value is None: 476*da0073e9SAndroid Build Coastguard Worker # cross-attn 477*da0073e9SAndroid Build Coastguard Worker # (batch_size, n_heads, seq_length, dim_per_head) 478*da0073e9SAndroid Build Coastguard Worker hidden_states = shape(proj_layer(key_value_states)) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker if past_key_value is not None: 481*da0073e9SAndroid Build Coastguard Worker if key_value_states is None: 482*da0073e9SAndroid Build Coastguard Worker # self-attn 483*da0073e9SAndroid Build Coastguard Worker # (batch_size, n_heads, key_length, dim_per_head) 484*da0073e9SAndroid Build Coastguard Worker hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 485*da0073e9SAndroid Build Coastguard Worker else: 486*da0073e9SAndroid Build Coastguard Worker # cross-attn 487*da0073e9SAndroid Build Coastguard Worker hidden_states = past_key_value 488*da0073e9SAndroid Build Coastguard Worker return hidden_states 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # get query states 491*da0073e9SAndroid Build Coastguard Worker query_states = shape( 492*da0073e9SAndroid Build Coastguard Worker self.q(hidden_states) 493*da0073e9SAndroid Build Coastguard Worker ) # (batch_size, n_heads, seq_length, dim_per_head) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker # get key/value states 496*da0073e9SAndroid Build Coastguard Worker key_states = project( 497*da0073e9SAndroid Build Coastguard Worker hidden_states, 498*da0073e9SAndroid Build Coastguard Worker self.k, 499*da0073e9SAndroid Build Coastguard Worker key_value_states, 500*da0073e9SAndroid Build Coastguard Worker past_key_value[0] if past_key_value is not None else None, 501*da0073e9SAndroid Build Coastguard Worker ) 502*da0073e9SAndroid Build Coastguard Worker value_states = project( 503*da0073e9SAndroid Build Coastguard Worker hidden_states, 504*da0073e9SAndroid Build Coastguard Worker self.v, 505*da0073e9SAndroid Build Coastguard Worker key_value_states, 506*da0073e9SAndroid Build Coastguard Worker past_key_value[1] if past_key_value is not None else None, 507*da0073e9SAndroid Build Coastguard Worker ) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker # compute scores 510*da0073e9SAndroid Build Coastguard Worker scores = torch.matmul(query_states, key_states.transpose(3, 2)) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker # (truncated here ) 513*da0073e9SAndroid Build Coastguard Worker return scores, value_states 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Workerclass ChunkReformerFeedForward(torch.nn.Module): 517*da0073e9SAndroid Build Coastguard Worker # simplified from HF modeling_reformer.py 518*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 519*da0073e9SAndroid Build Coastguard Worker super().__init__() 520*da0073e9SAndroid Build Coastguard Worker self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12) 521*da0073e9SAndroid Build Coastguard Worker self.dense = torch.nn.Linear(256, 256) 522*da0073e9SAndroid Build Coastguard Worker self.output = torch.nn.Linear(256, 256) 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker def forward(self, attention_output): 525*da0073e9SAndroid Build Coastguard Worker return apply_chunking_to_forward( 526*da0073e9SAndroid Build Coastguard Worker self.forward_chunk, 527*da0073e9SAndroid Build Coastguard Worker attention_output + 1, 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def forward_chunk(self, hidden_states): 531*da0073e9SAndroid Build Coastguard Worker hidden_states = self.layer_norm(hidden_states) 532*da0073e9SAndroid Build Coastguard Worker hidden_states = self.dense(hidden_states) 533*da0073e9SAndroid Build Coastguard Worker return self.output(hidden_states) 534*da0073e9SAndroid Build Coastguard Worker 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Workerdef apply_chunking_to_forward(forward_fn, *input_tensors): 537*da0073e9SAndroid Build Coastguard Worker # simplified from HF model_utils.py 538*da0073e9SAndroid Build Coastguard Worker assert len(input_tensors) > 0 539*da0073e9SAndroid Build Coastguard Worker tensor_shape = input_tensors[0].shape[1] 540*da0073e9SAndroid Build Coastguard Worker assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors) 541*da0073e9SAndroid Build Coastguard Worker num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) 542*da0073e9SAndroid Build Coastguard Worker if num_args_in_forward_chunk_fn != len(input_tensors): 543*da0073e9SAndroid Build Coastguard Worker raise ValueError 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker return forward_fn(*input_tensors) 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Workerdef _validate_model_kwargs(fn, model_kwargs): 549*da0073e9SAndroid Build Coastguard Worker # simplified from transformers.generation.utils._validate_model_kwargs 550*da0073e9SAndroid Build Coastguard Worker unused_model_args = [] 551*da0073e9SAndroid Build Coastguard Worker model_args = set(inspect.signature(fn).parameters) 552*da0073e9SAndroid Build Coastguard Worker for key, value in model_kwargs.items(): 553*da0073e9SAndroid Build Coastguard Worker if value is not None and key not in model_args: 554*da0073e9SAndroid Build Coastguard Worker unused_model_args.append(key) 555*da0073e9SAndroid Build Coastguard Worker if unused_model_args: 556*da0073e9SAndroid Build Coastguard Worker raise ValueError( 557*da0073e9SAndroid Build Coastguard Worker f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" 558*da0073e9SAndroid Build Coastguard Worker " generate arguments will also show up in this list)" 559*da0073e9SAndroid Build Coastguard Worker ) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Workerclass FakeMamlInner(torch.nn.Module): 563*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 564*da0073e9SAndroid Build Coastguard Worker super().__init__() 565*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(784, 5) 566*da0073e9SAndroid Build Coastguard Worker 567*da0073e9SAndroid Build Coastguard Worker def forward(self, x, ignored=None, bn_training=False): 568*da0073e9SAndroid Build Coastguard Worker return self.linear(x.view(x.shape[0], -1)) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Workerclass PartialMaml(torch.nn.Module): 572*da0073e9SAndroid Build Coastguard Worker # Highly simplified version of maml.meta.Meta.finetuning 573*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 574*da0073e9SAndroid Build Coastguard Worker super().__init__() 575*da0073e9SAndroid Build Coastguard Worker self.net = FakeMamlInner() 576*da0073e9SAndroid Build Coastguard Worker self.update_step_test = 10 577*da0073e9SAndroid Build Coastguard Worker self.update_lr = 0.4 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker def forward(self, x_spt, y_spt, x_qry, y_qry): 580*da0073e9SAndroid Build Coastguard Worker querysz = x_qry.size(0) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker corrects = [0 for _ in range(self.update_step_test + 1)] 583*da0073e9SAndroid Build Coastguard Worker 584*da0073e9SAndroid Build Coastguard Worker # in order to not ruin the state of running_mean/variance and bn_weight/bias 585*da0073e9SAndroid Build Coastguard Worker # we finetuning on the copied model instead of self.net 586*da0073e9SAndroid Build Coastguard Worker net = deepcopy(self.net) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker # 1. run the i-th task and compute loss for k=0 589*da0073e9SAndroid Build Coastguard Worker logits = net(x_spt) 590*da0073e9SAndroid Build Coastguard Worker loss = F.cross_entropy(logits, y_spt) 591*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(loss, net.parameters()) 592*da0073e9SAndroid Build Coastguard Worker fast_weights = [ 593*da0073e9SAndroid Build Coastguard Worker p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters()) 594*da0073e9SAndroid Build Coastguard Worker ] 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker # this is the loss and accuracy before first update 597*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 598*da0073e9SAndroid Build Coastguard Worker # [setsz, nway] 599*da0073e9SAndroid Build Coastguard Worker logits_q = net(x_qry, net.parameters(), bn_training=True) 600*da0073e9SAndroid Build Coastguard Worker # [setsz] 601*da0073e9SAndroid Build Coastguard Worker pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 602*da0073e9SAndroid Build Coastguard Worker # scalar 603*da0073e9SAndroid Build Coastguard Worker correct = torch.eq(pred_q, y_qry).sum().item() 604*da0073e9SAndroid Build Coastguard Worker corrects[0] = corrects[0] + correct 605*da0073e9SAndroid Build Coastguard Worker 606*da0073e9SAndroid Build Coastguard Worker # this is the loss and accuracy after the first update 607*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 608*da0073e9SAndroid Build Coastguard Worker # [setsz, nway] 609*da0073e9SAndroid Build Coastguard Worker logits_q = net(x_qry, fast_weights, bn_training=True) 610*da0073e9SAndroid Build Coastguard Worker # [setsz] 611*da0073e9SAndroid Build Coastguard Worker pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 612*da0073e9SAndroid Build Coastguard Worker # scalar 613*da0073e9SAndroid Build Coastguard Worker correct = torch.eq(pred_q, y_qry).sum().item() 614*da0073e9SAndroid Build Coastguard Worker corrects[1] = corrects[1] + correct 615*da0073e9SAndroid Build Coastguard Worker 616*da0073e9SAndroid Build Coastguard Worker del net 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Worker accs = torch.tensor(corrects) / querysz 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Worker return accs 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Workerdef softmax_backward_data(parent, grad_output, output, dim, self): 624*da0073e9SAndroid Build Coastguard Worker from torch import _softmax_backward_data 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) 627*da0073e9SAndroid Build Coastguard Worker 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Workerclass XSoftmax(torch.autograd.Function): 630*da0073e9SAndroid Build Coastguard Worker # transformers.models.deberta.modeling_deberta.XSoftmax 631*da0073e9SAndroid Build Coastguard Worker @staticmethod 632*da0073e9SAndroid Build Coastguard Worker def forward(self, input, mask, dim): 633*da0073e9SAndroid Build Coastguard Worker self.dim = dim 634*da0073e9SAndroid Build Coastguard Worker rmask = ~(mask.to(torch.bool)) 635*da0073e9SAndroid Build Coastguard Worker output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) 636*da0073e9SAndroid Build Coastguard Worker output = torch.softmax(output, self.dim) 637*da0073e9SAndroid Build Coastguard Worker output.masked_fill_(rmask, 0) 638*da0073e9SAndroid Build Coastguard Worker self.save_for_backward(output, rmask) 639*da0073e9SAndroid Build Coastguard Worker return output 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker @staticmethod 642*da0073e9SAndroid Build Coastguard Worker def backward(self, grad_output): 643*da0073e9SAndroid Build Coastguard Worker (output, rmask) = self.saved_tensors 644*da0073e9SAndroid Build Coastguard Worker inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) 645*da0073e9SAndroid Build Coastguard Worker return inputGrad, None, None 646*da0073e9SAndroid Build Coastguard Worker 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Workerclass ModelOutput(collections.OrderedDict): 649*da0073e9SAndroid Build Coastguard Worker """based on file_utils.py in HuggingFace""" 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, k): 652*da0073e9SAndroid Build Coastguard Worker if isinstance(k, str): 653*da0073e9SAndroid Build Coastguard Worker inner_dict = dict(self.items()) 654*da0073e9SAndroid Build Coastguard Worker return inner_dict[k] 655*da0073e9SAndroid Build Coastguard Worker else: 656*da0073e9SAndroid Build Coastguard Worker return self.to_tuple()[k] 657*da0073e9SAndroid Build Coastguard Worker 658*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name, value): 659*da0073e9SAndroid Build Coastguard Worker if name in self.keys() and value is not None: 660*da0073e9SAndroid Build Coastguard Worker # Don't call self.__setitem__ to avoid recursion errors 661*da0073e9SAndroid Build Coastguard Worker super().__setitem__(name, value) 662*da0073e9SAndroid Build Coastguard Worker super().__setattr__(name, value) 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker def __setitem__(self, key, value): 665*da0073e9SAndroid Build Coastguard Worker # Will raise a KeyException if needed 666*da0073e9SAndroid Build Coastguard Worker super().__setitem__(key, value) 667*da0073e9SAndroid Build Coastguard Worker # Don't call self.__setattr__ to avoid recursion errors 668*da0073e9SAndroid Build Coastguard Worker super().__setattr__(key, value) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker def to_tuple(self): 671*da0073e9SAndroid Build Coastguard Worker return tuple(self[k] for k in self.keys()) 672*da0073e9SAndroid Build Coastguard Worker 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Workerdef create_rand_mask_from_inputs( 675*da0073e9SAndroid Build Coastguard Worker from_blocked_mask, 676*da0073e9SAndroid Build Coastguard Worker to_blocked_mask, 677*da0073e9SAndroid Build Coastguard Worker rand_attn, 678*da0073e9SAndroid Build Coastguard Worker num_attention_heads, 679*da0073e9SAndroid Build Coastguard Worker num_rand_blocks, 680*da0073e9SAndroid Build Coastguard Worker batch_size, 681*da0073e9SAndroid Build Coastguard Worker from_seq_length, 682*da0073e9SAndroid Build Coastguard Worker from_block_size, 683*da0073e9SAndroid Build Coastguard Worker): 684*da0073e9SAndroid Build Coastguard Worker """taken from HF modeling_big_bird.py""" 685*da0073e9SAndroid Build Coastguard Worker num_windows = from_seq_length // from_block_size - 2 686*da0073e9SAndroid Build Coastguard Worker rand_mask = torch.stack( 687*da0073e9SAndroid Build Coastguard Worker [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)] 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker rand_mask = rand_mask.view( 690*da0073e9SAndroid Build Coastguard Worker batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size 691*da0073e9SAndroid Build Coastguard Worker ) 692*da0073e9SAndroid Build Coastguard Worker rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) 693*da0073e9SAndroid Build Coastguard Worker return rand_mask 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Workerclass SequentialAppendList(torch.nn.Sequential): 697*da0073e9SAndroid Build Coastguard Worker """from timm/models/vovnet.py""" 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: 700*da0073e9SAndroid Build Coastguard Worker for i, module in enumerate(self): 701*da0073e9SAndroid Build Coastguard Worker if i == 0: 702*da0073e9SAndroid Build Coastguard Worker concat_list.append(module(x)) 703*da0073e9SAndroid Build Coastguard Worker else: 704*da0073e9SAndroid Build Coastguard Worker concat_list.append(module(concat_list[-1])) 705*da0073e9SAndroid Build Coastguard Worker x = torch.cat(concat_list, dim=1) 706*da0073e9SAndroid Build Coastguard Worker return x, concat_list 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker 709*da0073e9SAndroid Build Coastguard Workerclass BatchNormAct2d(torch.nn.BatchNorm2d): 710*da0073e9SAndroid Build Coastguard Worker """Taken from timm""" 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker def __init__( 713*da0073e9SAndroid Build Coastguard Worker self, 714*da0073e9SAndroid Build Coastguard Worker num_features, 715*da0073e9SAndroid Build Coastguard Worker eps=1e-5, 716*da0073e9SAndroid Build Coastguard Worker momentum=0.1, 717*da0073e9SAndroid Build Coastguard Worker affine=True, 718*da0073e9SAndroid Build Coastguard Worker track_running_stats=True, 719*da0073e9SAndroid Build Coastguard Worker act_layer=torch.nn.ReLU, 720*da0073e9SAndroid Build Coastguard Worker inplace=True, 721*da0073e9SAndroid Build Coastguard Worker ): 722*da0073e9SAndroid Build Coastguard Worker super().__init__( 723*da0073e9SAndroid Build Coastguard Worker num_features, 724*da0073e9SAndroid Build Coastguard Worker eps=eps, 725*da0073e9SAndroid Build Coastguard Worker momentum=momentum, 726*da0073e9SAndroid Build Coastguard Worker affine=affine, 727*da0073e9SAndroid Build Coastguard Worker track_running_stats=track_running_stats, 728*da0073e9SAndroid Build Coastguard Worker ) 729*da0073e9SAndroid Build Coastguard Worker self.act = act_layer(inplace=inplace) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 732*da0073e9SAndroid Build Coastguard Worker def _forward_python(self, x): 733*da0073e9SAndroid Build Coastguard Worker return super().forward(x) 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 736*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 737*da0073e9SAndroid Build Coastguard Worker x = self._forward_jit(x) 738*da0073e9SAndroid Build Coastguard Worker else: 739*da0073e9SAndroid Build Coastguard Worker x = self._forward_python(x) 740*da0073e9SAndroid Build Coastguard Worker x = self.act(x) 741*da0073e9SAndroid Build Coastguard Worker return x 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker 744*da0073e9SAndroid Build Coastguard Workerdef get_parameter_dtype(parameter): 745*da0073e9SAndroid Build Coastguard Worker """from huggingface model_utils.py""" 746*da0073e9SAndroid Build Coastguard Worker try: 747*da0073e9SAndroid Build Coastguard Worker return next(parameter.parameters()).dtype 748*da0073e9SAndroid Build Coastguard Worker except StopIteration: 749*da0073e9SAndroid Build Coastguard Worker # For nn.DataParallel compatibility in PyTorch 1.5 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker def find_tensor_attributes(module): 752*da0073e9SAndroid Build Coastguard Worker tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 753*da0073e9SAndroid Build Coastguard Worker return tuples 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker gen = parameter._named_members(get_members_fn=find_tensor_attributes) 756*da0073e9SAndroid Build Coastguard Worker first_tuple = next(gen) 757*da0073e9SAndroid Build Coastguard Worker return first_tuple[1].dtype 758*da0073e9SAndroid Build Coastguard Worker 759*da0073e9SAndroid Build Coastguard Worker 760*da0073e9SAndroid Build Coastguard Workerclass DummyConfig: 761*da0073e9SAndroid Build Coastguard Worker attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"] 762*da0073e9SAndroid Build Coastguard Worker lsh_attn_chunk_length = 64 763*da0073e9SAndroid Build Coastguard Worker local_attn_chunk_length = 64 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Workerdef _get_min_chunk_len(config): 767*da0073e9SAndroid Build Coastguard Worker """from hf_Reformer""" 768*da0073e9SAndroid Build Coastguard Worker attn_types = config.attn_layers 769*da0073e9SAndroid Build Coastguard Worker attn_types_set = set(attn_types) 770*da0073e9SAndroid Build Coastguard Worker if len(attn_types_set) == 1 and attn_types[0] == "lsh": 771*da0073e9SAndroid Build Coastguard Worker return config.lsh_attn_chunk_length 772*da0073e9SAndroid Build Coastguard Worker elif len(attn_types_set) == 1 and attn_types[0] == "local": 773*da0073e9SAndroid Build Coastguard Worker return config.local_attn_chunk_length 774*da0073e9SAndroid Build Coastguard Worker elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: 775*da0073e9SAndroid Build Coastguard Worker return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) 776*da0073e9SAndroid Build Coastguard Worker else: 777*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 778*da0073e9SAndroid Build Coastguard Worker f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " 779*da0073e9SAndroid Build Coastguard Worker "attn layer types from ['lsh', 'local'] only." 780*da0073e9SAndroid Build Coastguard Worker ) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker 783*da0073e9SAndroid Build Coastguard Workerdef _stable_argsort(vector, dim): 784*da0073e9SAndroid Build Coastguard Worker """from hf_Reformer""" 785*da0073e9SAndroid Build Coastguard Worker # this function scales the vector so that torch.argsort is stable. 786*da0073e9SAndroid Build Coastguard Worker # torch.argsort is not stable on its own 787*da0073e9SAndroid Build Coastguard Worker scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) 788*da0073e9SAndroid Build Coastguard Worker scale_offset = scale_offset.expand(vector.shape) 789*da0073e9SAndroid Build Coastguard Worker scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) 790*da0073e9SAndroid Build Coastguard Worker return torch.argsort(scaled_vector, dim=dim) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Workerdef _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets): 794*da0073e9SAndroid Build Coastguard Worker """from hf_Reformer""" 795*da0073e9SAndroid Build Coastguard Worker # no gradients are needed 796*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 797*da0073e9SAndroid Build Coastguard Worker # hash-based sort 798*da0073e9SAndroid Build Coastguard Worker sorted_bucket_idx = _stable_argsort(buckets, dim=-1) 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker # create simple indices to scatter to, to have undo sort 801*da0073e9SAndroid Build Coastguard Worker indices = ( 802*da0073e9SAndroid Build Coastguard Worker torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) 803*da0073e9SAndroid Build Coastguard Worker .view(1, 1, -1) 804*da0073e9SAndroid Build Coastguard Worker .expand(sorted_bucket_idx.shape) 805*da0073e9SAndroid Build Coastguard Worker ) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker # get undo sort 808*da0073e9SAndroid Build Coastguard Worker undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) 809*da0073e9SAndroid Build Coastguard Worker undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker return sorted_bucket_idx, undo_sorted_bucket_idx 812*da0073e9SAndroid Build Coastguard Worker 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Workerclass CustomList1(list): 815*da0073e9SAndroid Build Coastguard Worker def __call__(self, x): 816*da0073e9SAndroid Build Coastguard Worker for processor in self: 817*da0073e9SAndroid Build Coastguard Worker x = processor(x) 818*da0073e9SAndroid Build Coastguard Worker return x 819*da0073e9SAndroid Build Coastguard Worker 820*da0073e9SAndroid Build Coastguard Worker def clear(self): 821*da0073e9SAndroid Build Coastguard Worker pass # this prevents RestrictedListSubclassVariable from kicking in 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker 824*da0073e9SAndroid Build Coastguard Workerclass CustomList2(list): 825*da0073e9SAndroid Build Coastguard Worker def __call__(self, x): 826*da0073e9SAndroid Build Coastguard Worker for processor in self: 827*da0073e9SAndroid Build Coastguard Worker x = processor(x) 828*da0073e9SAndroid Build Coastguard Worker return x 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker def length_times_10(self): 831*da0073e9SAndroid Build Coastguard Worker return len(self) * 10 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def append_twice(self, x): 834*da0073e9SAndroid Build Coastguard Worker self.extend([x, x]) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Workerdef _merge_criteria_processor_list(default_list, custom_list): 838*da0073e9SAndroid Build Coastguard Worker # simplified transformers/generation/utils.py 839*da0073e9SAndroid Build Coastguard Worker if len(custom_list) == 0: 840*da0073e9SAndroid Build Coastguard Worker return default_list 841*da0073e9SAndroid Build Coastguard Worker for default in default_list: 842*da0073e9SAndroid Build Coastguard Worker for custom in custom_list: 843*da0073e9SAndroid Build Coastguard Worker if type(custom) is type(default): 844*da0073e9SAndroid Build Coastguard Worker raise ValueError 845*da0073e9SAndroid Build Coastguard Worker default_list.extend(custom_list) 846*da0073e9SAndroid Build Coastguard Worker return default_list 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Workerclass FeedForwardLayer(nn.Module): 850*da0073e9SAndroid Build Coastguard Worker def __init__(self, d_model, dim_feedforward, activation, dropout) -> None: 851*da0073e9SAndroid Build Coastguard Worker super().__init__() 852*da0073e9SAndroid Build Coastguard Worker self.linear1 = nn.Linear(d_model, dim_feedforward) 853*da0073e9SAndroid Build Coastguard Worker self.activation = activation 854*da0073e9SAndroid Build Coastguard Worker self.dropout1 = nn.Dropout(dropout) 855*da0073e9SAndroid Build Coastguard Worker self.linear2 = nn.Linear(dim_feedforward, d_model) 856*da0073e9SAndroid Build Coastguard Worker self.dropout2 = nn.Dropout(dropout) 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 859*da0073e9SAndroid Build Coastguard Worker return self.dropout2( 860*da0073e9SAndroid Build Coastguard Worker self.linear2(self.dropout1(self.activation(self.linear1(x)))) 861*da0073e9SAndroid Build Coastguard Worker ) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Workerclass TransformerEncoderLayer(nn.Module): 865*da0073e9SAndroid Build Coastguard Worker def __init__( 866*da0073e9SAndroid Build Coastguard Worker self, 867*da0073e9SAndroid Build Coastguard Worker d_model, 868*da0073e9SAndroid Build Coastguard Worker nhead, 869*da0073e9SAndroid Build Coastguard Worker dim_feedforward=2048, 870*da0073e9SAndroid Build Coastguard Worker dropout=0.1, 871*da0073e9SAndroid Build Coastguard Worker activation=nn.ReLU(), 872*da0073e9SAndroid Build Coastguard Worker layer_norm_eps=1e-5, 873*da0073e9SAndroid Build Coastguard Worker ): 874*da0073e9SAndroid Build Coastguard Worker super().__init__() 875*da0073e9SAndroid Build Coastguard Worker self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 876*da0073e9SAndroid Build Coastguard Worker self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 877*da0073e9SAndroid Build Coastguard Worker self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 878*da0073e9SAndroid Build Coastguard Worker self.dropout = nn.Dropout(dropout) 879*da0073e9SAndroid Build Coastguard Worker self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout) 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker def forward(self, src, src_mask=None, src_key_padding_mask=None): 882*da0073e9SAndroid Build Coastguard Worker x = src 883*da0073e9SAndroid Build Coastguard Worker x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) 884*da0073e9SAndroid Build Coastguard Worker x = self.norm2(x + self._ff_block(x)) 885*da0073e9SAndroid Build Coastguard Worker return x 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker # self-attention block 888*da0073e9SAndroid Build Coastguard Worker def _sa_block(self, x, attn_mask, key_padding_mask): 889*da0073e9SAndroid Build Coastguard Worker x = self.self_attn( 890*da0073e9SAndroid Build Coastguard Worker x, 891*da0073e9SAndroid Build Coastguard Worker x, 892*da0073e9SAndroid Build Coastguard Worker x, 893*da0073e9SAndroid Build Coastguard Worker attn_mask=attn_mask, 894*da0073e9SAndroid Build Coastguard Worker key_padding_mask=key_padding_mask, 895*da0073e9SAndroid Build Coastguard Worker need_weights=False, 896*da0073e9SAndroid Build Coastguard Worker )[0] 897*da0073e9SAndroid Build Coastguard Worker return self.dropout(x) 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker # feed forward block 900*da0073e9SAndroid Build Coastguard Worker def _ff_block(self, x): 901*da0073e9SAndroid Build Coastguard Worker return self.ff_block(x) 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Workerclass MockModule(torch.nn.Module): 905*da0073e9SAndroid Build Coastguard Worker def inner_fn(self, left, right): 906*da0073e9SAndroid Build Coastguard Worker return tuple(left) == tuple(right) 907*da0073e9SAndroid Build Coastguard Worker 908*da0073e9SAndroid Build Coastguard Worker def fn(self, tensor): 909*da0073e9SAndroid Build Coastguard Worker if type(tensor) is int: 910*da0073e9SAndroid Build Coastguard Worker return False 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker torch.add(tensor, tensor) 913*da0073e9SAndroid Build Coastguard Worker return self.inner_fn(tensor.shape, (1, 2, 3)) 914*da0073e9SAndroid Build Coastguard Worker 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Workerclass IncByOne: 917*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 918*da0073e9SAndroid Build Coastguard Worker self.x = x + 1 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Workerclass IncByTwo: 922*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 923*da0073e9SAndroid Build Coastguard Worker self.x = x + 2 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker 926*da0073e9SAndroid Build Coastguard Workerclass ReproTests(torch._dynamo.test_case.TestCase): 927*da0073e9SAndroid Build Coastguard Worker def test_do_paste_mask(self): 928*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters.clear() 929*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 930*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt) 931*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask( 932*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 28, 28), 933*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.0, 1, 2, 4]]) * 1, 934*da0073e9SAndroid Build Coastguard Worker 427, 935*da0073e9SAndroid Build Coastguard Worker 640, 936*da0073e9SAndroid Build Coastguard Worker True, 937*da0073e9SAndroid Build Coastguard Worker ) 938*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask( 939*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 28, 28), 940*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.0, 1, 2, 4]]) * 2, 941*da0073e9SAndroid Build Coastguard Worker 427, 942*da0073e9SAndroid Build Coastguard Worker 640, 943*da0073e9SAndroid Build Coastguard Worker True, 944*da0073e9SAndroid Build Coastguard Worker ) 945*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask( 946*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 28, 28), 947*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.0, 1, 2, 4]]) * 3, 948*da0073e9SAndroid Build Coastguard Worker 612, 949*da0073e9SAndroid Build Coastguard Worker 612, 950*da0073e9SAndroid Build Coastguard Worker True, 951*da0073e9SAndroid Build Coastguard Worker ) 952*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask( 953*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 28, 28), 954*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.0, 1, 2, 4]]) * 4, 955*da0073e9SAndroid Build Coastguard Worker 612, 956*da0073e9SAndroid Build Coastguard Worker 612, 957*da0073e9SAndroid Build Coastguard Worker True, 958*da0073e9SAndroid Build Coastguard Worker ) 959*da0073e9SAndroid Build Coastguard Worker opt__do_paste_mask( 960*da0073e9SAndroid Build Coastguard Worker torch.randn(1, 1, 28, 28), 961*da0073e9SAndroid Build Coastguard Worker torch.tensor([[0.0, 1, 2, 4]]) * 2, 962*da0073e9SAndroid Build Coastguard Worker 427, 963*da0073e9SAndroid Build Coastguard Worker 640, 964*da0073e9SAndroid Build Coastguard Worker False, 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker # (dynamic shapes, static shapes) 967*da0073e9SAndroid Build Coastguard Worker self.assertIn(cnt.frame_count, (5, 7)) 968*da0073e9SAndroid Build Coastguard Worker self.assertIn(cnt.op_count, (92, 106, 119)) 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker def test_convert_boxes_to_pooler_format(self): 971*da0073e9SAndroid Build Coastguard Worker boxes1 = [ 972*da0073e9SAndroid Build Coastguard Worker Boxes(torch.arange(0, 8).reshape((2, 4))), 973*da0073e9SAndroid Build Coastguard Worker Boxes(torch.arange(8, 16).reshape((2, 4))), 974*da0073e9SAndroid Build Coastguard Worker ] 975*da0073e9SAndroid Build Coastguard Worker boxes2 = [ 976*da0073e9SAndroid Build Coastguard Worker Boxes(torch.arange(16, 20).reshape((1, 4))), 977*da0073e9SAndroid Build Coastguard Worker Boxes(torch.arange(20, 24).reshape((1, 4))), 978*da0073e9SAndroid Build Coastguard Worker ] 979*da0073e9SAndroid Build Coastguard Worker correct1 = convert_boxes_to_pooler_format(boxes1) 980*da0073e9SAndroid Build Coastguard Worker correct2 = convert_boxes_to_pooler_format(boxes2) 981*da0073e9SAndroid Build Coastguard Worker fn = convert_boxes_to_pooler_format 982*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 983*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 984*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(boxes1), correct1)) 985*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(boxes2), correct2)) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker # repeat_interleave is a dynamic shape operator we do not execute/ 988*da0073e9SAndroid Build Coastguard Worker # In the future, we could reduce the frame_count down to 1 989*da0073e9SAndroid Build Coastguard Worker # by guarding on the exact values of `Tensor repeats` arg 990*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 991*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """4""") 992*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """10""") 993*da0073e9SAndroid Build Coastguard Worker else: 994*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """4""") 995*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """14""") 996*da0073e9SAndroid Build Coastguard Worker 997*da0073e9SAndroid Build Coastguard Worker def test_boxes_len(self): 998*da0073e9SAndroid Build Coastguard Worker def fn(boxes): 999*da0073e9SAndroid Build Coastguard Worker return len(boxes) + boxes.__len__() + boxes.tensor 1000*da0073e9SAndroid Build Coastguard Worker 1001*da0073e9SAndroid Build Coastguard Worker boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4))) 1002*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1003*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1004*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0)) 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1007*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1008*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """1""") 1009*da0073e9SAndroid Build Coastguard Worker else: 1010*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1011*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """2""") 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker def _reformer(self, nopython): 1014*da0073e9SAndroid Build Coastguard Worker input = torch.randn([1, 64, 256]) 1015*da0073e9SAndroid Build Coastguard Worker model = ReformerEncoder() 1016*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1017*da0073e9SAndroid Build Coastguard Worker correct = copy.deepcopy(model)(input) 1018*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1019*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 1020*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model) 1021*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(input), correct)) 1022*da0073e9SAndroid Build Coastguard Worker return cnt 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker @requires_cuda 1025*da0073e9SAndroid Build Coastguard Worker def test_sub_alpha_scalar_repro(self): 1026*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager") 1027*da0073e9SAndroid Build Coastguard Worker def f(x): 1028*da0073e9SAndroid Build Coastguard Worker return x.sub(1, alpha=2) 1029*da0073e9SAndroid Build Coastguard Worker 1030*da0073e9SAndroid Build Coastguard Worker f(torch.ones(2, device="cuda", dtype=torch.float64)) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113010 1033*da0073e9SAndroid Build Coastguard Worker def test_out_overload_non_contiguous(self): 1034*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1035*da0073e9SAndroid Build Coastguard Worker return torch.abs(x, out=y.T) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1040*da0073e9SAndroid Build Coastguard Worker y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1041*da0073e9SAndroid Build Coastguard Worker x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1042*da0073e9SAndroid Build Coastguard Worker y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker out_ref = f(x_ref, y_ref) 1045*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x_test, y_test) 1046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1047*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_ref, y_test) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/109053 1050*da0073e9SAndroid Build Coastguard Worker def test_view_dtype_overload(self): 1051*da0073e9SAndroid Build Coastguard Worker def f(x): 1052*da0073e9SAndroid Build Coastguard Worker return x.view(torch.int32) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 1057*da0073e9SAndroid Build Coastguard Worker out_ref = f(x1) 1058*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x1) 1059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=False) 1062*da0073e9SAndroid Build Coastguard Worker out_ref = f(x2) 1063*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x2) 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/90552 1067*da0073e9SAndroid Build Coastguard Worker def test_intermediate_leaf_requires_grad(self): 1068*da0073e9SAndroid Build Coastguard Worker def f(x): 1069*da0073e9SAndroid Build Coastguard Worker leaf = torch.ones(2, requires_grad=True) 1070*da0073e9SAndroid Build Coastguard Worker return leaf, leaf * 2 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1073*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1074*da0073e9SAndroid Build Coastguard Worker 1075*da0073e9SAndroid Build Coastguard Worker leaf, out = f(x) 1076*da0073e9SAndroid Build Coastguard Worker leaf_test, out_test = f_compiled(x) 1077*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 1078*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 1079*da0073e9SAndroid Build Coastguard Worker self.assertEqual(leaf.grad, leaf_test.grad) 1080*da0073e9SAndroid Build Coastguard Worker 1081*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113263 1082*da0073e9SAndroid Build Coastguard Worker def test_unpack_hooks_dont_run_during_tracing(self): 1083*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1084*da0073e9SAndroid Build Coastguard Worker return x * y 1085*da0073e9SAndroid Build Coastguard Worker 1086*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker pack_count = 0 1089*da0073e9SAndroid Build Coastguard Worker unpack_count = 0 1090*da0073e9SAndroid Build Coastguard Worker 1091*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 1092*da0073e9SAndroid Build Coastguard Worker nonlocal pack_count 1093*da0073e9SAndroid Build Coastguard Worker pack_count += 1 1094*da0073e9SAndroid Build Coastguard Worker return x 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker # unpack hook shouldn't run during compilation, while we trace the forward 1097*da0073e9SAndroid Build Coastguard Worker def unpack_hook(x): 1098*da0073e9SAndroid Build Coastguard Worker nonlocal unpack_count 1099*da0073e9SAndroid Build Coastguard Worker unpack_count += 1 1100*da0073e9SAndroid Build Coastguard Worker return x 1101*da0073e9SAndroid Build Coastguard Worker 1102*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, requires_grad=True) 1103*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, requires_grad=False) 1104*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 1105*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x, y) 1106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pack_count, 1) 1107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpack_count, 0) 1108*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 1109*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pack_count, 1) 1110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpack_count, 1) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113263 1113*da0073e9SAndroid Build Coastguard Worker def test_unpack_hooks_can_be_disabled(self): 1114*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1115*da0073e9SAndroid Build Coastguard Worker return x * y 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, requires_grad=True) 1120*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, requires_grad=False) 1121*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): 1122*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x, y) 1123*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113263 1126*da0073e9SAndroid Build Coastguard Worker def test_disabling_unpack_hooks_within_compiled_region(self): 1127*da0073e9SAndroid Build Coastguard Worker def g(z): 1128*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): 1129*da0073e9SAndroid Build Coastguard Worker return z + 5 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1132*da0073e9SAndroid Build Coastguard Worker z = x * y 1133*da0073e9SAndroid Build Coastguard Worker return g(z) 1134*da0073e9SAndroid Build Coastguard Worker 1135*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1136*da0073e9SAndroid Build Coastguard Worker 1137*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, requires_grad=True) 1138*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, requires_grad=False) 1139*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(x, y) 1140*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/97745 1143*da0073e9SAndroid Build Coastguard Worker def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self): 1144*da0073e9SAndroid Build Coastguard Worker def f(a, b): 1145*da0073e9SAndroid Build Coastguard Worker c = torch.ones(2, 2) 1146*da0073e9SAndroid Build Coastguard Worker d = torch.ones(2, 2) 1147*da0073e9SAndroid Build Coastguard Worker e = torch.matmul(a, c) 1148*da0073e9SAndroid Build Coastguard Worker g_loss = torch.abs(e - d).mean() 1149*da0073e9SAndroid Build Coastguard Worker g_loss.backward() 1150*da0073e9SAndroid Build Coastguard Worker fake_d_pred = torch.matmul(b, e.detach()) 1151*da0073e9SAndroid Build Coastguard Worker d_loss = fake_d_pred.mean() 1152*da0073e9SAndroid Build Coastguard Worker d_loss.backward() 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Worker a_ref = torch.randn(2, 2, requires_grad=True) 1155*da0073e9SAndroid Build Coastguard Worker b_ref = torch.randn(2, 2, requires_grad=True) 1156*da0073e9SAndroid Build Coastguard Worker out_ref = f(a_ref, b_ref) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker a_test = a_ref.clone().detach().requires_grad_(True) 1159*da0073e9SAndroid Build Coastguard Worker b_test = b_ref.clone().detach().requires_grad_(True) 1160*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(f, backend="aot_eager")(a_test, b_test) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_ref.grad, a_test.grad) 1164*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_ref.grad, b_test.grad) 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/111603 1167*da0073e9SAndroid Build Coastguard Worker def test_tuple_enum_as_key_dict(self): 1168*da0073e9SAndroid Build Coastguard Worker class MyEnum(Enum): 1169*da0073e9SAndroid Build Coastguard Worker A = "a" 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker class SomeModel(torch.nn.Module): 1172*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1173*da0073e9SAndroid Build Coastguard Worker super().__init__() 1174*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(1, 1) 1175*da0073e9SAndroid Build Coastguard Worker 1176*da0073e9SAndroid Build Coastguard Worker def forward(self, x) -> torch.Tensor: 1177*da0073e9SAndroid Build Coastguard Worker return self.linear(x[MyEnum.A]) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker x = {MyEnum.A: torch.rand(8, 1)} 1180*da0073e9SAndroid Build Coastguard Worker model_pytorch = SomeModel() 1181*da0073e9SAndroid Build Coastguard Worker model = torch.compile(model_pytorch) 1182*da0073e9SAndroid Build Coastguard Worker # Executing twice works 1183*da0073e9SAndroid Build Coastguard Worker model(x) 1184*da0073e9SAndroid Build Coastguard Worker y = model(x) 1185*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, model_pytorch(x)) 1186*da0073e9SAndroid Build Coastguard Worker 1187*da0073e9SAndroid Build Coastguard Worker def test_embedding_backward_broadcasting_decomp(self): 1188*da0073e9SAndroid Build Coastguard Worker def f(grad_output, indices): 1189*da0073e9SAndroid Build Coastguard Worker num_weights = 10 1190*da0073e9SAndroid Build Coastguard Worker padding_idx = 1 1191*da0073e9SAndroid Build Coastguard Worker scale_grad_by_freq = True 1192*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.embedding_dense_backward( 1193*da0073e9SAndroid Build Coastguard Worker grad_output, indices, num_weights, padding_idx, scale_grad_by_freq 1194*da0073e9SAndroid Build Coastguard Worker ) 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f, backend="aot_eager") 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(2, 4, 3, dtype=torch.float16) 1199*da0073e9SAndroid Build Coastguard Worker indices = torch.ones(2, 4, dtype=torch.int64) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker out_ref = f(grad_output, indices) 1202*da0073e9SAndroid Build Coastguard Worker out_test = f_compiled(grad_output, indices) 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 1205*da0073e9SAndroid Build Coastguard Worker 1206*da0073e9SAndroid Build Coastguard Worker def test_reformer_eval(self): 1207*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1208*da0073e9SAndroid Build Coastguard Worker cnt = self._reformer(nopython=True) 1209*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 11) 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker def test_reformer_train(self): 1213*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 1214*da0073e9SAndroid Build Coastguard Worker cnt = self._reformer(nopython=False) 1215*da0073e9SAndroid Build Coastguard Worker expected_op_count = ( 1216*da0073e9SAndroid Build Coastguard Worker """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5""" 1217*da0073e9SAndroid Build Coastguard Worker ) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1220*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, expected_op_count) 1221*da0073e9SAndroid Build Coastguard Worker 1222*da0073e9SAndroid Build Coastguard Worker @disable_translation_validation_if_dynamic_shapes 1223*da0073e9SAndroid Build Coastguard Worker def test_longformer_chunk(self): 1224*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn([1, 4096, 1]) 1225*da0073e9SAndroid Build Coastguard Worker input2 = torch.randn([12, 4096, 64]) 1226*da0073e9SAndroid Build Coastguard Worker correct1 = longformer_chunk(input1) 1227*da0073e9SAndroid Build Coastguard Worker correct2 = longformer_chunk(input2) 1228*da0073e9SAndroid Build Coastguard Worker fn = longformer_chunk 1229*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1230*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1231*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(input1), correct1)) 1232*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(input2), correct2)) 1233*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(input1), correct1)) 1234*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(input2), correct2)) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1237*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.automatic_dynamic_shapes: 1238*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 1239*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """8""") 1240*da0073e9SAndroid Build Coastguard Worker else: 1241*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 1242*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """4""") 1243*da0073e9SAndroid Build Coastguard Worker else: 1244*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 1245*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """19""") 1246*da0073e9SAndroid Build Coastguard Worker 1247*da0073e9SAndroid Build Coastguard Worker def test_hf_t5_forward(self): 1248*da0073e9SAndroid Build Coastguard Worker input = torch.randn([1, 2048, 512]) 1249*da0073e9SAndroid Build Coastguard Worker model = PartialT5() 1250*da0073e9SAndroid Build Coastguard Worker correct = model(input) 1251*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1252*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize_assert(cnt)(model) 1253*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(input), correct)) 1254*da0073e9SAndroid Build Coastguard Worker 1255*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1256*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1257*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """11""") 1258*da0073e9SAndroid Build Coastguard Worker else: 1259*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1260*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """11""") 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker def test_module_in_skipfiles(self): 1263*da0073e9SAndroid Build Coastguard Worker model = nn.Linear(10, 10) 1264*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1265*da0073e9SAndroid Build Coastguard Worker torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10])) 1266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 1268*da0073e9SAndroid Build Coastguard Worker 1269*da0073e9SAndroid Build Coastguard Worker def test_function_in_skipfiles(self): 1270*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1271*da0073e9SAndroid Build Coastguard Worker torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10])) 1272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker def test_slicing_dynamic_shape(self): 1276*da0073e9SAndroid Build Coastguard Worker def fn(y): 1277*da0073e9SAndroid Build Coastguard Worker x = torch.ones(8) 1278*da0073e9SAndroid Build Coastguard Worker idx = y[0] 1279*da0073e9SAndroid Build Coastguard Worker out = x[idx:] 1280*da0073e9SAndroid Build Coastguard Worker return (out + 3) * 5 1281*da0073e9SAndroid Build Coastguard Worker 1282*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 1283*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(counter)(fn) 1284*da0073e9SAndroid Build Coastguard Worker out = opt_fn(torch.ones(10, dtype=torch.long)) 1285*da0073e9SAndroid Build Coastguard Worker # idx should be 1 -> slicing off [1:] of 8 elem tensor 1286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(out.shape), [7]) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 2) 1289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 1290*da0073e9SAndroid Build Coastguard Worker 1291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4]) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker def test_slicing_dynamic_shape_setitem(self): 1294*da0073e9SAndroid Build Coastguard Worker def fn(input_lengths: torch.Tensor, new_ones_1): 1295*da0073e9SAndroid Build Coastguard Worker getitem_13 = input_lengths[3] 1296*da0073e9SAndroid Build Coastguard Worker new_ones_1[(3, slice(getitem_13, None, None))] = 0 1297*da0073e9SAndroid Build Coastguard Worker setitem_13 = new_ones_1 1298*da0073e9SAndroid Build Coastguard Worker return (setitem_13,) 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10).to(dtype=torch.int64) 1301*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 204) 1302*da0073e9SAndroid Build Coastguard Worker ref = fn(x, y) 1303*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 1304*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x, y) 1305*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1306*da0073e9SAndroid Build Coastguard Worker 1307*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(error_on_recompile=True) 1308*da0073e9SAndroid Build Coastguard Worker @torch.fx.experimental._config.patch(use_duck_shape=False) 1309*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shape_disable_duck_size(self): 1310*da0073e9SAndroid Build Coastguard Worker class TestModel(nn.Module): 1311*da0073e9SAndroid Build Coastguard Worker def __init__( 1312*da0073e9SAndroid Build Coastguard Worker self, 1313*da0073e9SAndroid Build Coastguard Worker ): 1314*da0073e9SAndroid Build Coastguard Worker super().__init__() 1315*da0073e9SAndroid Build Coastguard Worker 1316*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 1317*da0073e9SAndroid Build Coastguard Worker return x + val 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker main_model = TestModel().to(memory_format=torch.channels_last) 1320*da0073e9SAndroid Build Coastguard Worker opt_model = torch.compile(main_model, backend="eager", dynamic=True) 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last) 1323*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last) 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker o1_ref = main_model(x1, 4) 1326*da0073e9SAndroid Build Coastguard Worker o1 = opt_model(x1, 4) 1327*da0073e9SAndroid Build Coastguard Worker 1328*da0073e9SAndroid Build Coastguard Worker o2_ref = main_model(x2, 20) 1329*da0073e9SAndroid Build Coastguard Worker o2 = opt_model(x2, 20) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker def test_chunk_reformer_ff(self): 1332*da0073e9SAndroid Build Coastguard Worker input = torch.randn([1, 4096, 256]) 1333*da0073e9SAndroid Build Coastguard Worker model = ChunkReformerFeedForward() 1334*da0073e9SAndroid Build Coastguard Worker correct = model(input) 1335*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1336*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize_assert(cnt)(model) 1337*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(input), correct)) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1340*da0073e9SAndroid Build Coastguard Worker self.assertLessEqual(cnt.op_count, 10) 1341*da0073e9SAndroid Build Coastguard Worker 1342*da0073e9SAndroid Build Coastguard Worker # see: https://github.com/pytorch/pytorch/issues/80067 1343*da0073e9SAndroid Build Coastguard Worker # NB: When you remove the expectedFailure, don't forget to 1344*da0073e9SAndroid Build Coastguard Worker # uncomment/adjust the assertEqual below 1345*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1346*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch( 1347*da0073e9SAndroid Build Coastguard Worker fake_tensor_propagation=True, capture_scalar_outputs=True 1348*da0073e9SAndroid Build Coastguard Worker ) 1349*da0073e9SAndroid Build Coastguard Worker def test_maml_item_capture(self): 1350*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 1, 28, 28) 1351*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(5, dtype=torch.int64) 1352*da0073e9SAndroid Build Coastguard Worker c = torch.randn(75, 1, 28, 28) 1353*da0073e9SAndroid Build Coastguard Worker d = torch.zeros(75, dtype=torch.int64) 1354*da0073e9SAndroid Build Coastguard Worker model = PartialMaml() 1355*da0073e9SAndroid Build Coastguard Worker correct = model(a, b, c, d) 1356*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1357*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(cnt)(model) 1358*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1359*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(a, b, c, d), correct)) 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker # if torch._dynamo.config.assume_static_by_default: 1362*da0073e9SAndroid Build Coastguard Worker # self.assertExpectedInline(cnt.frame_count, """2""") 1363*da0073e9SAndroid Build Coastguard Worker # else: 1364*da0073e9SAndroid Build Coastguard Worker # self.assertExpectedInline(cnt.frame_count, """3""") 1365*da0073e9SAndroid Build Coastguard Worker # TODO(jansel): figure out why op count depends on imports 1366*da0073e9SAndroid Build Coastguard Worker self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27)) 1367*da0073e9SAndroid Build Coastguard Worker 1368*da0073e9SAndroid Build Coastguard Worker # see: https://github.com/pytorch/pytorch/issues/80067 1369*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=False) 1370*da0073e9SAndroid Build Coastguard Worker def test_maml_no_item_capture(self): 1371*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 1, 28, 28) 1372*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(5, dtype=torch.int64) 1373*da0073e9SAndroid Build Coastguard Worker c = torch.randn(75, 1, 28, 28) 1374*da0073e9SAndroid Build Coastguard Worker d = torch.zeros(75, dtype=torch.int64) 1375*da0073e9SAndroid Build Coastguard Worker model = PartialMaml() 1376*da0073e9SAndroid Build Coastguard Worker correct = model(a, b, c, d) 1377*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1378*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(cnt)(model) 1379*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 1380*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(a, b, c, d), correct)) 1381*da0073e9SAndroid Build Coastguard Worker 1382*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1383*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """4""") 1384*da0073e9SAndroid Build Coastguard Worker else: 1385*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """5""") 1386*da0073e9SAndroid Build Coastguard Worker 1387*da0073e9SAndroid Build Coastguard Worker def test_hf_model_output(self): 1388*da0073e9SAndroid Build Coastguard Worker ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker def fn1(x): 1391*da0073e9SAndroid Build Coastguard Worker return x["a"] + 1 1392*da0073e9SAndroid Build Coastguard Worker 1393*da0073e9SAndroid Build Coastguard Worker def fn2(x): 1394*da0073e9SAndroid Build Coastguard Worker return x.a + 1 1395*da0073e9SAndroid Build Coastguard Worker 1396*da0073e9SAndroid Build Coastguard Worker def fn3(x): 1397*da0073e9SAndroid Build Coastguard Worker return x.to_tuple()[0] + 1 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker def fn4(x): 1400*da0073e9SAndroid Build Coastguard Worker return x[0] + 1 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1403*da0073e9SAndroid Build Coastguard Worker for fn in (fn1, fn2, fn3, fn4): 1404*da0073e9SAndroid Build Coastguard Worker cnt.clear() 1405*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1406*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(ex), ex.a + 1)) 1407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 1409*da0073e9SAndroid Build Coastguard Worker 1410*da0073e9SAndroid Build Coastguard Worker @disable_translation_validation_if_dynamic_shapes 1411*da0073e9SAndroid Build Coastguard Worker def test_create_rand_mask_from_inputs(self): 1412*da0073e9SAndroid Build Coastguard Worker args = [ 1413*da0073e9SAndroid Build Coastguard Worker torch.randn([1, 64, 64]), 1414*da0073e9SAndroid Build Coastguard Worker torch.randn([1, 64, 64]), 1415*da0073e9SAndroid Build Coastguard Worker torch.zeros([1, 12, 62, 3], dtype=torch.int64), 1416*da0073e9SAndroid Build Coastguard Worker 12, 1417*da0073e9SAndroid Build Coastguard Worker 3, 1418*da0073e9SAndroid Build Coastguard Worker 1, 1419*da0073e9SAndroid Build Coastguard Worker 4096, 1420*da0073e9SAndroid Build Coastguard Worker 64, 1421*da0073e9SAndroid Build Coastguard Worker ] 1422*da0073e9SAndroid Build Coastguard Worker correct = create_rand_mask_from_inputs(*args) 1423*da0073e9SAndroid Build Coastguard Worker fn = create_rand_mask_from_inputs 1424*da0073e9SAndroid Build Coastguard Worker 1425*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1426*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1427*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(*args), correct)) 1428*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1429*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1430*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """8""") 1431*da0073e9SAndroid Build Coastguard Worker else: 1432*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1433*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """11""") 1434*da0073e9SAndroid Build Coastguard Worker 1435*da0073e9SAndroid Build Coastguard Worker def test_rng_state(self): 1436*da0073e9SAndroid Build Coastguard Worker def fn(): 1437*da0073e9SAndroid Build Coastguard Worker state = torch.get_rng_state() 1438*da0073e9SAndroid Build Coastguard Worker before = torch.rand(1000) 1439*da0073e9SAndroid Build Coastguard Worker torch.set_rng_state(state) 1440*da0073e9SAndroid Build Coastguard Worker after = torch.rand(1000) 1441*da0073e9SAndroid Build Coastguard Worker return before, after 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1444*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 1445*da0073e9SAndroid Build Coastguard Worker 1446*da0073e9SAndroid Build Coastguard Worker before, after = opt_fn() 1447*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(before, after)) 1448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 1449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) # rand, rand 1450*da0073e9SAndroid Build Coastguard Worker try: 1451*da0073e9SAndroid Build Coastguard Worker graph, _ = torch._dynamo.export(fn)() 1452*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/87490 1453*da0073e9SAndroid Build Coastguard Worker self.fail("unexpected export success") 1454*da0073e9SAndroid Build Coastguard Worker except torch._dynamo.exc.Unsupported: 1455*da0073e9SAndroid Build Coastguard Worker pass 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker def test_threading_local(self): 1458*da0073e9SAndroid Build Coastguard Worker import threading 1459*da0073e9SAndroid Build Coastguard Worker 1460*da0073e9SAndroid Build Coastguard Worker foo = threading.local() 1461*da0073e9SAndroid Build Coastguard Worker foo.x = torch.rand(1) 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker def f(x): 1464*da0073e9SAndroid Build Coastguard Worker return torch.cat([x, foo.x]) 1465*da0073e9SAndroid Build Coastguard Worker 1466*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1467*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1) 1470*da0073e9SAndroid Build Coastguard Worker out = f(inp) 1471*da0073e9SAndroid Build Coastguard Worker opt_out = opt_f(inp) 1472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_out, out) 1473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker def test_seq_append_list(self): 1476*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 10) 1477*da0073e9SAndroid Build Coastguard Worker model = SequentialAppendList( 1478*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 1479*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 1480*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 1481*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 1482*da0073e9SAndroid Build Coastguard Worker ) 1483*da0073e9SAndroid Build Coastguard Worker # this one is tricky because it mutates the list provided as an input 1484*da0073e9SAndroid Build Coastguard Worker l1 = [x] 1485*da0073e9SAndroid Build Coastguard Worker l2 = [x] 1486*da0073e9SAndroid Build Coastguard Worker correct, _ = model(x, l1) 1487*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1488*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize_assert(cnt)(model) 1489*da0073e9SAndroid Build Coastguard Worker result, l3 = opt_model(x, l2) 1490*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(result, correct)) 1491*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(l1, l2)) 1492*da0073e9SAndroid Build Coastguard Worker self.assertIs(l2, l3) 1493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 5) 1495*da0073e9SAndroid Build Coastguard Worker 1496*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_act(self): 1497*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 1, 28, 28) 1498*da0073e9SAndroid Build Coastguard Worker model = BatchNormAct2d(1).eval() 1499*da0073e9SAndroid Build Coastguard Worker correct = model(a) 1500*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1501*da0073e9SAndroid Build Coastguard Worker if not torch._dynamo.config.specialize_int: 1502*da0073e9SAndroid Build Coastguard Worker # _local_scalar_dense causes graph break w 0-dim tensor 1503*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(cnt)(model) 1504*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(a), correct)) 1505*da0073e9SAndroid Build Coastguard Worker return 1506*da0073e9SAndroid Build Coastguard Worker 1507*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize_assert(cnt)(model) 1508*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_model(a), correct)) 1509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker def test_get_parameter_dtype(self): 1513*da0073e9SAndroid Build Coastguard Worker model = SequentialAppendList( 1514*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(10, 10), 1515*da0073e9SAndroid Build Coastguard Worker torch.nn.ReLU(), 1516*da0073e9SAndroid Build Coastguard Worker ) 1517*da0073e9SAndroid Build Coastguard Worker 1518*da0073e9SAndroid Build Coastguard Worker def fn(model, x): 1519*da0073e9SAndroid Build Coastguard Worker return x + torch.randn(10, dtype=get_parameter_dtype(model)) 1520*da0073e9SAndroid Build Coastguard Worker 1521*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1522*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32) 1524*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker def test_nn_parameter(self): 1528*da0073e9SAndroid Build Coastguard Worker def test_fn(): 1529*da0073e9SAndroid Build Coastguard Worker a = torch.nn.Parameter(torch.randn(5, 5)) 1530*da0073e9SAndroid Build Coastguard Worker # Checks that TensorVariable stores the type information correctly 1531*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(a, torch.nn.Parameter)) 1532*da0073e9SAndroid Build Coastguard Worker return a 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1535*da0073e9SAndroid Build Coastguard Worker opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1536*da0073e9SAndroid Build Coastguard Worker out = opt_test_fn() 1537*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(out, torch.nn.Parameter)) 1538*da0073e9SAndroid Build Coastguard Worker 1539*da0073e9SAndroid Build Coastguard Worker def test_Size(self): 1540*da0073e9SAndroid Build Coastguard Worker def test_fn(): 1541*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4) 1542*da0073e9SAndroid Build Coastguard Worker x = torch.Size([1, 2, 3]) 1543*da0073e9SAndroid Build Coastguard Worker # Checks that SizeVariable return torch.Size object 1544*da0073e9SAndroid Build Coastguard Worker assert isinstance(x, torch.Size) 1545*da0073e9SAndroid Build Coastguard Worker # Causes graph breaks and checks reconstruction of SizeVariable 1546*da0073e9SAndroid Build Coastguard Worker # object 1547*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x, torch.Size) 1548*da0073e9SAndroid Build Coastguard Worker return a 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1551*da0073e9SAndroid Build Coastguard Worker opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1552*da0073e9SAndroid Build Coastguard Worker opt_test_fn() 1553*da0073e9SAndroid Build Coastguard Worker 1554*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/100067 1555*da0073e9SAndroid Build Coastguard Worker def test_copy_weird_strides(self): 1556*da0073e9SAndroid Build Coastguard Worker # This test requires inductor's copy() decomp to preserve strides properly. 1557*da0073e9SAndroid Build Coastguard Worker def test_fn(a): 1558*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(48, 4, 256, 513) 1559*da0073e9SAndroid Build Coastguard Worker b[:, 0, 1:256, 1:256] = a 1560*da0073e9SAndroid Build Coastguard Worker c = b.view(4, 12, 1024, 513) 1561*da0073e9SAndroid Build Coastguard Worker d = c.transpose(2, 1) 1562*da0073e9SAndroid Build Coastguard Worker d.add_(1) 1563*da0073e9SAndroid Build Coastguard Worker return d 1564*da0073e9SAndroid Build Coastguard Worker 1565*da0073e9SAndroid Build Coastguard Worker sh, st, dt, dev, rg = ( 1566*da0073e9SAndroid Build Coastguard Worker (48, 255, 255), 1567*da0073e9SAndroid Build Coastguard Worker (787968, 513, 1), 1568*da0073e9SAndroid Build Coastguard Worker torch.float16, 1569*da0073e9SAndroid Build Coastguard Worker "cpu", 1570*da0073e9SAndroid Build Coastguard Worker True, 1571*da0073e9SAndroid Build Coastguard Worker ) 1572*da0073e9SAndroid Build Coastguard Worker a = rand_strided(sh, st, dt, dev).requires_grad_(rg) 1573*da0073e9SAndroid Build Coastguard Worker compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition") 1574*da0073e9SAndroid Build Coastguard Worker out1 = test_fn(a) 1575*da0073e9SAndroid Build Coastguard Worker out2 = compiled_f(a) 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 1577*da0073e9SAndroid Build Coastguard Worker 1578*da0073e9SAndroid Build Coastguard Worker def test_indexing_with_list(self): 1579*da0073e9SAndroid Build Coastguard Worker def test_fn(): 1580*da0073e9SAndroid Build Coastguard Worker def run_test(tensor, *idx): 1581*da0073e9SAndroid Build Coastguard Worker npt = tensor.numpy() 1582*da0073e9SAndroid Build Coastguard Worker assert npt[idx].shape == tensor[idx].shape 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker x = torch.arange(0, 10) 1585*da0073e9SAndroid Build Coastguard Worker cases = [ 1586*da0073e9SAndroid Build Coastguard Worker [None, None], 1587*da0073e9SAndroid Build Coastguard Worker [1, None], 1588*da0073e9SAndroid Build Coastguard Worker ] 1589*da0073e9SAndroid Build Coastguard Worker 1590*da0073e9SAndroid Build Coastguard Worker for case in cases: 1591*da0073e9SAndroid Build Coastguard Worker run_test(x, *case) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker return torch.randn(4) 1594*da0073e9SAndroid Build Coastguard Worker 1595*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1596*da0073e9SAndroid Build Coastguard Worker opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1597*da0073e9SAndroid Build Coastguard Worker opt_test_fn() 1598*da0073e9SAndroid Build Coastguard Worker 1599*da0073e9SAndroid Build Coastguard Worker def test_reformer_min_chunk_len(self): 1600*da0073e9SAndroid Build Coastguard Worker def fn(cfg): 1601*da0073e9SAndroid Build Coastguard Worker t = torch.empty(10) 1602*da0073e9SAndroid Build Coastguard Worker t.fill_(_get_min_chunk_len(cfg)) 1603*da0073e9SAndroid Build Coastguard Worker return t[0] 1604*da0073e9SAndroid Build Coastguard Worker 1605*da0073e9SAndroid Build Coastguard Worker cfg = DummyConfig() 1606*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1607*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(cfg), 64) 1609*da0073e9SAndroid Build Coastguard Worker # With unspec int, maximum computation is preserved 1610*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1611*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """3""") 1612*da0073e9SAndroid Build Coastguard Worker 1613*da0073e9SAndroid Build Coastguard Worker def test_reformer_sorting(self): 1614*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([1, 12, 4096], dtype=torch.int64) 1615*da0073e9SAndroid Build Coastguard Worker correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x) 1616*da0073e9SAndroid Build Coastguard Worker fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx 1617*da0073e9SAndroid Build Coastguard Worker 1618*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1619*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1620*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(opt_fn(x), correct)) 1621*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 1622*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1623*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """14""") 1624*da0073e9SAndroid Build Coastguard Worker else: 1625*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 1626*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.op_count, """16""") 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker def test_recursive_map(self): 1629*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/torchdynamo/issues/132 1630*da0073e9SAndroid Build Coastguard Worker def _recursive_map(struct, batch_dim=0): 1631*da0073e9SAndroid Build Coastguard Worker for k, v in struct.items(): 1632*da0073e9SAndroid Build Coastguard Worker if v is not None: 1633*da0073e9SAndroid Build Coastguard Worker if isinstance(v, dict): 1634*da0073e9SAndroid Build Coastguard Worker _recursive_map(v) 1635*da0073e9SAndroid Build Coastguard Worker else: 1636*da0073e9SAndroid Build Coastguard Worker struct[k] = v 1637*da0073e9SAndroid Build Coastguard Worker 1638*da0073e9SAndroid Build Coastguard Worker def toy_example(a, b, v): 1639*da0073e9SAndroid Build Coastguard Worker x = a / (torch.abs(a) + 1) 1640*da0073e9SAndroid Build Coastguard Worker if v is not None: 1641*da0073e9SAndroid Build Coastguard Worker _recursive_map(v) 1642*da0073e9SAndroid Build Coastguard Worker return x * b 1643*da0073e9SAndroid Build Coastguard Worker 1644*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1645*da0073e9SAndroid Build Coastguard Worker opt_toy_example = torch._dynamo.optimize(cnt)(toy_example) 1646*da0073e9SAndroid Build Coastguard Worker opt_toy_example( 1647*da0073e9SAndroid Build Coastguard Worker torch.randn(10), 1648*da0073e9SAndroid Build Coastguard Worker torch.randn(10), 1649*da0073e9SAndroid Build Coastguard Worker {"layer0": {"memory_keys": torch.randn(10)}}, 1650*da0073e9SAndroid Build Coastguard Worker ) 1651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 4) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker def test_issue114171(self): 1655*da0073e9SAndroid Build Coastguard Worker device = torch.device("cpu") 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU): 1658*da0073e9SAndroid Build Coastguard Worker layers = [ 1659*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(in_dim, hidden_dim, device=device), 1660*da0073e9SAndroid Build Coastguard Worker activation(), 1661*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(hidden_dim, out_dim, device=device), 1662*da0073e9SAndroid Build Coastguard Worker ] 1663*da0073e9SAndroid Build Coastguard Worker return torch.nn.Sequential(*layers) 1664*da0073e9SAndroid Build Coastguard Worker 1665*da0073e9SAndroid Build Coastguard Worker class testmodel(torch.nn.Module): 1666*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1667*da0073e9SAndroid Build Coastguard Worker super().__init__() 1668*da0073e9SAndroid Build Coastguard Worker self.interaction_networks = torch.nn.ModuleList( 1669*da0073e9SAndroid Build Coastguard Worker [fcnn(262, 1174, 400) for _ in range(4)] 1670*da0073e9SAndroid Build Coastguard Worker ) 1671*da0073e9SAndroid Build Coastguard Worker 1672*da0073e9SAndroid Build Coastguard Worker def interact(self, x, cycle): 1673*da0073e9SAndroid Build Coastguard Worker return self.interaction_networks[cycle](x) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker model = testmodel() 1676*da0073e9SAndroid Build Coastguard Worker forward_aot = torch.compile( 1677*da0073e9SAndroid Build Coastguard Worker model.interact, fullgraph=True, dynamic=True, backend="eager" 1678*da0073e9SAndroid Build Coastguard Worker ) 1679*da0073e9SAndroid Build Coastguard Worker 1680*da0073e9SAndroid Build Coastguard Worker x = torch.rand([111, 262], device=device) 1681*da0073e9SAndroid Build Coastguard Worker y2 = forward_aot(x, 2) # previously failed 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker def test_issue175(self): 1684*da0073e9SAndroid Build Coastguard Worker n_heads = 2 1685*da0073e9SAndroid Build Coastguard Worker d_model = 64 1686*da0073e9SAndroid Build Coastguard Worker model = TransformerEncoderLayer(d_model, n_heads) 1687*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(1, d_model) 1688*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 1689*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize(cnt, nopython=True)(model) 1690*da0073e9SAndroid Build Coastguard Worker opt_model(inp) 1691*da0073e9SAndroid Build Coastguard Worker opt_model(inp) 1692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 1693*da0073e9SAndroid Build Coastguard Worker 1694*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1695*da0073e9SAndroid Build Coastguard Worker 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count 1696*da0073e9SAndroid Build Coastguard Worker ) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker def test_exec_import(self): 1699*da0073e9SAndroid Build Coastguard Worker def fn1(): 1700*da0073e9SAndroid Build Coastguard Worker exec("import math") 1701*da0073e9SAndroid Build Coastguard Worker 1702*da0073e9SAndroid Build Coastguard Worker def fn2(): 1703*da0073e9SAndroid Build Coastguard Worker try: 1704*da0073e9SAndroid Build Coastguard Worker math.sqrt(4) 1705*da0073e9SAndroid Build Coastguard Worker return False 1706*da0073e9SAndroid Build Coastguard Worker except NameError: 1707*da0073e9SAndroid Build Coastguard Worker return True 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker def fn3(): 1710*da0073e9SAndroid Build Coastguard Worker fn1() 1711*da0073e9SAndroid Build Coastguard Worker return fn2() 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fn3()) 1714*da0073e9SAndroid Build Coastguard Worker opt_fn3 = torch._dynamo.optimize("eager")(fn3) 1715*da0073e9SAndroid Build Coastguard Worker self.assertTrue(opt_fn3()) 1716*da0073e9SAndroid Build Coastguard Worker 1717*da0073e9SAndroid Build Coastguard Worker def test_exec_wildcard_import(self): 1718*da0073e9SAndroid Build Coastguard Worker # Test that globals are not carried over from frame to frame 1719*da0073e9SAndroid Build Coastguard Worker def fn1(): 1720*da0073e9SAndroid Build Coastguard Worker exec("from torch import *") 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker def fn2(): 1723*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4) 1724*da0073e9SAndroid Build Coastguard Worker for i in range(5): 1725*da0073e9SAndroid Build Coastguard Worker x = x + i 1726*da0073e9SAndroid Build Coastguard Worker return x 1727*da0073e9SAndroid Build Coastguard Worker 1728*da0073e9SAndroid Build Coastguard Worker def fn3(): 1729*da0073e9SAndroid Build Coastguard Worker fn1() 1730*da0073e9SAndroid Build Coastguard Worker return fn2() 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker ref = fn3() 1733*da0073e9SAndroid Build Coastguard Worker opt_fn3 = torch._dynamo.optimize("eager")(fn3) 1734*da0073e9SAndroid Build Coastguard Worker res = opt_fn3() 1735*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1736*da0073e9SAndroid Build Coastguard Worker 1737*da0073e9SAndroid Build Coastguard Worker def test_with_on_graph_break_inst(self): 1738*da0073e9SAndroid Build Coastguard Worker def reversible(x): 1739*da0073e9SAndroid Build Coastguard Worker print("Hello world") # Cause graph break so inline fails 1740*da0073e9SAndroid Build Coastguard Worker return torch.sin(torch.cos(x)) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker def fn(x): 1743*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 1744*da0073e9SAndroid Build Coastguard Worker a = torch.sin(x) 1745*da0073e9SAndroid Build Coastguard Worker b = reversible(a) 1746*da0073e9SAndroid Build Coastguard Worker c = torch.sigmoid(b) 1747*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 1748*da0073e9SAndroid Build Coastguard Worker return x.grad 1749*da0073e9SAndroid Build Coastguard Worker 1750*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1751*da0073e9SAndroid Build Coastguard Worker x.grad = None 1752*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1753*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker x.grad = None 1756*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1757*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1758*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1759*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker def test_with_on_graph_break_nested(self): 1762*da0073e9SAndroid Build Coastguard Worker def reversible(x): 1763*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() # Cause graph break so inline fails 1764*da0073e9SAndroid Build Coastguard Worker return torch.sin(torch.cos(x)) 1765*da0073e9SAndroid Build Coastguard Worker 1766*da0073e9SAndroid Build Coastguard Worker def fn(x): 1767*da0073e9SAndroid Build Coastguard Worker # nested context manager failed previously 1768*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1769*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 1770*da0073e9SAndroid Build Coastguard Worker a = torch.sin(x) 1771*da0073e9SAndroid Build Coastguard Worker b = reversible(a) 1772*da0073e9SAndroid Build Coastguard Worker c = torch.sigmoid(b) 1773*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 1774*da0073e9SAndroid Build Coastguard Worker return x.grad 1775*da0073e9SAndroid Build Coastguard Worker 1776*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1777*da0073e9SAndroid Build Coastguard Worker x.grad = None 1778*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1779*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 1780*da0073e9SAndroid Build Coastguard Worker 1781*da0073e9SAndroid Build Coastguard Worker x.grad = None 1782*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1783*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1784*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 1785*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/torchdynamo/issues/1446 1788*da0073e9SAndroid Build Coastguard Worker def test_grad_mode_carrying_correct_state_after_graph_break(self): 1789*da0073e9SAndroid Build Coastguard Worker def fn(x): 1790*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1791*da0073e9SAndroid Build Coastguard Worker y = x * 3 1792*da0073e9SAndroid Build Coastguard Worker print("Break") 1793*da0073e9SAndroid Build Coastguard Worker z = x + 2 1794*da0073e9SAndroid Build Coastguard Worker return y, z 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1797*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1798*da0073e9SAndroid Build Coastguard Worker y, z = opt_fn(x) 1799*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 1800*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z.requires_grad) 1801*da0073e9SAndroid Build Coastguard Worker 1802*da0073e9SAndroid Build Coastguard Worker def test_abc_setattr(self): 1803*da0073e9SAndroid Build Coastguard Worker # tests that we correctly bail out of __setattr__ calls 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker # TODO: does not ensure ABC classes are correctly inferred as ClassVariables 1806*da0073e9SAndroid Build Coastguard Worker # (doesn't test the fix for 'super()') 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker class BaseModule(torch.nn.Module, ABC): 1809*da0073e9SAndroid Build Coastguard Worker def blah(self, x): 1810*da0073e9SAndroid Build Coastguard Worker return x + 1 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker class Derived(BaseModule): 1813*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name, value) -> None: 1814*da0073e9SAndroid Build Coastguard Worker super().__setattr__(name, value) 1815*da0073e9SAndroid Build Coastguard Worker 1816*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1817*da0073e9SAndroid Build Coastguard Worker # expect a graph break on __setattr__ 1818*da0073e9SAndroid Build Coastguard Worker self.foo = 0 1819*da0073e9SAndroid Build Coastguard Worker return self.blah(x) 1820*da0073e9SAndroid Build Coastguard Worker 1821*da0073e9SAndroid Build Coastguard Worker def blah(self, x): 1822*da0073e9SAndroid Build Coastguard Worker return super().blah(x) 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1825*da0073e9SAndroid Build Coastguard Worker mod = Derived() 1826*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 1827*da0073e9SAndroid Build Coastguard Worker opt_mod(x) 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker # Not sure what this test is testing. It was earlier graph breaking on 1830*da0073e9SAndroid Build Coastguard Worker # __dict__, so the counter >= 2. With __dict__ support, there is no 1831*da0073e9SAndroid Build Coastguard Worker # graph break. 1832*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 1833*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1) 1834*da0073e9SAndroid Build Coastguard Worker 1835*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("suppress_errors", True) 1836*da0073e9SAndroid Build Coastguard Worker def test_guard_fail_tensor_bool(self): 1837*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable(recursive=False) 1838*da0073e9SAndroid Build Coastguard Worker def fn(): 1839*da0073e9SAndroid Build Coastguard Worker condition_shape = (5, 5) 1840*da0073e9SAndroid Build Coastguard Worker dtypes = (torch.bool,) 1841*da0073e9SAndroid Build Coastguard Worker shapes = ( 1842*da0073e9SAndroid Build Coastguard Worker (), 1843*da0073e9SAndroid Build Coastguard Worker (5,), 1844*da0073e9SAndroid Build Coastguard Worker (1, 5), 1845*da0073e9SAndroid Build Coastguard Worker ) 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker tensors = [ 1848*da0073e9SAndroid Build Coastguard Worker torch.empty(shape, dtype=dtype).fill_(17) 1849*da0073e9SAndroid Build Coastguard Worker for shape, dtype in itertools.product(shapes, dtypes) 1850*da0073e9SAndroid Build Coastguard Worker ] 1851*da0073e9SAndroid Build Coastguard Worker 1852*da0073e9SAndroid Build Coastguard Worker x_vals = (5.0, *tensors) 1853*da0073e9SAndroid Build Coastguard Worker y_vals = (6.0, *tensors) 1854*da0073e9SAndroid Build Coastguard Worker 1855*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 1856*da0073e9SAndroid Build Coastguard Worker def get_expected(condition, x, y): 1857*da0073e9SAndroid Build Coastguard Worker x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x 1858*da0073e9SAndroid Build Coastguard Worker y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y 1859*da0073e9SAndroid Build Coastguard Worker return torch.from_numpy( 1860*da0073e9SAndroid Build Coastguard Worker np.where(condition.cpu().numpy(), x_np, y_np) 1861*da0073e9SAndroid Build Coastguard Worker ).to(common_dtype) 1862*da0073e9SAndroid Build Coastguard Worker 1863*da0073e9SAndroid Build Coastguard Worker for x, y in zip(x_vals, y_vals): 1864*da0073e9SAndroid Build Coastguard Worker condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_() 1865*da0073e9SAndroid Build Coastguard Worker common_dtype = torch.result_type(x, y) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker def check_equal(condition, x, y): 1868*da0073e9SAndroid Build Coastguard Worker # NumPy aggressively promotes to double, hence cast to output to correct dtype 1869*da0073e9SAndroid Build Coastguard Worker expected = get_expected(condition, x, y) 1870*da0073e9SAndroid Build Coastguard Worker result = torch.where(condition, x, y) 1871*da0073e9SAndroid Build Coastguard Worker assert torch.allclose(expected, result) 1872*da0073e9SAndroid Build Coastguard Worker 1873*da0073e9SAndroid Build Coastguard Worker check_equal(condition, x, y) 1874*da0073e9SAndroid Build Coastguard Worker check_equal(condition, y, x) 1875*da0073e9SAndroid Build Coastguard Worker 1876*da0073e9SAndroid Build Coastguard Worker fn() 1877*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1878*da0073e9SAndroid Build Coastguard Worker opt_fn() 1879*da0073e9SAndroid Build Coastguard Worker 1880*da0073e9SAndroid Build Coastguard Worker def test_guard_fail_nested_tuple(self): 1881*da0073e9SAndroid Build Coastguard Worker def fn(args): 1882*da0073e9SAndroid Build Coastguard Worker return torch.ones(()), args[0] * 2 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker # This adds a tensor check on args[1][0] and args[1][1] 1885*da0073e9SAndroid Build Coastguard Worker args1 = (torch.ones(1), (torch.ones(1), torch.ones(1))) 1886*da0073e9SAndroid Build Coastguard Worker args2 = (torch.ones(1), torch.ones(1)) 1887*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1888*da0073e9SAndroid Build Coastguard Worker ref = opt_fn(args1) 1889*da0073e9SAndroid Build Coastguard Worker res = opt_fn(args2) 1890*da0073e9SAndroid Build Coastguard Worker 1891*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker def test_nullcontext1(self): 1894*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 1895*da0073e9SAndroid Build Coastguard Worker def fn(x, ctx): 1896*da0073e9SAndroid Build Coastguard Worker x = x.sin() 1897*da0073e9SAndroid Build Coastguard Worker with ctx: 1898*da0073e9SAndroid Build Coastguard Worker x = x.cos() 1899*da0073e9SAndroid Build Coastguard Worker x = x.sin() 1900*da0073e9SAndroid Build Coastguard Worker return x 1901*da0073e9SAndroid Build Coastguard Worker 1902*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 1903*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin())) 1904*da0073e9SAndroid Build Coastguard Worker 1905*da0073e9SAndroid Build Coastguard Worker def test_nullcontext2(self): 1906*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 1907*da0073e9SAndroid Build Coastguard Worker def fn(x, ctx): 1908*da0073e9SAndroid Build Coastguard Worker x = x.sin() 1909*da0073e9SAndroid Build Coastguard Worker with ctx(): 1910*da0073e9SAndroid Build Coastguard Worker x = x.cos() 1911*da0073e9SAndroid Build Coastguard Worker x = x.sin() 1912*da0073e9SAndroid Build Coastguard Worker return x 1913*da0073e9SAndroid Build Coastguard Worker 1914*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 1915*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin())) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker def test_no_grad_inline(self): 1918*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 1919*da0073e9SAndroid Build Coastguard Worker def a(x): 1920*da0073e9SAndroid Build Coastguard Worker return x.sin() 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 1923*da0073e9SAndroid Build Coastguard Worker def b(x): 1924*da0073e9SAndroid Build Coastguard Worker return a(x).cos() 1925*da0073e9SAndroid Build Coastguard Worker 1926*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 1927*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(b(y), y.sin().cos())) 1928*da0073e9SAndroid Build Coastguard Worker 1929*da0073e9SAndroid Build Coastguard Worker @skipIfWindows( 1930*da0073e9SAndroid Build Coastguard Worker msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function <class 'torch.LongTensor'>(*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):" # noqa: B950 1931*da0073e9SAndroid Build Coastguard Worker ) 1932*da0073e9SAndroid Build Coastguard Worker def test_longtensor_list(self): 1933*da0073e9SAndroid Build Coastguard Worker for partition in [0, 5, 10]: 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable 1936*da0073e9SAndroid Build Coastguard Worker def rand_gen(): 1937*da0073e9SAndroid Build Coastguard Worker rand_vals = [random.randint(5, 10) for _ in range(10)] 1938*da0073e9SAndroid Build Coastguard Worker # List of tensors mixed with np.arrays 1939*da0073e9SAndroid Build Coastguard Worker return list(np.array(rand_vals[:partition])) + [ 1940*da0073e9SAndroid Build Coastguard Worker torch.tensor(val) for val in rand_vals[partition:] 1941*da0073e9SAndroid Build Coastguard Worker ] 1942*da0073e9SAndroid Build Coastguard Worker 1943*da0073e9SAndroid Build Coastguard Worker def fn(x): 1944*da0073e9SAndroid Build Coastguard Worker random_list = rand_gen() 1945*da0073e9SAndroid Build Coastguard Worker z = torch.LongTensor(random_list) 1946*da0073e9SAndroid Build Coastguard Worker return x * z 1947*da0073e9SAndroid Build Coastguard Worker 1948*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10) * 2 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker random.seed(0) 1951*da0073e9SAndroid Build Coastguard Worker ref0 = fn(x) 1952*da0073e9SAndroid Build Coastguard Worker ref1 = fn(x) 1953*da0073e9SAndroid Build Coastguard Worker 1954*da0073e9SAndroid Build Coastguard Worker random.seed(0) 1955*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 1956*da0073e9SAndroid Build Coastguard Worker res0 = opt_fn(x) 1957*da0073e9SAndroid Build Coastguard Worker res1 = opt_fn(x) 1958*da0073e9SAndroid Build Coastguard Worker 1959*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref0, res0)) 1960*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref1, res1)) 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker def test_primtorch(self): 1963*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 1964*da0073e9SAndroid Build Coastguard Worker def fn(x): 1965*da0073e9SAndroid Build Coastguard Worker torch._refs.abs(x) 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 1970*da0073e9SAndroid Build Coastguard Worker # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)] 1971*da0073e9SAndroid Build Coastguard Worker def test_primtorch_no_graph_break(self): 1972*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 1973*da0073e9SAndroid Build Coastguard Worker def fn(x): 1974*da0073e9SAndroid Build Coastguard Worker torch._refs.abs(x) 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 1977*da0073e9SAndroid Build Coastguard Worker 1978*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_ops_no_graph_break(self): 1979*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 1980*da0073e9SAndroid Build Coastguard Worker def fn(x): 1981*da0073e9SAndroid Build Coastguard Worker torch.Tensor.abs_(x) 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 1984*da0073e9SAndroid Build Coastguard Worker 1985*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 1986*da0073e9SAndroid Build Coastguard Worker not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket), 1987*da0073e9SAndroid Build Coastguard Worker "old pt doesn't work", 1988*da0073e9SAndroid Build Coastguard Worker ) 1989*da0073e9SAndroid Build Coastguard Worker def test_torch_ops_aten(self): 1990*da0073e9SAndroid Build Coastguard Worker # Picked an op that doesn't show up in the default list 1991*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 1992*da0073e9SAndroid Build Coastguard Worker def fn(x): 1993*da0073e9SAndroid Build Coastguard Worker return torch.ops.aten.absolute(x) 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 1996*da0073e9SAndroid Build Coastguard Worker 1997*da0073e9SAndroid Build Coastguard Worker def test_hf_gelu_inline(self): 1998*da0073e9SAndroid Build Coastguard Worker class GELUActivation(nn.Module): 1999*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2000*da0073e9SAndroid Build Coastguard Worker super().__init__() 2001*da0073e9SAndroid Build Coastguard Worker self.act = nn.functional.gelu 2002*da0073e9SAndroid Build Coastguard Worker 2003*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2004*da0073e9SAndroid Build Coastguard Worker return self.act(input) 2005*da0073e9SAndroid Build Coastguard Worker 2006*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 2007*da0073e9SAndroid Build Coastguard Worker def fn(x): 2008*da0073e9SAndroid Build Coastguard Worker return GELUActivation()(x) 2009*da0073e9SAndroid Build Coastguard Worker 2010*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 2011*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(y), nn.functional.gelu(y))) 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 2014*da0073e9SAndroid Build Coastguard Worker def fn_returns(x): 2015*da0073e9SAndroid Build Coastguard Worker return GELUActivation(), x + 1 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker act, _ = fn_returns(y) 2018*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(act, GELUActivation) 2019*da0073e9SAndroid Build Coastguard Worker self.assertIs(act.act, nn.functional.gelu) 2020*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(act, "_buffers")) # check that __init__ got called 2021*da0073e9SAndroid Build Coastguard Worker 2022*da0073e9SAndroid Build Coastguard Worker def test_dropout_inline(self): 2023*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 2024*da0073e9SAndroid Build Coastguard Worker def fn(x): 2025*da0073e9SAndroid Build Coastguard Worker return torch.nn.Dropout(0.1)(x) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 2028*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 2029*da0073e9SAndroid Build Coastguard Worker ref = nn.functional.dropout(y, 0.1) 2030*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 2031*da0073e9SAndroid Build Coastguard Worker res = fn(y) 2032*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker def test_setitem_boolean_mask_diff(self): 2035*da0073e9SAndroid Build Coastguard Worker def fn(x, b, y): 2036*da0073e9SAndroid Build Coastguard Worker x = x.clone() 2037*da0073e9SAndroid Build Coastguard Worker x[b] = y 2038*da0073e9SAndroid Build Coastguard Worker return x 2039*da0073e9SAndroid Build Coastguard Worker 2040*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2041*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 2042*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([True, False, True, False]) 2043*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, requires_grad=True) 2044*da0073e9SAndroid Build Coastguard Worker opt_fn(x, b, y) 2045*da0073e9SAndroid Build Coastguard Worker 2046*da0073e9SAndroid Build Coastguard Worker def test_setitem_tuple_boolean_mask_diff(self): 2047*da0073e9SAndroid Build Coastguard Worker def fn(x, b, y): 2048*da0073e9SAndroid Build Coastguard Worker x = x.clone() 2049*da0073e9SAndroid Build Coastguard Worker x[:, b] = y 2050*da0073e9SAndroid Build Coastguard Worker return x 2051*da0073e9SAndroid Build Coastguard Worker 2052*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2053*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8, 4, requires_grad=True) 2054*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([True, False, True, False]) 2055*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, requires_grad=True) 2056*da0073e9SAndroid Build Coastguard Worker opt_fn(x, b, y) 2057*da0073e9SAndroid Build Coastguard Worker 2058*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_ops(self): 2059*da0073e9SAndroid Build Coastguard Worker def fn(x): 2060*da0073e9SAndroid Build Coastguard Worker return torch.Tensor.abs_(x) 2061*da0073e9SAndroid Build Coastguard Worker 2062*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 2063*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2064*da0073e9SAndroid Build Coastguard Worker y = fn(x) 2065*da0073e9SAndroid Build Coastguard Worker y_ = opt_fn(x) 2066*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(y, y_)) 2067*da0073e9SAndroid Build Coastguard Worker 2068*da0073e9SAndroid Build Coastguard Worker def test_guard_ordering_shape_fail(self): 2069*da0073e9SAndroid Build Coastguard Worker # If a function which takes a tensor has an inner function which 2070*da0073e9SAndroid Build Coastguard Worker # is compiled and generates a guard on its shape, 2071*da0073e9SAndroid Build Coastguard Worker # they are evaluated in the wrong order. So if on a subsequent call 2072*da0073e9SAndroid Build Coastguard Worker # an int is passed instead of a tensor, guard evaluation will crash 2073*da0073e9SAndroid Build Coastguard Worker # with a "no attribute: shape" error 2074*da0073e9SAndroid Build Coastguard Worker m = MockModule() 2075*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager")(m) 2076*da0073e9SAndroid Build Coastguard Worker opt_m.fn(torch.ones((5, 5))) 2077*da0073e9SAndroid Build Coastguard Worker opt_m.fn(-3) 2078*da0073e9SAndroid Build Coastguard Worker 2079*da0073e9SAndroid Build Coastguard Worker def test_tensor_isinstance_tuple(self): 2080*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 2081*da0073e9SAndroid Build Coastguard Worker def fn(): 2082*da0073e9SAndroid Build Coastguard Worker t = torch.ones(5, 5) 2083*da0073e9SAndroid Build Coastguard Worker if not isinstance(t, (int, torch.Tensor)): 2084*da0073e9SAndroid Build Coastguard Worker msg = str.format( 2085*da0073e9SAndroid Build Coastguard Worker "{0} is not an instance of {1}", 2086*da0073e9SAndroid Build Coastguard Worker type(t), 2087*da0073e9SAndroid Build Coastguard Worker (int, torch.Tensor), 2088*da0073e9SAndroid Build Coastguard Worker ) 2089*da0073e9SAndroid Build Coastguard Worker raise ValueError(msg) 2090*da0073e9SAndroid Build Coastguard Worker return True 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker fn() 2093*da0073e9SAndroid Build Coastguard Worker 2094*da0073e9SAndroid Build Coastguard Worker def test_isinstance_dtype(self): 2095*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 2096*da0073e9SAndroid Build Coastguard Worker def fn(x): 2097*da0073e9SAndroid Build Coastguard Worker isinstance(torch.bfloat16, torch.dtype) 2098*da0073e9SAndroid Build Coastguard Worker return x 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 2101*da0073e9SAndroid Build Coastguard Worker 2102*da0073e9SAndroid Build Coastguard Worker def test_isinstance_storage(self): 2103*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 2104*da0073e9SAndroid Build Coastguard Worker def fn(x): 2105*da0073e9SAndroid Build Coastguard Worker f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) 2106*da0073e9SAndroid Build Coastguard Worker bools = torch.BoolStorage.from_buffer(f, "big") 2107*da0073e9SAndroid Build Coastguard Worker assert isinstance(bools, torch.BoolStorage) 2108*da0073e9SAndroid Build Coastguard Worker return x 2109*da0073e9SAndroid Build Coastguard Worker 2110*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(3)) 2111*da0073e9SAndroid Build Coastguard Worker 2112*da0073e9SAndroid Build Coastguard Worker def test_issue111522(self): 2113*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 2114*da0073e9SAndroid Build Coastguard Worker def f(x, y): 2115*da0073e9SAndroid Build Coastguard Worker return x + y.a 2116*da0073e9SAndroid Build Coastguard Worker 2117*da0073e9SAndroid Build Coastguard Worker class A: 2118*da0073e9SAndroid Build Coastguard Worker a = 2 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0)) 2121*da0073e9SAndroid Build Coastguard Worker 2122*da0073e9SAndroid Build Coastguard Worker del A.a 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker # graph break on missing attr 2125*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.Unsupported): 2126*da0073e9SAndroid Build Coastguard Worker f(torch.zeros(2), A()) 2127*da0073e9SAndroid Build Coastguard Worker 2128*da0073e9SAndroid Build Coastguard Worker def test_dict_list_values(self): 2129*da0073e9SAndroid Build Coastguard Worker def inner_fn(args): 2130*da0073e9SAndroid Build Coastguard Worker return [x[1].shape for x in args] 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 2133*da0073e9SAndroid Build Coastguard Worker def fn(tensors): 2134*da0073e9SAndroid Build Coastguard Worker return inner_fn(zip(itertools.count(), tensors["args"])) 2135*da0073e9SAndroid Build Coastguard Worker 2136*da0073e9SAndroid Build Coastguard Worker fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]}) 2137*da0073e9SAndroid Build Coastguard Worker fn({"args": [torch.ones(5, 5)]}) 2138*da0073e9SAndroid Build Coastguard Worker 2139*da0073e9SAndroid Build Coastguard Worker def test_dict_iter(self): 2140*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 2141*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2142*da0073e9SAndroid Build Coastguard Worker z = {"my": 1, "const": 2, "dict": 3, "variable": 4} 2143*da0073e9SAndroid Build Coastguard Worker tot = 0 2144*da0073e9SAndroid Build Coastguard Worker for key in z: 2145*da0073e9SAndroid Build Coastguard Worker tot += z[key] 2146*da0073e9SAndroid Build Coastguard Worker 2147*da0073e9SAndroid Build Coastguard Worker return tot 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0]) 2150*da0073e9SAndroid Build Coastguard Worker model = MyMod() 2151*da0073e9SAndroid Build Coastguard Worker opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 2152*da0073e9SAndroid Build Coastguard Worker y = opt_model(x) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, 10) 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker def test_sort_out(self): 2157*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 2158*da0073e9SAndroid Build Coastguard Worker device = "cpu" 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker def fn(): 2161*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0] 2162*da0073e9SAndroid Build Coastguard Worker values1 = torch.tensor(0, dtype=dtype, device=device) 2163*da0073e9SAndroid Build Coastguard Worker indices1 = torch.tensor(0, dtype=torch.long, device=device) 2164*da0073e9SAndroid Build Coastguard Worker torch.sort(tensor, out=(values1, indices1)) 2165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values1.stride(), (1,)) 2166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indices1.stride(), (1,)) 2167*da0073e9SAndroid Build Coastguard Worker 2168*da0073e9SAndroid Build Coastguard Worker fn() 2169*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 2170*da0073e9SAndroid Build Coastguard Worker opt_fn() 2171*da0073e9SAndroid Build Coastguard Worker 2172*da0073e9SAndroid Build Coastguard Worker def test_sort_out2(self): 2173*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2174*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2175*da0073e9SAndroid Build Coastguard Worker super().__init__() 2176*da0073e9SAndroid Build Coastguard Worker self.sorted = torch.nn.Buffer(torch.ones(4, 4)) 2177*da0073e9SAndroid Build Coastguard Worker self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long)) 2178*da0073e9SAndroid Build Coastguard Worker 2179*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2180*da0073e9SAndroid Build Coastguard Worker torch.sort(x, out=(self.sorted, self.indices)) 2181*da0073e9SAndroid Build Coastguard Worker return (x + 1, self.sorted, self.indices) 2182*da0073e9SAndroid Build Coastguard Worker 2183*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4) 2184*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2185*da0073e9SAndroid Build Coastguard Worker ref = m(x) 2186*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager")(m) 2187*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 2188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2189*da0073e9SAndroid Build Coastguard Worker 2190*da0073e9SAndroid Build Coastguard Worker def test_sigmoid_out(self): 2191*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 2192*da0073e9SAndroid Build Coastguard Worker device = "cpu" 2193*da0073e9SAndroid Build Coastguard Worker 2194*da0073e9SAndroid Build Coastguard Worker def fn(): 2195*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((3, 5), dtype=dtype, device=device) 2196*da0073e9SAndroid Build Coastguard Worker out1 = torch.tensor(0, dtype=dtype, device=device) 2197*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(inp, out=out1) 2198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1.numel(), 15) 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker fn() 2201*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 2202*da0073e9SAndroid Build Coastguard Worker opt_fn() 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker def test_sigmoid_out2(self): 2205*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2206*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2207*da0073e9SAndroid Build Coastguard Worker super().__init__() 2208*da0073e9SAndroid Build Coastguard Worker self.base = torch.nn.Buffer(torch.ones(4, 4)) 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2211*da0073e9SAndroid Build Coastguard Worker torch.sigmoid(x, out=self.base) 2212*da0073e9SAndroid Build Coastguard Worker return x + self.base 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 4) 2215*da0073e9SAndroid Build Coastguard Worker m = MyModule() 2216*da0073e9SAndroid Build Coastguard Worker ref = m(x) 2217*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize("eager")(m) 2218*da0073e9SAndroid Build Coastguard Worker res = opt_m(x) 2219*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2220*da0073e9SAndroid Build Coastguard Worker 2221*da0073e9SAndroid Build Coastguard Worker def test_slice_into_list_mutable(self): 2222*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 2223*da0073e9SAndroid Build Coastguard Worker def forward(self, listy): 2224*da0073e9SAndroid Build Coastguard Worker x = listy[3:5] 2225*da0073e9SAndroid Build Coastguard Worker for i in range(10): 2226*da0073e9SAndroid Build Coastguard Worker z = torch.abs(torch.randn(10)) + 1 2227*da0073e9SAndroid Build Coastguard Worker x[0] = z 2228*da0073e9SAndroid Build Coastguard Worker return x 2229*da0073e9SAndroid Build Coastguard Worker 2230*da0073e9SAndroid Build Coastguard Worker m = Mod() 2231*da0073e9SAndroid Build Coastguard Worker listy = [torch.randn(10)] * 10 2232*da0073e9SAndroid Build Coastguard Worker 2233*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2234*da0073e9SAndroid Build Coastguard Worker opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 2235*da0073e9SAndroid Build Coastguard Worker opt_m.forward(listy) 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2238*da0073e9SAndroid Build Coastguard Worker 2239*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 2240*da0073e9SAndroid Build Coastguard Worker def test_issue111918(self): 2241*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2242*da0073e9SAndroid Build Coastguard Worker 2243*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, dynamic=True) 2244*da0073e9SAndroid Build Coastguard Worker def fn(x): 2245*da0073e9SAndroid Build Coastguard Worker x = x + 1 2246*da0073e9SAndroid Build Coastguard Worker y = x.item() 2247*da0073e9SAndroid Build Coastguard Worker if y > 2: 2248*da0073e9SAndroid Build Coastguard Worker return x * 2 2249*da0073e9SAndroid Build Coastguard Worker else: 2250*da0073e9SAndroid Build Coastguard Worker return x * 3 2251*da0073e9SAndroid Build Coastguard Worker 2252*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([3.0]) 2253*da0073e9SAndroid Build Coastguard Worker fn(x) 2254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 2255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 4) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 2258*da0073e9SAndroid Build Coastguard Worker fn = torch.compile(fn, fullgraph=True, backend="eager") 2259*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.UserError): 2260*da0073e9SAndroid Build Coastguard Worker fn(x) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker def test_vdd_duplicate_error(self): 2263*da0073e9SAndroid Build Coastguard Worker def fn(a, dt): 2264*da0073e9SAndroid Build Coastguard Worker keys = list(dt._jt_dict.keys()) 2265*da0073e9SAndroid Build Coastguard Worker p = torch.cos(dt._jt_dict[keys[0]]._value) 2266*da0073e9SAndroid Build Coastguard Worker q = torch.sin(a) 2267*da0073e9SAndroid Build Coastguard Worker r = torch.sigmoid(dt._jt_dict[keys[0]]._value) 2268*da0073e9SAndroid Build Coastguard Worker return p + q + r 2269*da0073e9SAndroid Build Coastguard Worker 2270*da0073e9SAndroid Build Coastguard Worker class Value: 2271*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2272*da0073e9SAndroid Build Coastguard Worker self._value = torch.randn(4) 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker class Sample: 2275*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2276*da0073e9SAndroid Build Coastguard Worker self._jt_dict = {} 2277*da0073e9SAndroid Build Coastguard Worker self._jt_dict["POSITION_ID"] = Value() 2278*da0073e9SAndroid Build Coastguard Worker 2279*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4) 2280*da0073e9SAndroid Build Coastguard Worker sample = Sample() 2281*da0073e9SAndroid Build Coastguard Worker 2282*da0073e9SAndroid Build Coastguard Worker ref = fn(a, sample) 2283*da0073e9SAndroid Build Coastguard Worker 2284*da0073e9SAndroid Build Coastguard Worker optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2285*da0073e9SAndroid Build Coastguard Worker res = optimized_fn(a, sample) 2286*da0073e9SAndroid Build Coastguard Worker 2287*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2288*da0073e9SAndroid Build Coastguard Worker 2289*da0073e9SAndroid Build Coastguard Worker def test_specialized_stride(self): 2290*da0073e9SAndroid Build Coastguard Worker def f(): 2291*da0073e9SAndroid Build Coastguard Worker e = torch.empty(4) 2292*da0073e9SAndroid Build Coastguard Worker x = e[::2] 2293*da0073e9SAndroid Build Coastguard Worker return x.stride() 2294*da0073e9SAndroid Build Coastguard Worker 2295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(), torch._dynamo.optimize("eager")(f)()) 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker def test_out_none(self): 2298*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/92814 2299*da0073e9SAndroid Build Coastguard Worker def fn(input): 2300*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.normalize(input, dim=0, out=None) 2301*da0073e9SAndroid Build Coastguard Worker 2302*da0073e9SAndroid Build Coastguard Worker x = torch.rand([1]) 2303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x)) 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker def test_multi_import(self): 2306*da0073e9SAndroid Build Coastguard Worker if not has_detectron2(): 2307*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("requires detectron2") 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager", nopython=True) 2310*da0073e9SAndroid Build Coastguard Worker def to_bitmasks(boxes): 2311*da0073e9SAndroid Build Coastguard Worker from detectron2.layers.mask_ops import ( 2312*da0073e9SAndroid Build Coastguard Worker _paste_masks_tensor_shape, 2313*da0073e9SAndroid Build Coastguard Worker paste_masks_in_image, 2314*da0073e9SAndroid Build Coastguard Worker ) 2315*da0073e9SAndroid Build Coastguard Worker 2316*da0073e9SAndroid Build Coastguard Worker if ( 2317*da0073e9SAndroid Build Coastguard Worker paste_masks_in_image is not None 2318*da0073e9SAndroid Build Coastguard Worker and _paste_masks_tensor_shape is not None 2319*da0073e9SAndroid Build Coastguard Worker ): 2320*da0073e9SAndroid Build Coastguard Worker return boxes + 1 2321*da0073e9SAndroid Build Coastguard Worker 2322*da0073e9SAndroid Build Coastguard Worker self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all()) 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker def test_multi_dot_import(self): 2325*da0073e9SAndroid Build Coastguard Worker def fn1(x): 2326*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2327*da0073e9SAndroid Build Coastguard Worker 2328*da0073e9SAndroid Build Coastguard Worker def fn(x): 2329*da0073e9SAndroid Build Coastguard Worker import torch.fx 2330*da0073e9SAndroid Build Coastguard Worker 2331*da0073e9SAndroid Build Coastguard Worker _ = torch.fx.symbolic_trace(fn1) 2332*da0073e9SAndroid Build Coastguard Worker return x * 2 2333*da0073e9SAndroid Build Coastguard Worker 2334*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2335*da0073e9SAndroid Build Coastguard Worker fn(x) 2336*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2337*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2338*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2340*da0073e9SAndroid Build Coastguard Worker 2341*da0073e9SAndroid Build Coastguard Worker def test_relative_import(self): 2342*da0073e9SAndroid Build Coastguard Worker try: 2343*da0073e9SAndroid Build Coastguard Worker from . import utils as _ # noqa: F401 2344*da0073e9SAndroid Build Coastguard Worker 2345*da0073e9SAndroid Build Coastguard Worker def fn(x): 2346*da0073e9SAndroid Build Coastguard Worker from .utils import tensor_for_import_testing 2347*da0073e9SAndroid Build Coastguard Worker 2348*da0073e9SAndroid Build Coastguard Worker return x * 2 * tensor_for_import_testing 2349*da0073e9SAndroid Build Coastguard Worker 2350*da0073e9SAndroid Build Coastguard Worker except ImportError: 2351*da0073e9SAndroid Build Coastguard Worker 2352*da0073e9SAndroid Build Coastguard Worker def fn(x): 2353*da0073e9SAndroid Build Coastguard Worker from utils import tensor_for_import_testing 2354*da0073e9SAndroid Build Coastguard Worker 2355*da0073e9SAndroid Build Coastguard Worker return x * 2 * tensor_for_import_testing 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2358*da0073e9SAndroid Build Coastguard Worker fn(x) 2359*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2360*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 2361*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2363*da0073e9SAndroid Build Coastguard Worker 2364*da0073e9SAndroid Build Coastguard Worker def test_relative_import_no_modulename(self): 2365*da0073e9SAndroid Build Coastguard Worker try: 2366*da0073e9SAndroid Build Coastguard Worker from . import utils as _ # noqa: F401 2367*da0073e9SAndroid Build Coastguard Worker 2368*da0073e9SAndroid Build Coastguard Worker def fn(x): 2369*da0073e9SAndroid Build Coastguard Worker from . import utils 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker return x * 2 * utils.tensor_for_import_testing 2372*da0073e9SAndroid Build Coastguard Worker 2373*da0073e9SAndroid Build Coastguard Worker except ImportError: 2374*da0073e9SAndroid Build Coastguard Worker 2375*da0073e9SAndroid Build Coastguard Worker def fn(x): 2376*da0073e9SAndroid Build Coastguard Worker import utils 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker return x * 2 * utils.tensor_for_import_testing 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2381*da0073e9SAndroid Build Coastguard Worker fn(x) 2382*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2383*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 2384*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2386*da0073e9SAndroid Build Coastguard Worker 2387*da0073e9SAndroid Build Coastguard Worker def test_bigbird_unsqueeze_inplace(self): 2388*da0073e9SAndroid Build Coastguard Worker def fn(reshape_2): 2389*da0073e9SAndroid Build Coastguard Worker view_2 = reshape_2.clone() 2390*da0073e9SAndroid Build Coastguard Worker view_2.unsqueeze_(2) 2391*da0073e9SAndroid Build Coastguard Worker cat_11 = torch.cat([view_2], dim=2) 2392*da0073e9SAndroid Build Coastguard Worker view_13 = cat_11.view((2, 12, 64, -1)) 2393*da0073e9SAndroid Build Coastguard Worker return (view_13,) 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 12, 64, 64, requires_grad=True) 2396*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 2397*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2398*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2399*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2400*da0073e9SAndroid Build Coastguard Worker 2401*da0073e9SAndroid Build Coastguard Worker def test_issue1466_size_aot_autograd(self): 2402*da0073e9SAndroid Build Coastguard Worker def fn(x): 2403*da0073e9SAndroid Build Coastguard Worker # do a tensor op and a size compute 2404*da0073e9SAndroid Build Coastguard Worker y = x * 2 2405*da0073e9SAndroid Build Coastguard Worker x_size = x.size() 2406*da0073e9SAndroid Build Coastguard Worker # trigger a graph break 2407*da0073e9SAndroid Build Coastguard Worker print("arf") 2408*da0073e9SAndroid Build Coastguard Worker # use the tensor op and size compute 2409*da0073e9SAndroid Build Coastguard Worker z = y.view(x_size) + 1 2410*da0073e9SAndroid Build Coastguard Worker return z 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True) 2413*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 2414*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2415*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2416*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker def test_ellipsis(self): 2419*da0073e9SAndroid Build Coastguard Worker class Repro(torch.nn.Module): 2420*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2421*da0073e9SAndroid Build Coastguard Worker super().__init__() 2422*da0073e9SAndroid Build Coastguard Worker self.lnorm = torch.nn.LayerNorm( 2423*da0073e9SAndroid Build Coastguard Worker (256,), eps=1e-06, elementwise_affine=True 2424*da0073e9SAndroid Build Coastguard Worker ) 2425*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear( 2426*da0073e9SAndroid Build Coastguard Worker in_features=256, out_features=256, bias=True 2427*da0073e9SAndroid Build Coastguard Worker ) 2428*da0073e9SAndroid Build Coastguard Worker 2429*da0073e9SAndroid Build Coastguard Worker def forward(self, cat_10): 2430*da0073e9SAndroid Build Coastguard Worker lnorm = self.lnorm(cat_10) 2431*da0073e9SAndroid Build Coastguard Worker getitem_64 = lnorm[ 2432*da0073e9SAndroid Build Coastguard Worker (slice(None, None, None), slice(0, 1, None), Ellipsis) 2433*da0073e9SAndroid Build Coastguard Worker ] 2434*da0073e9SAndroid Build Coastguard Worker linear = self.linear(getitem_64) 2435*da0073e9SAndroid Build Coastguard Worker return (linear,) 2436*da0073e9SAndroid Build Coastguard Worker 2437*da0073e9SAndroid Build Coastguard Worker args = [torch.randn(2, 197, 256)] 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker mod = Repro() 2440*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(mod(*args), opt_mod(*args))) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker def test_reinplacing(self): 2445*da0073e9SAndroid Build Coastguard Worker class MockModule(torch.nn.Module): 2446*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2447*da0073e9SAndroid Build Coastguard Worker super().__init__() 2448*da0073e9SAndroid Build Coastguard Worker self.self_layoutlm_embeddings_x_position_embeddings = ( 2449*da0073e9SAndroid Build Coastguard Worker torch.nn.Embedding(1024, 768) 2450*da0073e9SAndroid Build Coastguard Worker ) 2451*da0073e9SAndroid Build Coastguard Worker self.self_layoutlm_embeddings_y_position_embeddings = ( 2452*da0073e9SAndroid Build Coastguard Worker torch.nn.Embedding(1024, 768) 2453*da0073e9SAndroid Build Coastguard Worker ) 2454*da0073e9SAndroid Build Coastguard Worker 2455*da0073e9SAndroid Build Coastguard Worker def forward(self, getitem_1, getitem_2, add): 2456*da0073e9SAndroid Build Coastguard Worker self_layoutlm_embeddings_x_position_embeddings = ( 2457*da0073e9SAndroid Build Coastguard Worker self.self_layoutlm_embeddings_x_position_embeddings(getitem_1) 2458*da0073e9SAndroid Build Coastguard Worker ) 2459*da0073e9SAndroid Build Coastguard Worker self_layoutlm_embeddings_y_position_embeddings = ( 2460*da0073e9SAndroid Build Coastguard Worker self.self_layoutlm_embeddings_y_position_embeddings(getitem_2) 2461*da0073e9SAndroid Build Coastguard Worker ) 2462*da0073e9SAndroid Build Coastguard Worker add_1 = add + self_layoutlm_embeddings_x_position_embeddings 2463*da0073e9SAndroid Build Coastguard Worker add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings 2464*da0073e9SAndroid Build Coastguard Worker return (add_2,) 2465*da0073e9SAndroid Build Coastguard Worker 2466*da0073e9SAndroid Build Coastguard Worker mod = MockModule() 2467*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod) 2468*da0073e9SAndroid Build Coastguard Worker 2469*da0073e9SAndroid Build Coastguard Worker args = [ 2470*da0073e9SAndroid Build Coastguard Worker ((2, 512), (2048, 4), torch.int64, "cpu", False), 2471*da0073e9SAndroid Build Coastguard Worker ((2, 512), (2048, 4), torch.int64, "cpu", False), 2472*da0073e9SAndroid Build Coastguard Worker ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True), 2473*da0073e9SAndroid Build Coastguard Worker ] 2474*da0073e9SAndroid Build Coastguard Worker args = [ 2475*da0073e9SAndroid Build Coastguard Worker rand_strided(sh, st, dt, dev).requires_grad_(rg) 2476*da0073e9SAndroid Build Coastguard Worker for (sh, st, dt, dev, rg) in args 2477*da0073e9SAndroid Build Coastguard Worker ] 2478*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_two_models(mod, opt_mod, args)) 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker def test_optimized_deepcopy(self): 2481*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/88629 2482*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2483*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2484*da0073e9SAndroid Build Coastguard Worker super().__init__() 2485*da0073e9SAndroid Build Coastguard Worker self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True) 2486*da0073e9SAndroid Build Coastguard Worker 2487*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2488*da0073e9SAndroid Build Coastguard Worker return self.fc(x) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2491*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 2492*da0073e9SAndroid Build Coastguard Worker args = [torch.randn(1, 2)] 2493*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_two_models(mod, opt_mod, args)) 2494*da0073e9SAndroid Build Coastguard Worker 2495*da0073e9SAndroid Build Coastguard Worker def test_class_member(self): 2496*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2497*da0073e9SAndroid Build Coastguard Worker a = 4 2498*da0073e9SAndroid Build Coastguard Worker b = torch.ones(3, 4) 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2501*da0073e9SAndroid Build Coastguard Worker super().__init__() 2502*da0073e9SAndroid Build Coastguard Worker self.c = 4 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2505*da0073e9SAndroid Build Coastguard Worker return x.cos() + self.a + self.b + self.c 2506*da0073e9SAndroid Build Coastguard Worker 2507*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2508*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2509*da0073e9SAndroid Build Coastguard Worker args = (torch.randn(3, 4),) 2510*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(mod(*args), opt_mod(*args))) 2511*da0073e9SAndroid Build Coastguard Worker 2512*da0073e9SAndroid Build Coastguard Worker def test_named_buffers(self): 2513*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2514*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2515*da0073e9SAndroid Build Coastguard Worker super().__init__() 2516*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.Buffer(torch.ones(3)) 2517*da0073e9SAndroid Build Coastguard Worker self.y = torch.nn.Buffer(torch.ones(3)) 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 2520*da0073e9SAndroid Build Coastguard Worker res = 0 2521*da0073e9SAndroid Build Coastguard Worker for name, buffer in self.named_buffers(): 2522*da0073e9SAndroid Build Coastguard Worker res += buffer.sum() 2523*da0073e9SAndroid Build Coastguard Worker 2524*da0073e9SAndroid Build Coastguard Worker return inp.cos() + res 2525*da0073e9SAndroid Build Coastguard Worker 2526*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2527*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2528*da0073e9SAndroid Build Coastguard Worker args = (torch.randn(3, 4),) 2529*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(mod(*args), opt_mod(*args))) 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_guards_with_grad_mode1(self): 2532*da0073e9SAndroid Build Coastguard Worker def f(x): 2533*da0073e9SAndroid Build Coastguard Worker if x.requires_grad: 2534*da0073e9SAndroid Build Coastguard Worker return x + 1 2535*da0073e9SAndroid Build Coastguard Worker else: 2536*da0073e9SAndroid Build Coastguard Worker return x + 2 2537*da0073e9SAndroid Build Coastguard Worker 2538*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, requires_grad=True) 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(f) 2541*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2542*da0073e9SAndroid Build Coastguard Worker # compile an inference graph 2543*da0073e9SAndroid Build Coastguard Worker f_compiled(x) 2544*da0073e9SAndroid Build Coastguard Worker 2545*da0073e9SAndroid Build Coastguard Worker # Test: we should fail guards and recompile (even though it's still an inference graph) 2546*da0073e9SAndroid Build Coastguard Worker out_ref = f(x.detach()) 2547*da0073e9SAndroid Build Coastguard Worker out = f_compiled(x.detach()) 2548*da0073e9SAndroid Build Coastguard Worker 2549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 2550*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref.requires_grad, out.requires_grad) 2551*da0073e9SAndroid Build Coastguard Worker 2552*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_guards_with_grad_mode2(self): 2553*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, requires_grad=True) 2554*da0073e9SAndroid Build Coastguard Worker x_ref = x.clone().detach().requires_grad_(True) 2555*da0073e9SAndroid Build Coastguard Worker 2556*da0073e9SAndroid Build Coastguard Worker m = torch.nn.Linear(2, 2) 2557*da0073e9SAndroid Build Coastguard Worker m_compiled = torch.compile(m) 2558*da0073e9SAndroid Build Coastguard Worker 2559*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2560*da0073e9SAndroid Build Coastguard Worker # compile an inference graph 2561*da0073e9SAndroid Build Coastguard Worker m_compiled(x) 2562*da0073e9SAndroid Build Coastguard Worker 2563*da0073e9SAndroid Build Coastguard Worker # Test: we should fail guards and recompile a training graph 2564*da0073e9SAndroid Build Coastguard Worker out_ref = m(x_ref) 2565*da0073e9SAndroid Build Coastguard Worker out = m_compiled(x) 2566*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 2567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref.requires_grad, out.requires_grad) 2568*da0073e9SAndroid Build Coastguard Worker 2569*da0073e9SAndroid Build Coastguard Worker def test_is_symbolic_tracing(self): 2570*da0073e9SAndroid Build Coastguard Worker # Ensure no graph break here 2571*da0073e9SAndroid Build Coastguard Worker def fn(x): 2572*da0073e9SAndroid Build Coastguard Worker if is_fx_tracing_test(): 2573*da0073e9SAndroid Build Coastguard Worker return x * 2 2574*da0073e9SAndroid Build Coastguard Worker return x * 4 2575*da0073e9SAndroid Build Coastguard Worker 2576*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4) 2577*da0073e9SAndroid Build Coastguard Worker ref = fn(a) 2578*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2579*da0073e9SAndroid Build Coastguard Worker res = opt_fn(a) 2580*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker def test_tokenization(self): 2583*da0073e9SAndroid Build Coastguard Worker from collections import UserDict 2584*da0073e9SAndroid Build Coastguard Worker 2585*da0073e9SAndroid Build Coastguard Worker class BatchEncoding(UserDict): 2586*da0073e9SAndroid Build Coastguard Worker """ 2587*da0073e9SAndroid Build Coastguard Worker Copied from tokenization 2588*da0073e9SAndroid Build Coastguard Worker """ 2589*da0073e9SAndroid Build Coastguard Worker 2590*da0073e9SAndroid Build Coastguard Worker def __init__( 2591*da0073e9SAndroid Build Coastguard Worker self, 2592*da0073e9SAndroid Build Coastguard Worker data, 2593*da0073e9SAndroid Build Coastguard Worker ): 2594*da0073e9SAndroid Build Coastguard Worker super().__init__(data) 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, item: str): 2597*da0073e9SAndroid Build Coastguard Worker try: 2598*da0073e9SAndroid Build Coastguard Worker return self.data[item] 2599*da0073e9SAndroid Build Coastguard Worker except KeyError as e: 2600*da0073e9SAndroid Build Coastguard Worker raise AttributeError from e 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker def tokenization(x): 2603*da0073e9SAndroid Build Coastguard Worker encoding = BatchEncoding({"key": x}) 2604*da0073e9SAndroid Build Coastguard Worker return encoding["key"] 2605*da0073e9SAndroid Build Coastguard Worker 2606*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(tokenization) 2607*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1, 4)) 2608*da0073e9SAndroid Build Coastguard Worker ref = tokenization(x) 2609*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 2610*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 2611*da0073e9SAndroid Build Coastguard Worker 2612*da0073e9SAndroid Build Coastguard Worker def test_modules(self): 2613*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2614*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2615*da0073e9SAndroid Build Coastguard Worker super().__init__() 2616*da0073e9SAndroid Build Coastguard Worker self.fc = torch.nn.Linear(4, 3) 2617*da0073e9SAndroid Build Coastguard Worker 2618*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 2619*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(3, 3) 2620*da0073e9SAndroid Build Coastguard Worker for mod in self.modules(): 2621*da0073e9SAndroid Build Coastguard Worker res += self.fc(inp) 2622*da0073e9SAndroid Build Coastguard Worker return res 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2625*da0073e9SAndroid Build Coastguard Worker args = (torch.ones(3, 4),) 2626*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2627*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod) 2628*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(mod(*args), opt_mod(*args))) 2629*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 5) 2630*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2631*da0073e9SAndroid Build Coastguard Worker 2632*da0073e9SAndroid Build Coastguard Worker def test_omegaconf_listconfig_iter(self): 2633*da0073e9SAndroid Build Coastguard Worker obj = ListConfig() 2634*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 2635*da0073e9SAndroid Build Coastguard Worker 2636*da0073e9SAndroid Build Coastguard Worker def fn(): 2637*da0073e9SAndroid Build Coastguard Worker y = x 2638*da0073e9SAndroid Build Coastguard Worker for i in obj: 2639*da0073e9SAndroid Build Coastguard Worker y += i 2640*da0073e9SAndroid Build Coastguard Worker return y 2641*da0073e9SAndroid Build Coastguard Worker 2642*da0073e9SAndroid Build Coastguard Worker expected = fn() 2643*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(fn, fullgraph=True, backend="eager")() 2644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 2645*da0073e9SAndroid Build Coastguard Worker 2646*da0073e9SAndroid Build Coastguard Worker def test_user_defined_iter(self): 2647*da0073e9SAndroid Build Coastguard Worker class MyIter: 2648*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2649*da0073e9SAndroid Build Coastguard Worker self.i = 0 2650*da0073e9SAndroid Build Coastguard Worker 2651*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 2652*da0073e9SAndroid Build Coastguard Worker return self 2653*da0073e9SAndroid Build Coastguard Worker 2654*da0073e9SAndroid Build Coastguard Worker def __next__(self): 2655*da0073e9SAndroid Build Coastguard Worker if self.i < 3: 2656*da0073e9SAndroid Build Coastguard Worker self.i += 1 2657*da0073e9SAndroid Build Coastguard Worker return self.i 2658*da0073e9SAndroid Build Coastguard Worker raise StopIteration 2659*da0073e9SAndroid Build Coastguard Worker 2660*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 2661*da0073e9SAndroid Build Coastguard Worker def fn(x): 2662*da0073e9SAndroid Build Coastguard Worker for i in MyIter(): 2663*da0073e9SAndroid Build Coastguard Worker x += i 2664*da0073e9SAndroid Build Coastguard Worker return x 2665*da0073e9SAndroid Build Coastguard Worker 2666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0)) 2667*da0073e9SAndroid Build Coastguard Worker 2668*da0073e9SAndroid Build Coastguard Worker def test_stop_iteration_reconstruct(self): 2669*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 2670*da0073e9SAndroid Build Coastguard Worker def fn(x): 2671*da0073e9SAndroid Build Coastguard Worker return x.sin(), StopIteration(1, 2, 3) 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker _, res = fn(torch.ones(1)) 2674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(res), str(StopIteration(1, 2, 3))) 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker def test_tensor_data_kwarg(self): 2677*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/96278 2678*da0073e9SAndroid Build Coastguard Worker def f(): 2679*da0073e9SAndroid Build Coastguard Worker return torch.tensor(data=[[1.0, -1.0]]) 2680*da0073e9SAndroid Build Coastguard Worker 2681*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2682*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 2683*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(), opt_fn())) 2684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2685*da0073e9SAndroid Build Coastguard Worker 2686*da0073e9SAndroid Build Coastguard Worker @requires_cuda 2687*da0073e9SAndroid Build Coastguard Worker def test_norm_dtype(self): 2688*da0073e9SAndroid Build Coastguard Worker def foo(_stack0): 2689*da0073e9SAndroid Build Coastguard Worker getitem = _stack0[(slice(None, None, None), -1)] 2690*da0073e9SAndroid Build Coastguard Worker _stack0 = None 2691*da0073e9SAndroid Build Coastguard Worker normalize = torch.nn.functional.normalize(getitem, p=2, dim=1) 2692*da0073e9SAndroid Build Coastguard Worker getitem = None 2693*da0073e9SAndroid Build Coastguard Worker return (normalize,) 2694*da0073e9SAndroid Build Coastguard Worker 2695*da0073e9SAndroid Build Coastguard Worker args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)] 2696*da0073e9SAndroid Build Coastguard Worker args = [ 2697*da0073e9SAndroid Build Coastguard Worker rand_strided(sh, st, dt, dev).requires_grad_(rg) 2698*da0073e9SAndroid Build Coastguard Worker for (sh, st, dt, dev, rg) in args 2699*da0073e9SAndroid Build Coastguard Worker ] 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker opt_foo = torch._dynamo.optimize("aot_eager_decomp_partition")(foo) 2702*da0073e9SAndroid Build Coastguard Worker with torch.cuda.amp.autocast(enabled=True): 2703*da0073e9SAndroid Build Coastguard Worker ref = foo(*args)[0] 2704*da0073e9SAndroid Build Coastguard Worker res = foo(*args)[0] 2705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref.dtype, res.dtype) 2706*da0073e9SAndroid Build Coastguard Worker 2707*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(res, ref)) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker def test_for_loop_graph_break(self): 2710*da0073e9SAndroid Build Coastguard Worker def inner(x): 2711*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2712*da0073e9SAndroid Build Coastguard Worker 2713*da0073e9SAndroid Build Coastguard Worker def fn(x): 2714*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 2715*da0073e9SAndroid Build Coastguard Worker inner(x) 2716*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2717*da0073e9SAndroid Build Coastguard Worker return x 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2720*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2721*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2722*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 2725*da0073e9SAndroid Build Coastguard Worker 2726*da0073e9SAndroid Build Coastguard Worker def test_for_loop_graph_break_before(self): 2727*da0073e9SAndroid Build Coastguard Worker # Checks that the backedge is calculated correctly 2728*da0073e9SAndroid Build Coastguard Worker def inner(x): 2729*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2730*da0073e9SAndroid Build Coastguard Worker 2731*da0073e9SAndroid Build Coastguard Worker def fn(x): 2732*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2733*da0073e9SAndroid Build Coastguard Worker for _ in range(100): 2734*da0073e9SAndroid Build Coastguard Worker inner(x) 2735*da0073e9SAndroid Build Coastguard Worker return x 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2738*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2739*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2740*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 100) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker def test_avoid_dupe_specialization(self): 2745*da0073e9SAndroid Build Coastguard Worker def f(x, y): 2746*da0073e9SAndroid Build Coastguard Worker return (x + y) * 1 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize("aot_eager")(f) 2749*da0073e9SAndroid Build Coastguard Worker 2750*da0073e9SAndroid Build Coastguard Worker for b in [True, False]: 2751*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=b) 2752*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, requires_grad=b) 2753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, x), opt_f(x, x)) 2754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(x, y), opt_f(x, y)) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker def test_validate_model_kwargs(self): 2757*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2758*da0073e9SAndroid Build Coastguard Worker 2759*da0073e9SAndroid Build Coastguard Worker def f1(a, b): 2760*da0073e9SAndroid Build Coastguard Worker return torch.sin(a) + torch.cos(b) 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 2763*da0073e9SAndroid Build Coastguard Worker def f2(**kwargs): 2764*da0073e9SAndroid Build Coastguard Worker _validate_model_kwargs(f1, kwargs) 2765*da0073e9SAndroid Build Coastguard Worker return f1(**kwargs) 2766*da0073e9SAndroid Build Coastguard Worker 2767*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 2768*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f2(a=x, b=y), f1(x, y)) 2771*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2772*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker def test_swin_base_tensor_attr(self): 2775*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 2776*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2777*da0073e9SAndroid Build Coastguard Worker super().__init__() 2778*da0073e9SAndroid Build Coastguard Worker # NB: not a parameter or buffer 2779*da0073e9SAndroid Build Coastguard Worker self.t = torch.randn(3) 2780*da0073e9SAndroid Build Coastguard Worker 2781*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2782*da0073e9SAndroid Build Coastguard Worker return x + torch.cat((self.t, self.t)) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker mod = Foo() 2785*da0073e9SAndroid Build Coastguard Worker opt_mod = torch._dynamo.optimize("eager")(mod) 2786*da0073e9SAndroid Build Coastguard Worker args = [torch.randn(6)] 2787*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_two_models(mod, opt_mod, args)) 2788*da0073e9SAndroid Build Coastguard Worker opt_mod(*args) 2789*da0073e9SAndroid Build Coastguard Worker 2790*da0073e9SAndroid Build Coastguard Worker def test_pointless_graph_removal(self): 2791*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2792*da0073e9SAndroid Build Coastguard Worker 2793*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt) 2794*da0073e9SAndroid Build Coastguard Worker def fn(x): 2795*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2796*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2797*da0073e9SAndroid Build Coastguard Worker return x + 1 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4)) 2800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker def test_output_aliases_intermediate(self): 2804*da0073e9SAndroid Build Coastguard Worker def f(x): 2805*da0073e9SAndroid Build Coastguard Worker intermediate = x.mul(2) 2806*da0073e9SAndroid Build Coastguard Worker return intermediate.view(-1), intermediate 2807*da0073e9SAndroid Build Coastguard Worker 2808*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize("aot_eager")(f) 2809*da0073e9SAndroid Build Coastguard Worker 2810*da0073e9SAndroid Build Coastguard Worker for b in [True, False]: 2811*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=b) 2812*da0073e9SAndroid Build Coastguard Worker out = f(x) 2813*da0073e9SAndroid Build Coastguard Worker out_test = opt_f(x) 2814*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], out_test[0]) 2815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[1], out_test[1]) 2816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0].requires_grad, out_test[0].requires_grad) 2817*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[1].requires_grad, out_test[1].requires_grad) 2818*da0073e9SAndroid Build Coastguard Worker # test that the aliasing relationship of outputs is preserved 2819*da0073e9SAndroid Build Coastguard Worker out[0].mul_(2) 2820*da0073e9SAndroid Build Coastguard Worker out_test[0].mul_(2) 2821*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[0], out_test[0]) 2822*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out[1], out_test[1]) 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker def test_while_loop_graph_break(self): 2825*da0073e9SAndroid Build Coastguard Worker # Repro of tacotron2 cache_size_recompilation 2826*da0073e9SAndroid Build Coastguard Worker def inner(x): 2827*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker def fn(x): 2830*da0073e9SAndroid Build Coastguard Worker i = 20 2831*da0073e9SAndroid Build Coastguard Worker while i > 10: 2832*da0073e9SAndroid Build Coastguard Worker x = inner(x) 2833*da0073e9SAndroid Build Coastguard Worker i -= 1 2834*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2835*da0073e9SAndroid Build Coastguard Worker return x 2836*da0073e9SAndroid Build Coastguard Worker 2837*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2838*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2839*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2840*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 2843*da0073e9SAndroid Build Coastguard Worker 2844*da0073e9SAndroid Build Coastguard Worker def test_nested_while_loop_graph_break(self): 2845*da0073e9SAndroid Build Coastguard Worker def inner_loop(x): 2846*da0073e9SAndroid Build Coastguard Worker i = 3 2847*da0073e9SAndroid Build Coastguard Worker while i > 0: 2848*da0073e9SAndroid Build Coastguard Worker i -= 1 2849*da0073e9SAndroid Build Coastguard Worker x += 1 2850*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2851*da0073e9SAndroid Build Coastguard Worker return x 2852*da0073e9SAndroid Build Coastguard Worker 2853*da0073e9SAndroid Build Coastguard Worker def inner(x): 2854*da0073e9SAndroid Build Coastguard Worker inner_loop(x) 2855*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 2856*da0073e9SAndroid Build Coastguard Worker 2857*da0073e9SAndroid Build Coastguard Worker def fn(x): 2858*da0073e9SAndroid Build Coastguard Worker i = 20 2859*da0073e9SAndroid Build Coastguard Worker while i > 10: 2860*da0073e9SAndroid Build Coastguard Worker x = inner(x) 2861*da0073e9SAndroid Build Coastguard Worker i -= 1 2862*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2863*da0073e9SAndroid Build Coastguard Worker return x 2864*da0073e9SAndroid Build Coastguard Worker 2865*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2866*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2867*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2868*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2869*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2870*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 2871*da0073e9SAndroid Build Coastguard Worker 2872*da0073e9SAndroid Build Coastguard Worker def test_while_loop_graph_break_inside_call_function(self): 2873*da0073e9SAndroid Build Coastguard Worker # Repro of huggingface graph break inside loop in `get_parameter_dtype`. 2874*da0073e9SAndroid Build Coastguard Worker # Skip only the inner frame that has loop that contains graph break. 2875*da0073e9SAndroid Build Coastguard Worker def inner(x): 2876*da0073e9SAndroid Build Coastguard Worker for i in range(3): 2877*da0073e9SAndroid Build Coastguard Worker x += 1 2878*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 2879*da0073e9SAndroid Build Coastguard Worker return x 2880*da0073e9SAndroid Build Coastguard Worker 2881*da0073e9SAndroid Build Coastguard Worker def fn(x): 2882*da0073e9SAndroid Build Coastguard Worker x += 2 2883*da0073e9SAndroid Build Coastguard Worker inner(x) 2884*da0073e9SAndroid Build Coastguard Worker x += 3 2885*da0073e9SAndroid Build Coastguard Worker return x 2886*da0073e9SAndroid Build Coastguard Worker 2887*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 2888*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 2889*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 2890*da0073e9SAndroid Build Coastguard Worker opt_fn(x) 2891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 2892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 2893*da0073e9SAndroid Build Coastguard Worker 2894*da0073e9SAndroid Build Coastguard Worker def test_exception_in_dynamo_handling(self): 2895*da0073e9SAndroid Build Coastguard Worker hit_handler = False 2896*da0073e9SAndroid Build Coastguard Worker 2897*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/96488 2898*da0073e9SAndroid Build Coastguard Worker @contextlib.contextmanager 2899*da0073e9SAndroid Build Coastguard Worker def ctx(): 2900*da0073e9SAndroid Build Coastguard Worker try: 2901*da0073e9SAndroid Build Coastguard Worker yield 2902*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 2903*da0073e9SAndroid Build Coastguard Worker nonlocal hit_handler 2904*da0073e9SAndroid Build Coastguard Worker hit_handler = True 2905*da0073e9SAndroid Build Coastguard Worker 2906*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 2907*da0073e9SAndroid Build Coastguard Worker def f(): 2908*da0073e9SAndroid Build Coastguard Worker with ctx(): 2909*da0073e9SAndroid Build Coastguard Worker h() 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker def h(): 2912*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("boof") 2913*da0073e9SAndroid Build Coastguard Worker 2914*da0073e9SAndroid Build Coastguard Worker # Should not error 2915*da0073e9SAndroid Build Coastguard Worker f() 2916*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hit_handler) 2917*da0073e9SAndroid Build Coastguard Worker 2918*da0073e9SAndroid Build Coastguard Worker def test_generator_dealloc(self): 2919*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/pull/96488 2920*da0073e9SAndroid Build Coastguard Worker # 2921*da0073e9SAndroid Build Coastguard Worker # NB: yes, [(...)] is intentional, this is a list containing a 2922*da0073e9SAndroid Build Coastguard Worker # generator 2923*da0073e9SAndroid Build Coastguard Worker generator_box = [(x for x in [1, 2, 3])] 2924*da0073e9SAndroid Build Coastguard Worker 2925*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 2926*da0073e9SAndroid Build Coastguard Worker 2927*da0073e9SAndroid Build Coastguard Worker def g(x): 2928*da0073e9SAndroid Build Coastguard Worker return x + 2 2929*da0073e9SAndroid Build Coastguard Worker 2930*da0073e9SAndroid Build Coastguard Worker # TODO: This test is pretty delicate. To test if it's actually doing 2931*da0073e9SAndroid Build Coastguard Worker # anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1' 2932*da0073e9SAndroid Build Coastguard Worker # and then look at the logs for: 2933*da0073e9SAndroid Build Coastguard Worker # 2934*da0073e9SAndroid Build Coastguard Worker # TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0 2935*da0073e9SAndroid Build Coastguard Worker # TRACE[_custom_eval_frame:664] throw <genexpr> 2936*da0073e9SAndroid Build Coastguard Worker # 2937*da0073e9SAndroid Build Coastguard Worker # This means we're actually hitting the relevant codepath 2938*da0073e9SAndroid Build Coastguard Worker 2939*da0073e9SAndroid Build Coastguard Worker # NB: Make sure we don't actually Dynamo this frame; if we do Dynamo 2940*da0073e9SAndroid Build Coastguard Worker # this frame, Dynamo actually DOES understand list.clear and will 2941*da0073e9SAndroid Build Coastguard Worker # arrange for the generator deallocation to happen when the eval frame 2942*da0073e9SAndroid Build Coastguard Worker # handler is disabled, which will prevent the bug from happening (we 2943*da0073e9SAndroid Build Coastguard Worker # specifically want to trigger the generator deallocation WHILE the 2944*da0073e9SAndroid Build Coastguard Worker # dynamo eval frame handler is active), as that will cause the 2945*da0073e9SAndroid Build Coastguard Worker # generator to become exhausted and trigger the throw_flag == TRUE 2946*da0073e9SAndroid Build Coastguard Worker # case. 2947*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.disable(recursive=False) 2948*da0073e9SAndroid Build Coastguard Worker def f(x): 2949*da0073e9SAndroid Build Coastguard Worker generator_box.clear() 2950*da0073e9SAndroid Build Coastguard Worker return g(x) 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker self.assertNoUnraisable( 2953*da0073e9SAndroid Build Coastguard Worker lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3)) 2954*da0073e9SAndroid Build Coastguard Worker ) 2955*da0073e9SAndroid Build Coastguard Worker 2956*da0073e9SAndroid Build Coastguard Worker # Make sure the x + 2 is captured (a previous incorrect implementation 2957*da0073e9SAndroid Build Coastguard Worker # of this fix would have disabled the eval frame callback, which means 2958*da0073e9SAndroid Build Coastguard Worker # g wouldn't get traced 2959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 1) 2960*da0073e9SAndroid Build Coastguard Worker 2961*da0073e9SAndroid Build Coastguard Worker def test_error_return_without_exception_set(self): 2962*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/93781 2963*da0073e9SAndroid Build Coastguard Worker @torch.compile 2964*da0073e9SAndroid Build Coastguard Worker def f(): 2965*da0073e9SAndroid Build Coastguard Worker _generator_type = type(_ for _ in ()) 2966*da0073e9SAndroid Build Coastguard Worker 2967*da0073e9SAndroid Build Coastguard Worker self.assertNoUnraisable(f) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker def common_merge_criteria_processor_list(self, list_cls, fullgraph): 2970*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 2971*da0073e9SAndroid Build Coastguard Worker 2972*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=fullgraph) 2973*da0073e9SAndroid Build Coastguard Worker def f(x, left, right): 2974*da0073e9SAndroid Build Coastguard Worker combined = _merge_criteria_processor_list(left, right) 2975*da0073e9SAndroid Build Coastguard Worker return combined(x) 2976*da0073e9SAndroid Build Coastguard Worker 2977*da0073e9SAndroid Build Coastguard Worker l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()]) 2978*da0073e9SAndroid Build Coastguard Worker l2 = list_cls([]) 2979*da0073e9SAndroid Build Coastguard Worker input = torch.randn(16) 2980*da0073e9SAndroid Build Coastguard Worker result = f(input, l1, l2) 2981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, l1(input)) 2982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2983*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 2984*da0073e9SAndroid Build Coastguard Worker 2985*da0073e9SAndroid Build Coastguard Worker cnt.clear() 2986*da0073e9SAndroid Build Coastguard Worker l3 = list_cls([torch.nn.SiLU()]) 2987*da0073e9SAndroid Build Coastguard Worker expected = l3(l1(input)) 2988*da0073e9SAndroid Build Coastguard Worker result = f(input, l1, l3) 2989*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(l1), 3) 2990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 2991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 2992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 2993*da0073e9SAndroid Build Coastguard Worker 2994*da0073e9SAndroid Build Coastguard Worker def test_merge_criteria_processor_list1(self): 2995*da0073e9SAndroid Build Coastguard Worker self.common_merge_criteria_processor_list(CustomList1, False) 2996*da0073e9SAndroid Build Coastguard Worker 2997*da0073e9SAndroid Build Coastguard Worker def test_merge_criteria_processor_list2(self): 2998*da0073e9SAndroid Build Coastguard Worker self.common_merge_criteria_processor_list(CustomList2, True) 2999*da0073e9SAndroid Build Coastguard Worker 3000*da0073e9SAndroid Build Coastguard Worker def test_restricted_list_subclass1(self): 3001*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 3002*da0073e9SAndroid Build Coastguard Worker 3003*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 3004*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3005*da0073e9SAndroid Build Coastguard Worker l = CustomList2() 3006*da0073e9SAndroid Build Coastguard Worker l.extend([True]) 3007*da0073e9SAndroid Build Coastguard Worker l.append(a) 3008*da0073e9SAndroid Build Coastguard Worker l.extend([b]) 3009*da0073e9SAndroid Build Coastguard Worker l.pop(0) 3010*da0073e9SAndroid Build Coastguard Worker l.append(l.length_times_10()) 3011*da0073e9SAndroid Build Coastguard Worker return sum(l) 3012*da0073e9SAndroid Build Coastguard Worker 3013*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3014*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 3015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, y), x + y + 20) 3016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 3017*da0073e9SAndroid Build Coastguard Worker 3018*da0073e9SAndroid Build Coastguard Worker def test_restricted_list_subclass2(self): 3019*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 3020*da0073e9SAndroid Build Coastguard Worker 3021*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 3022*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 3023*da0073e9SAndroid Build Coastguard Worker l1 = CustomList2([a + 1]) 3024*da0073e9SAndroid Build Coastguard Worker l2 = CustomList2([b + 2]) 3025*da0073e9SAndroid Build Coastguard Worker l1.extend(l2) 3026*da0073e9SAndroid Build Coastguard Worker return l1 3027*da0073e9SAndroid Build Coastguard Worker 3028*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3029*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 3030*da0073e9SAndroid Build Coastguard Worker z = fn(x, y) 3031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(z), CustomList2) 3032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(z), 2) 3033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.length_times_10(), 20) 3034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(z), [x + 1, y + 2]) 3035*da0073e9SAndroid Build Coastguard Worker 3036*da0073e9SAndroid Build Coastguard Worker def test_restricted_list_subclass3(self): 3037*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 3038*da0073e9SAndroid Build Coastguard Worker 3039*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 3040*da0073e9SAndroid Build Coastguard Worker def fn(a: CustomList2, b: CustomList2): 3041*da0073e9SAndroid Build Coastguard Worker a.extend(b) 3042*da0073e9SAndroid Build Coastguard Worker a.append_twice(b[2] + 1) 3043*da0073e9SAndroid Build Coastguard Worker a.append(b[3] + 2) 3044*da0073e9SAndroid Build Coastguard Worker return b 3045*da0073e9SAndroid Build Coastguard Worker 3046*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3047*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 3048*da0073e9SAndroid Build Coastguard Worker l = CustomList2([x, y]) 3049*da0073e9SAndroid Build Coastguard Worker self.assertIs(fn(l, l), l) 3050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(l), 7) 3051*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[0], x) 3052*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[1], y) 3053*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[2], x) 3054*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[3], y) 3055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[4], x + 1) 3056*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[5], l[4]) 3057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[6], y + 2) 3058*da0073e9SAndroid Build Coastguard Worker 3059*da0073e9SAndroid Build Coastguard Worker def test_rewrite_assert_with_msg(self): 3060*da0073e9SAndroid Build Coastguard Worker def f(x): 3061*da0073e9SAndroid Build Coastguard Worker b = x.sin() 3062*da0073e9SAndroid Build Coastguard Worker assert x[0] == 3, "First dim need to be 3" 3063*da0073e9SAndroid Build Coastguard Worker return x.cos() + b 3064*da0073e9SAndroid Build Coastguard Worker 3065*da0073e9SAndroid Build Coastguard Worker args = (torch.Tensor([3, 4, 5]),) 3066*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3067*da0073e9SAndroid Build Coastguard Worker 3068*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3069*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(*args), opt_f(*args))) 3070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 6) 3071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3072*da0073e9SAndroid Build Coastguard Worker 3073*da0073e9SAndroid Build Coastguard Worker exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3074*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(exported(*args), f(*args))) 3075*da0073e9SAndroid Build Coastguard Worker 3076*da0073e9SAndroid Build Coastguard Worker def test_list_aliasing(self): 3077*da0073e9SAndroid Build Coastguard Worker cnt = CompileCounter() 3078*da0073e9SAndroid Build Coastguard Worker 3079*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 3080*da0073e9SAndroid Build Coastguard Worker def fn(a): 3081*da0073e9SAndroid Build Coastguard Worker a.append(torch.sin(a[0])) 3082*da0073e9SAndroid Build Coastguard Worker return a 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10) 3085*da0073e9SAndroid Build Coastguard Worker l = [x] 3086*da0073e9SAndroid Build Coastguard Worker self.assertIs(fn(l), l) 3087*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(l), 2) 3088*da0073e9SAndroid Build Coastguard Worker self.assertIs(l[0], x) 3089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l[1], torch.sin(x)) 3090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 1) 3092*da0073e9SAndroid Build Coastguard Worker 3093*da0073e9SAndroid Build Coastguard Worker def test_not_rewrite_assert_for_other_errors(self): 3094*da0073e9SAndroid Build Coastguard Worker def f(x): 3095*da0073e9SAndroid Build Coastguard Worker b = x.sin() 3096*da0073e9SAndroid Build Coastguard Worker if not x.sum() <= 3: 3097*da0073e9SAndroid Build Coastguard Worker raise ValueError("input sum needs to be 3") 3098*da0073e9SAndroid Build Coastguard Worker return x.cos() + b 3099*da0073e9SAndroid Build Coastguard Worker 3100*da0073e9SAndroid Build Coastguard Worker args = (torch.Tensor([3, 4, 5]),) 3101*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(f) 3102*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): 3103*da0073e9SAndroid Build Coastguard Worker opt_fn(*args) 3104*da0073e9SAndroid Build Coastguard Worker 3105*da0073e9SAndroid Build Coastguard Worker def test_rewrite_assert_dont_change_bytecode(self): 3106*da0073e9SAndroid Build Coastguard Worker def fn(x): 3107*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3108*da0073e9SAndroid Build Coastguard Worker assert x.max() < 5, f"invalid max {x.max()}" 3109*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 3110*da0073e9SAndroid Build Coastguard Worker return x 3111*da0073e9SAndroid Build Coastguard Worker 3112*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4) 3113*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager")(fn) 3114*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(x), opt_fn(x))) 3115*da0073e9SAndroid Build Coastguard Worker 3116*da0073e9SAndroid Build Coastguard Worker def test_rewrite_assert_without_msg(self): 3117*da0073e9SAndroid Build Coastguard Worker def f(x): 3118*da0073e9SAndroid Build Coastguard Worker b = x.sin() 3119*da0073e9SAndroid Build Coastguard Worker assert x[0] == 3 3120*da0073e9SAndroid Build Coastguard Worker return x.cos() + b 3121*da0073e9SAndroid Build Coastguard Worker 3122*da0073e9SAndroid Build Coastguard Worker args = (torch.Tensor([3, 4, 5]),) 3123*da0073e9SAndroid Build Coastguard Worker exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3124*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(exported(*args), f(*args))) 3125*da0073e9SAndroid Build Coastguard Worker 3126*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "assertion error"): 3127*da0073e9SAndroid Build Coastguard Worker exported(torch.Tensor([5, 6, 7])) 3128*da0073e9SAndroid Build Coastguard Worker 3129*da0073e9SAndroid Build Coastguard Worker def test_rewrite_assert_with_non_string_msg(self): 3130*da0073e9SAndroid Build Coastguard Worker def f(x): 3131*da0073e9SAndroid Build Coastguard Worker b = x.sin() 3132*da0073e9SAndroid Build Coastguard Worker assert x[0] == 2, x.size() 3133*da0073e9SAndroid Build Coastguard Worker return x.cos() + b 3134*da0073e9SAndroid Build Coastguard Worker 3135*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters.clear() 3136*da0073e9SAndroid Build Coastguard Worker args = torch.Tensor([3, 4, 5]) 3137*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize("eager")(f) 3138*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "torch.Size"): 3139*da0073e9SAndroid Build Coastguard Worker opt_f(args) 3140*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3141*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters["graph_break"][ 3142*da0073e9SAndroid Build Coastguard Worker "assert with non-string message" 3143*da0073e9SAndroid Build Coastguard Worker ], 3144*da0073e9SAndroid Build Coastguard Worker 1, 3145*da0073e9SAndroid Build Coastguard Worker ) 3146*da0073e9SAndroid Build Coastguard Worker 3147*da0073e9SAndroid Build Coastguard Worker def test_rewrite_assert_noop(self): 3148*da0073e9SAndroid Build Coastguard Worker def f(x): 3149*da0073e9SAndroid Build Coastguard Worker b = x.sin() 3150*da0073e9SAndroid Build Coastguard Worker assert True 3151*da0073e9SAndroid Build Coastguard Worker assert x.dtype == torch.float32 3152*da0073e9SAndroid Build Coastguard Worker return x.cos() + b 3153*da0073e9SAndroid Build Coastguard Worker 3154*da0073e9SAndroid Build Coastguard Worker args = (torch.Tensor([3, 4, 5]),) 3155*da0073e9SAndroid Build Coastguard Worker exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3156*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(exported(*args), f(*args))) 3157*da0073e9SAndroid Build Coastguard Worker 3158*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3159*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(*args), opt_f(*args))) 3161*da0073e9SAndroid Build Coastguard Worker # torch._assert shouldn't be in the graph 3162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 3) 3163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3164*da0073e9SAndroid Build Coastguard Worker 3165*da0073e9SAndroid Build Coastguard Worker exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5])) 3166*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(exported(*args), f(*args))) 3167*da0073e9SAndroid Build Coastguard Worker 3168*da0073e9SAndroid Build Coastguard Worker def test_size_typematch(self): 3169*da0073e9SAndroid Build Coastguard Worker def f(x, y): 3170*da0073e9SAndroid Build Coastguard Worker if isinstance(x, torch.Size): 3171*da0073e9SAndroid Build Coastguard Worker return y + 1 3172*da0073e9SAndroid Build Coastguard Worker else: 3173*da0073e9SAndroid Build Coastguard Worker return y + 2 3174*da0073e9SAndroid Build Coastguard Worker 3175*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(1) 3176*da0073e9SAndroid Build Coastguard Worker x1 = torch.Size((3,)) 3177*da0073e9SAndroid Build Coastguard Worker x2 = (3,) 3178*da0073e9SAndroid Build Coastguard Worker 3179*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3180*da0073e9SAndroid Build Coastguard Worker opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3181*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(x1, y), opt_f(x1, y))) 3182*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(x2, y), opt_f(x2, y))) 3183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 3184*da0073e9SAndroid Build Coastguard Worker 3185*da0073e9SAndroid Build Coastguard Worker def test_dict_subclass_contains(self): 3186*da0073e9SAndroid Build Coastguard Worker # pattern from huggingface 3187*da0073e9SAndroid Build Coastguard Worker class ClassInstantier(collections.OrderedDict): 3188*da0073e9SAndroid Build Coastguard Worker pass 3189*da0073e9SAndroid Build Coastguard Worker 3190*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 3191*da0073e9SAndroid Build Coastguard Worker def f(x, d): 3192*da0073e9SAndroid Build Coastguard Worker if "key1" in d: 3193*da0073e9SAndroid Build Coastguard Worker x = x + 2 3194*da0073e9SAndroid Build Coastguard Worker if "key2" in d: 3195*da0073e9SAndroid Build Coastguard Worker x = x + 4 3196*da0073e9SAndroid Build Coastguard Worker x = x + 8 3197*da0073e9SAndroid Build Coastguard Worker return x 3198*da0073e9SAndroid Build Coastguard Worker 3199*da0073e9SAndroid Build Coastguard Worker result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)})) 3200*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(result, torch.full([8], 11.0))) 3201*da0073e9SAndroid Build Coastguard Worker 3202*da0073e9SAndroid Build Coastguard Worker result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)})) 3203*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(result, torch.full([8], 13.0))) 3204*da0073e9SAndroid Build Coastguard Worker 3205*da0073e9SAndroid Build Coastguard Worker def test_hf_classinstantier(self): 3206*da0073e9SAndroid Build Coastguard Worker # hf activations.py 3207*da0073e9SAndroid Build Coastguard Worker class ClassInstantier(collections.OrderedDict): 3208*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 3209*da0073e9SAndroid Build Coastguard Worker content = super().__getitem__(key) 3210*da0073e9SAndroid Build Coastguard Worker cls, kwargs = content if isinstance(content, tuple) else (content, {}) 3211*da0073e9SAndroid Build Coastguard Worker return cls(**kwargs) 3212*da0073e9SAndroid Build Coastguard Worker 3213*da0073e9SAndroid Build Coastguard Worker ACT2CLS = ClassInstantier( 3214*da0073e9SAndroid Build Coastguard Worker { 3215*da0073e9SAndroid Build Coastguard Worker "relu": (nn.ReLU, {"inplace": False}), 3216*da0073e9SAndroid Build Coastguard Worker "tanh": nn.Tanh, 3217*da0073e9SAndroid Build Coastguard Worker } 3218*da0073e9SAndroid Build Coastguard Worker ) 3219*da0073e9SAndroid Build Coastguard Worker 3220*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 3221*da0073e9SAndroid Build Coastguard Worker def f(x, act): 3222*da0073e9SAndroid Build Coastguard Worker return ACT2CLS[act](x) 3223*da0073e9SAndroid Build Coastguard Worker 3224*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 3225*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(y, "tanh"), torch.tanh(y))) 3226*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(y, "relu"), torch.relu(y))) 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker def test_ephemeral_module(self): 3229*da0073e9SAndroid Build Coastguard Worker # hf activations.py 3230*da0073e9SAndroid Build Coastguard Worker class ReLUSquaredActivation(nn.Module): 3231*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 3232*da0073e9SAndroid Build Coastguard Worker relu_applied = torch.nn.functional.relu(input) 3233*da0073e9SAndroid Build Coastguard Worker squared = torch.square(relu_applied) 3234*da0073e9SAndroid Build Coastguard Worker return squared 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend="eager") 3237*da0073e9SAndroid Build Coastguard Worker def f(x): 3238*da0073e9SAndroid Build Coastguard Worker x = x + 0.2 3239*da0073e9SAndroid Build Coastguard Worker x = ReLUSquaredActivation()(x) 3240*da0073e9SAndroid Build Coastguard Worker x = x + 1 3241*da0073e9SAndroid Build Coastguard Worker return x 3242*da0073e9SAndroid Build Coastguard Worker 3243*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10) 3244*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1)) 3245*da0073e9SAndroid Build Coastguard Worker 3246*da0073e9SAndroid Build Coastguard Worker def test_inplace_unsqueeze_input(self): 3247*da0073e9SAndroid Build Coastguard Worker def backend(gm, example_inputs): 3248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4])) 3249*da0073e9SAndroid Build Coastguard Worker return gm 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=backend) 3252*da0073e9SAndroid Build Coastguard Worker def fn(x): 3253*da0073e9SAndroid Build Coastguard Worker x.unsqueeze_(0) 3254*da0073e9SAndroid Build Coastguard Worker return x + 1 3255*da0073e9SAndroid Build Coastguard Worker 3256*da0073e9SAndroid Build Coastguard Worker inputs = [torch.randn(3, 4)] 3257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4])) 3258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4])) 3259*da0073e9SAndroid Build Coastguard Worker 3260*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_e2e(self): 3261*da0073e9SAndroid Build Coastguard Worker class Repro(torch.nn.Module): 3262*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3263*da0073e9SAndroid Build Coastguard Worker super().__init__() 3264*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d( 3265*da0073e9SAndroid Build Coastguard Worker 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True 3266*da0073e9SAndroid Build Coastguard Worker ) 3267*da0073e9SAndroid Build Coastguard Worker self.conv1 = torch.nn.Conv2d( 3268*da0073e9SAndroid Build Coastguard Worker 64, 3269*da0073e9SAndroid Build Coastguard Worker 64, 3270*da0073e9SAndroid Build Coastguard Worker kernel_size=(3, 3), 3271*da0073e9SAndroid Build Coastguard Worker stride=(1, 1), 3272*da0073e9SAndroid Build Coastguard Worker padding=(1, 1), 3273*da0073e9SAndroid Build Coastguard Worker bias=False, 3274*da0073e9SAndroid Build Coastguard Worker ) 3275*da0073e9SAndroid Build Coastguard Worker 3276*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3277*da0073e9SAndroid Build Coastguard Worker x1 = self.bn(x) 3278*da0073e9SAndroid Build Coastguard Worker x2 = self.conv1(x1) 3279*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.relu(x2) 3280*da0073e9SAndroid Build Coastguard Worker return (out,) 3281*da0073e9SAndroid Build Coastguard Worker 3282*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1337) 3283*da0073e9SAndroid Build Coastguard Worker 3284*da0073e9SAndroid Build Coastguard Worker m_ref = Repro() 3285*da0073e9SAndroid Build Coastguard Worker m_test = deepcopy(m_ref) 3286*da0073e9SAndroid Build Coastguard Worker 3287*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("aot_eager_decomp_partition") 3288*da0073e9SAndroid Build Coastguard Worker def compiled_fn(x): 3289*da0073e9SAndroid Build Coastguard Worker return m_test(x) 3290*da0073e9SAndroid Build Coastguard Worker 3291*da0073e9SAndroid Build Coastguard Worker x_ref = torch.randn(2, 64, 32, 32, requires_grad=True) 3292*da0073e9SAndroid Build Coastguard Worker x_test = x_ref.clone() 3293*da0073e9SAndroid Build Coastguard Worker 3294*da0073e9SAndroid Build Coastguard Worker # Loop multiple times: each iteration the running_mean/var on batchnorm will update, 3295*da0073e9SAndroid Build Coastguard Worker # which changes the output of the next iteration 3296*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 3297*da0073e9SAndroid Build Coastguard Worker ref = m_ref(x_ref) 3298*da0073e9SAndroid Build Coastguard Worker res = compiled_fn(x_test) 3299*da0073e9SAndroid Build Coastguard Worker 3300*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3301*da0073e9SAndroid Build Coastguard Worker 3302*da0073e9SAndroid Build Coastguard Worker for r in ref: 3303*da0073e9SAndroid Build Coastguard Worker if r.requires_grad: 3304*da0073e9SAndroid Build Coastguard Worker r.sum().backward() 3305*da0073e9SAndroid Build Coastguard Worker for r in res: 3306*da0073e9SAndroid Build Coastguard Worker if r.requires_grad: 3307*da0073e9SAndroid Build Coastguard Worker r.sum().backward() 3308*da0073e9SAndroid Build Coastguard Worker 3309*da0073e9SAndroid Build Coastguard Worker for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()): 3310*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(param_ref, param_test)) 3311*da0073e9SAndroid Build Coastguard Worker # Assert running_mean/var 3312*da0073e9SAndroid Build Coastguard Worker for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()): 3313*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(buffer_ref, buffer_test)) 3314*da0073e9SAndroid Build Coastguard Worker 3315*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("assume_static_by_default", False) 3316*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shapes_right_side(self): 3317*da0073e9SAndroid Build Coastguard Worker def f(x): 3318*da0073e9SAndroid Build Coastguard Worker return torch.ones(5 * x.shape[0]) 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(6, 5) 3321*da0073e9SAndroid Build Coastguard Worker 3322*da0073e9SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) 3323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gm(inp).shape, f(inp).shape) 3324*da0073e9SAndroid Build Coastguard Worker 3325*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("specialize_int", False) 3326*da0073e9SAndroid Build Coastguard Worker def test_maybe_multiply_symint(self): 3327*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/97346 3328*da0073e9SAndroid Build Coastguard Worker from torch._functorch.aot_autograd import aot_module_simplified 3329*da0073e9SAndroid Build Coastguard Worker 3330*da0073e9SAndroid Build Coastguard Worker def my_aot_compiler(gm, example_inputs): 3331*da0073e9SAndroid Build Coastguard Worker def my_compiler(gm, example_inputs): 3332*da0073e9SAndroid Build Coastguard Worker return gm.forward 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker # Invoke AOTAutograd 3335*da0073e9SAndroid Build Coastguard Worker return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler) 3336*da0073e9SAndroid Build Coastguard Worker 3337*da0073e9SAndroid Build Coastguard Worker def my_example(t1, t2, d): 3338*da0073e9SAndroid Build Coastguard Worker out = torch.add(t1, t2, alpha=d) 3339*da0073e9SAndroid Build Coastguard Worker return out 3340*da0073e9SAndroid Build Coastguard Worker 3341*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example) 3342*da0073e9SAndroid Build Coastguard Worker 3343*da0073e9SAndroid Build Coastguard Worker t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True) 3344*da0073e9SAndroid Build Coastguard Worker t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True) 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker ra = compiled_fn(t1, t2, 5) 3347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0])) 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker ra = compiled_fn(t1, t2, 6) 3350*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0])) 3351*da0073e9SAndroid Build Coastguard Worker 3352*da0073e9SAndroid Build Coastguard Worker def test_build_map_unpack_with_call(self): 3353*da0073e9SAndroid Build Coastguard Worker def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2): 3354*da0073e9SAndroid Build Coastguard Worker return x.sin() + t + cond_scale + self_cond + other1 + other2 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 3357*da0073e9SAndroid Build Coastguard Worker def fn(x): 3358*da0073e9SAndroid Build Coastguard Worker d1 = dict(other1=5) 3359*da0073e9SAndroid Build Coastguard Worker d2 = dict(other2=4) 3360*da0073e9SAndroid Build Coastguard Worker text_cond = {**d1, **d2} 3361*da0073e9SAndroid Build Coastguard Worker return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond) 3362*da0073e9SAndroid Build Coastguard Worker 3363*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15)) 3364*da0073e9SAndroid Build Coastguard Worker 3365*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(verbose=True) 3366*da0073e9SAndroid Build Coastguard Worker def test_graph_break_unsupported_fake(self): 3367*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 3368*da0073e9SAndroid Build Coastguard Worker 3369*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(counter) 3370*da0073e9SAndroid Build Coastguard Worker def f(x): 3371*da0073e9SAndroid Build Coastguard Worker return torch.ops.test_sample.foo(x + 1) + 1 3372*da0073e9SAndroid Build Coastguard Worker 3373*da0073e9SAndroid Build Coastguard Worker f(torch.randn(3)) 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.op_count, 2) 3376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker def test_delattr(self): 3379*da0073e9SAndroid Build Coastguard Worker class MyObj: 3380*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 3381*da0073e9SAndroid Build Coastguard Worker self.a = a 3382*da0073e9SAndroid Build Coastguard Worker self.b = b 3383*da0073e9SAndroid Build Coastguard Worker 3384*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 3385*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 3386*da0073e9SAndroid Build Coastguard Worker del obj.a 3387*da0073e9SAndroid Build Coastguard Worker obj.c = x + 1 3388*da0073e9SAndroid Build Coastguard Worker del obj.c 3389*da0073e9SAndroid Build Coastguard Worker tmp = MyObj(x + 2, x + 3) 3390*da0073e9SAndroid Build Coastguard Worker del tmp.b 3391*da0073e9SAndroid Build Coastguard Worker if hasattr(obj, "a"): 3392*da0073e9SAndroid Build Coastguard Worker return x + 1 3393*da0073e9SAndroid Build Coastguard Worker return tmp 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([]) 3396*da0073e9SAndroid Build Coastguard Worker obj1 = MyObj(x, x) 3397*da0073e9SAndroid Build Coastguard Worker obj2 = fn(x, obj1) 3398*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(obj1, "a")) 3399*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(obj1, "c")) 3400*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(obj2, "b")) 3401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj1.b.item(), 0) 3402*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj2.a.item(), 2) 3403*da0073e9SAndroid Build Coastguard Worker 3404*da0073e9SAndroid Build Coastguard Worker def test_delattr_raises(self): 3405*da0073e9SAndroid Build Coastguard Worker class MyObj: 3406*da0073e9SAndroid Build Coastguard Worker def __init__(self, a, b): 3407*da0073e9SAndroid Build Coastguard Worker self.a = a 3408*da0073e9SAndroid Build Coastguard Worker self.b = b 3409*da0073e9SAndroid Build Coastguard Worker 3410*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3411*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 3412*da0073e9SAndroid Build Coastguard Worker del obj.a 3413*da0073e9SAndroid Build Coastguard Worker x = x + 1 3414*da0073e9SAndroid Build Coastguard Worker obj.a # will raise 3415*da0073e9SAndroid Build Coastguard Worker return x 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker x = torch.zeros([]) 3418*da0073e9SAndroid Build Coastguard Worker obj1 = MyObj(x, x) 3419*da0073e9SAndroid Build Coastguard Worker self.assertRaises(AttributeError, lambda: fn(x, obj1)) 3420*da0073e9SAndroid Build Coastguard Worker 3421*da0073e9SAndroid Build Coastguard Worker def test_delsubscr(self): 3422*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3423*da0073e9SAndroid Build Coastguard Worker def fn(x): 3424*da0073e9SAndroid Build Coastguard Worker del x["a"] 3425*da0073e9SAndroid Build Coastguard Worker y = x["b"] + 1 3426*da0073e9SAndroid Build Coastguard Worker return y 3427*da0073e9SAndroid Build Coastguard Worker 3428*da0073e9SAndroid Build Coastguard Worker x = {"a": torch.tensor([1]), "b": torch.tensor([1])} 3429*da0073e9SAndroid Build Coastguard Worker result = fn(x) 3430*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hasattr(x, "a")) 3431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.item(), 2) 3432*da0073e9SAndroid Build Coastguard Worker 3433*da0073e9SAndroid Build Coastguard Worker def test_delsubscr_raises(self): 3434*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3435*da0073e9SAndroid Build Coastguard Worker def fn(x): 3436*da0073e9SAndroid Build Coastguard Worker del x["a"] 3437*da0073e9SAndroid Build Coastguard Worker y = x["a"] + 1 # should raise KeyError 3438*da0073e9SAndroid Build Coastguard Worker return y 3439*da0073e9SAndroid Build Coastguard Worker 3440*da0073e9SAndroid Build Coastguard Worker x = {"a": torch.tensor([1]), "b": torch.tensor([1])} 3441*da0073e9SAndroid Build Coastguard Worker self.assertRaises(KeyError, lambda: fn(x)) 3442*da0073e9SAndroid Build Coastguard Worker 3443*da0073e9SAndroid Build Coastguard Worker def test_attached_attribute_in_dir(self): 3444*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 3445*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3446*da0073e9SAndroid Build Coastguard Worker super().__init__() 3447*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(16, 16) 3448*da0073e9SAndroid Build Coastguard Worker self.relu = torch.nn.ReLU() 3449*da0073e9SAndroid Build Coastguard Worker 3450*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3451*da0073e9SAndroid Build Coastguard Worker return self.relu(self.linear(x)) 3452*da0073e9SAndroid Build Coastguard Worker 3453*da0073e9SAndroid Build Coastguard Worker mod = torch.compile(MyModule(), backend="eager") 3454*da0073e9SAndroid Build Coastguard Worker mod.is_compiled = True 3455*da0073e9SAndroid Build Coastguard Worker self.assertTrue("is_compiled" in dir(mod)) 3456*da0073e9SAndroid Build Coastguard Worker 3457*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 3458*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shapes_implicit_guard(self): 3459*da0073e9SAndroid Build Coastguard Worker def f(x): 3460*da0073e9SAndroid Build Coastguard Worker y = x * x.size(x.shape[0]) 3461*da0073e9SAndroid Build Coastguard Worker torch.sum(y, [y.shape[0]]) 3462*da0073e9SAndroid Build Coastguard Worker return y 3463*da0073e9SAndroid Build Coastguard Worker 3464*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3465*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 3466*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(3, 1, 1, 1, 1)) 3467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3468*da0073e9SAndroid Build Coastguard Worker 3469*da0073e9SAndroid Build Coastguard Worker def test_dalle2_maybe(self): 3470*da0073e9SAndroid Build Coastguard Worker def normalize(x): 3471*da0073e9SAndroid Build Coastguard Worker return x.cos() 3472*da0073e9SAndroid Build Coastguard Worker 3473*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 3474*da0073e9SAndroid Build Coastguard Worker def fn(x, normalize_img): 3475*da0073e9SAndroid Build Coastguard Worker lowres_cond_img = x.sin() 3476*da0073e9SAndroid Build Coastguard Worker lowres_cond_img = maybe(normalize_img)(lowres_cond_img) 3477*da0073e9SAndroid Build Coastguard Worker return lowres_cond_img 3478*da0073e9SAndroid Build Coastguard Worker 3479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos()) 3480*da0073e9SAndroid Build Coastguard Worker 3481*da0073e9SAndroid Build Coastguard Worker def test_functools_wraps(self): 3482*da0073e9SAndroid Build Coastguard Worker def cool_name(x): 3483*da0073e9SAndroid Build Coastguard Worker return x.sin() 3484*da0073e9SAndroid Build Coastguard Worker 3485*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 3486*da0073e9SAndroid Build Coastguard Worker def fn(x): 3487*da0073e9SAndroid Build Coastguard Worker y = x.cos() 3488*da0073e9SAndroid Build Coastguard Worker 3489*da0073e9SAndroid Build Coastguard Worker @functools.wraps(cool_name) 3490*da0073e9SAndroid Build Coastguard Worker def uncool_name(): 3491*da0073e9SAndroid Build Coastguard Worker return cool_name(y) 3492*da0073e9SAndroid Build Coastguard Worker 3493*da0073e9SAndroid Build Coastguard Worker return uncool_name 3494*da0073e9SAndroid Build Coastguard Worker 3495*da0073e9SAndroid Build Coastguard Worker result = fn(torch.ones([])) 3496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.__name__, "cool_name") 3497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result(), torch.ones([]).cos().sin()) 3498*da0073e9SAndroid Build Coastguard Worker 3499*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shapes_float_guard(self): 3500*da0073e9SAndroid Build Coastguard Worker def f(x): 3501*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.dropout(x, x.shape[0] / 6) 3502*da0073e9SAndroid Build Coastguard Worker 3503*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3504*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 3505*da0073e9SAndroid Build Coastguard Worker opt_fn(torch.randn(3)) 3506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3507*da0073e9SAndroid Build Coastguard Worker 3508*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 3509*da0073e9SAndroid Build Coastguard Worker def test_tensor_item(self): 3510*da0073e9SAndroid Build Coastguard Worker def f(x, y): 3511*da0073e9SAndroid Build Coastguard Worker val = y.item() 3512*da0073e9SAndroid Build Coastguard Worker return x.sum() + val 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export( 3515*da0073e9SAndroid Build Coastguard Worker f, 3516*da0073e9SAndroid Build Coastguard Worker aten_graph=True, 3517*da0073e9SAndroid Build Coastguard Worker )( 3518*da0073e9SAndroid Build Coastguard Worker torch.zeros(6, 4), 3519*da0073e9SAndroid Build Coastguard Worker torch.tensor(1), 3520*da0073e9SAndroid Build Coastguard Worker ) 3521*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3522*da0073e9SAndroid Build Coastguard Worker f(torch.zeros(6, 4), torch.tensor(1)), 3523*da0073e9SAndroid Build Coastguard Worker gm(torch.zeros(6, 4), torch.tensor(1)), 3524*da0073e9SAndroid Build Coastguard Worker ) 3525*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3526*da0073e9SAndroid Build Coastguard Worker f(torch.zeros(6, 4), torch.tensor(2)), 3527*da0073e9SAndroid Build Coastguard Worker gm(torch.zeros(6, 4), torch.tensor(2)), 3528*da0073e9SAndroid Build Coastguard Worker ) 3529*da0073e9SAndroid Build Coastguard Worker 3530*da0073e9SAndroid Build Coastguard Worker def test_dataclass_init_with_default_factory_with_inputs(self): 3531*da0073e9SAndroid Build Coastguard Worker @dataclasses.dataclass 3532*da0073e9SAndroid Build Coastguard Worker class DClass: 3533*da0073e9SAndroid Build Coastguard Worker sharding_contexts: Any = dataclasses.field(default_factory=list) 3534*da0073e9SAndroid Build Coastguard Worker a: int = 1 3535*da0073e9SAndroid Build Coastguard Worker 3536*da0073e9SAndroid Build Coastguard Worker def fn(x, inp_list): 3537*da0073e9SAndroid Build Coastguard Worker d = DClass(inp_list) 3538*da0073e9SAndroid Build Coastguard Worker d.sharding_contexts.append(x.sin() + d.a) 3539*da0073e9SAndroid Build Coastguard Worker return d 3540*da0073e9SAndroid Build Coastguard Worker 3541*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3542*da0073e9SAndroid Build Coastguard Worker inp_list1 = [1, 2, 3] 3543*da0073e9SAndroid Build Coastguard Worker inp_list2 = [2, 3, 4] 3544*da0073e9SAndroid Build Coastguard Worker inp_list3 = [1, 2] 3545*da0073e9SAndroid Build Coastguard Worker ref1 = fn(x, inp_list1) 3546*da0073e9SAndroid Build Coastguard Worker ref2 = fn(x, inp_list2) 3547*da0073e9SAndroid Build Coastguard Worker ref3 = fn(x, inp_list3) 3548*da0073e9SAndroid Build Coastguard Worker 3549*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3550*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, fullgraph=True) 3551*da0073e9SAndroid Build Coastguard Worker 3552*da0073e9SAndroid Build Coastguard Worker opt_ret1 = opt_fn(x, inp_list1) 3553*da0073e9SAndroid Build Coastguard Worker opt_ret2 = opt_fn(x, inp_list2) 3554*da0073e9SAndroid Build Coastguard Worker opt_ret3 = opt_fn(x, inp_list3) 3555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts) 3556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts) 3557*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts) 3558*da0073e9SAndroid Build Coastguard Worker 3559*da0073e9SAndroid Build Coastguard Worker def test_list_index(self): 3560*da0073e9SAndroid Build Coastguard Worker for i, list_type in enumerate( 3561*da0073e9SAndroid Build Coastguard Worker ( 3562*da0073e9SAndroid Build Coastguard Worker list, 3563*da0073e9SAndroid Build Coastguard Worker tuple, 3564*da0073e9SAndroid Build Coastguard Worker torch.Size, 3565*da0073e9SAndroid Build Coastguard Worker collections.deque, 3566*da0073e9SAndroid Build Coastguard Worker namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]), 3567*da0073e9SAndroid Build Coastguard Worker ) 3568*da0073e9SAndroid Build Coastguard Worker ): 3569*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 3570*da0073e9SAndroid Build Coastguard Worker for index in ([], [2], [0, 3]): 3571*da0073e9SAndroid Build Coastguard Worker 3572*da0073e9SAndroid Build Coastguard Worker def f(t): 3573*da0073e9SAndroid Build Coastguard Worker if i == 4: # namedtuple 3574*da0073e9SAndroid Build Coastguard Worker xs = list_type(1, 2, 3, 4) 3575*da0073e9SAndroid Build Coastguard Worker else: 3576*da0073e9SAndroid Build Coastguard Worker xs = list_type([1, 2, 3, 4]) 3577*da0073e9SAndroid Build Coastguard Worker res = xs.index(3, *index) 3578*da0073e9SAndroid Build Coastguard Worker return t + res 3579*da0073e9SAndroid Build Coastguard Worker 3580*da0073e9SAndroid Build Coastguard Worker res = torch._dynamo.optimize(backend="eager", nopython=True)(f)( 3581*da0073e9SAndroid Build Coastguard Worker torch.zeros(1) 3582*da0073e9SAndroid Build Coastguard Worker ) 3583*da0073e9SAndroid Build Coastguard Worker 3584*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, torch.tensor([2.0])) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker def test_list_index_not_found(self): 3587*da0073e9SAndroid Build Coastguard Worker def f(t): 3588*da0073e9SAndroid Build Coastguard Worker xs = ["bar", "foo", "baz", "buzz"] 3589*da0073e9SAndroid Build Coastguard Worker res = xs.index("non-existent") 3590*da0073e9SAndroid Build Coastguard Worker return t + res 3591*da0073e9SAndroid Build Coastguard Worker 3592*da0073e9SAndroid Build Coastguard Worker # Raising ValueError from item not found is unsupported 3593*da0073e9SAndroid Build Coastguard Worker with self.assertRaises( 3594*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.Unsupported, 3595*da0073e9SAndroid Build Coastguard Worker ): 3596*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(backend="eager", nopython=True)(f)(torch.zeros(1)) 3597*da0073e9SAndroid Build Coastguard Worker 3598*da0073e9SAndroid Build Coastguard Worker def test_list_index_tensor_unsupported(self): 3599*da0073e9SAndroid Build Coastguard Worker for index in ([], [2], [0, 3]): 3600*da0073e9SAndroid Build Coastguard Worker 3601*da0073e9SAndroid Build Coastguard Worker def f(t): 3602*da0073e9SAndroid Build Coastguard Worker xs = [torch.tensor([i]) for i in range(4)] 3603*da0073e9SAndroid Build Coastguard Worker res = xs.index(torch.tensor([2]), *index) 3604*da0073e9SAndroid Build Coastguard Worker return t + res 3605*da0073e9SAndroid Build Coastguard Worker 3606*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3607*da0073e9SAndroid Build Coastguard Worker torch._dynamo.exc.UserError, "Dynamic control flow is not supported" 3608*da0073e9SAndroid Build Coastguard Worker ): 3609*da0073e9SAndroid Build Coastguard Worker torch._dynamo.optimize(backend="eager", nopython=True)(f)( 3610*da0073e9SAndroid Build Coastguard Worker torch.zeros(1) 3611*da0073e9SAndroid Build Coastguard Worker ) 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker def test_hf_xsoftmax_inference(self): 3614*da0073e9SAndroid Build Coastguard Worker def fn(input, mask): 3615*da0073e9SAndroid Build Coastguard Worker return XSoftmax.apply(input + 1, mask, 1) + 2 3616*da0073e9SAndroid Build Coastguard Worker 3617*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 3618*da0073e9SAndroid Build Coastguard Worker 3619*da0073e9SAndroid Build Coastguard Worker inputs = [ 3620*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10), 3621*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10) < 0, 3622*da0073e9SAndroid Build Coastguard Worker ] 3623*da0073e9SAndroid Build Coastguard Worker expected = fn(*inputs) 3624*da0073e9SAndroid Build Coastguard Worker actual = fn_opt(*inputs) 3625*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(actual, expected)) 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker @mock.patch("torch._dynamo.config.guard_nn_modules", True) 3628*da0073e9SAndroid Build Coastguard Worker def test_hf_xsoftmax_training(self): 3629*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.utils import counters 3630*da0073e9SAndroid Build Coastguard Worker 3631*da0073e9SAndroid Build Coastguard Worker counters.clear() 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker def fn(input, mask): 3634*da0073e9SAndroid Build Coastguard Worker return XSoftmax.apply(input, mask, 1) 3635*da0073e9SAndroid Build Coastguard Worker 3636*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3637*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(fn, backend=cnt, fullgraph=False) 3638*da0073e9SAndroid Build Coastguard Worker 3639*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 3640*da0073e9SAndroid Build Coastguard Worker inputs1 = [ 3641*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10, requires_grad=True), 3642*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10) < 0, 3643*da0073e9SAndroid Build Coastguard Worker ] 3644*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1234) 3645*da0073e9SAndroid Build Coastguard Worker inputs2 = [ 3646*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10, requires_grad=True), 3647*da0073e9SAndroid Build Coastguard Worker torch.randn(4, 10) < 0, 3648*da0073e9SAndroid Build Coastguard Worker ] 3649*da0073e9SAndroid Build Coastguard Worker 3650*da0073e9SAndroid Build Coastguard Worker expected = fn(*inputs1) 3651*da0073e9SAndroid Build Coastguard Worker actual = fn_opt(*inputs2) 3652*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(actual, expected)) 3653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1}) 3654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 2) 3655*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 3656*da0073e9SAndroid Build Coastguard Worker cnt.clear() 3657*da0073e9SAndroid Build Coastguard Worker counters.clear() 3658*da0073e9SAndroid Build Coastguard Worker 3659*da0073e9SAndroid Build Coastguard Worker expected.sum().backward() 3660*da0073e9SAndroid Build Coastguard Worker actual.sum().backward() 3661*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(inputs1[0].grad, inputs2[0].grad)) 3662*da0073e9SAndroid Build Coastguard Worker 3663*da0073e9SAndroid Build Coastguard Worker # currently we don't capture the backwards frame 3664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 0) 3665*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.op_count, 0) 3666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dict(counters["frames"]), {}) 3667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dict(counters["graph_break"]), {}) 3668*da0073e9SAndroid Build Coastguard Worker 3669*da0073e9SAndroid Build Coastguard Worker def test_autograd_function_graph_break(self): 3670*da0073e9SAndroid Build Coastguard Worker class MySin(torch.autograd.Function): 3671*da0073e9SAndroid Build Coastguard Worker @staticmethod 3672*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3673*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3674*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 3675*da0073e9SAndroid Build Coastguard Worker return x.sin() 3676*da0073e9SAndroid Build Coastguard Worker 3677*da0073e9SAndroid Build Coastguard Worker @staticmethod 3678*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 3679*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 3680*da0073e9SAndroid Build Coastguard Worker return gx * x.cos() 3681*da0073e9SAndroid Build Coastguard Worker 3682*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 3683*da0073e9SAndroid Build Coastguard Worker 3684*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3685*da0073e9SAndroid Build Coastguard Worker def fn(x): 3686*da0073e9SAndroid Build Coastguard Worker return MySin.apply(x) 3687*da0073e9SAndroid Build Coastguard Worker 3688*da0073e9SAndroid Build Coastguard Worker y = fn(x) 3689*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 3690*da0073e9SAndroid Build Coastguard Worker 3691*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(y, x) 3692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, x.cos()) 3693*da0073e9SAndroid Build Coastguard Worker 3694*da0073e9SAndroid Build Coastguard Worker def test_jit_trace_errors(self): 3695*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", dynamic=True) 3696*da0073e9SAndroid Build Coastguard Worker def f(x): 3697*da0073e9SAndroid Build Coastguard Worker return x + 1 3698*da0073e9SAndroid Build Coastguard Worker 3699*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3700*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(f, torch.randn(3)) 3701*da0073e9SAndroid Build Coastguard Worker 3702*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch(error_on_nested_jit_trace=False): 3703*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(f, torch.randn(3)) 3704*da0073e9SAndroid Build Coastguard Worker 3705*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("assume_static_by_default", False) 3706*da0073e9SAndroid Build Coastguard Worker def test_tensor_split(self): 3707*da0073e9SAndroid Build Coastguard Worker def f(x): 3708*da0073e9SAndroid Build Coastguard Worker return torch.split(x, x.shape[0] // 2, dim=0)[0] 3709*da0073e9SAndroid Build Coastguard Worker 3710*da0073e9SAndroid Build Coastguard Worker gm, _ = torch._dynamo.export( 3711*da0073e9SAndroid Build Coastguard Worker f, 3712*da0073e9SAndroid Build Coastguard Worker aten_graph=True, 3713*da0073e9SAndroid Build Coastguard Worker )( 3714*da0073e9SAndroid Build Coastguard Worker torch.zeros(6, 4), 3715*da0073e9SAndroid Build Coastguard Worker ) 3716*da0073e9SAndroid Build Coastguard Worker 3717*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4))) 3718*da0073e9SAndroid Build Coastguard Worker 3719*da0073e9SAndroid Build Coastguard Worker def test_optim_state_references_cleared(self): 3720*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(2048, 2048, bias=False) 3721*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2048) 3722*da0073e9SAndroid Build Coastguard Worker state_ref = 0 3723*da0073e9SAndroid Build Coastguard Worker 3724*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01) 3725*da0073e9SAndroid Build Coastguard Worker 3726*da0073e9SAndroid Build Coastguard Worker def opt_step(): 3727*da0073e9SAndroid Build Coastguard Worker optimizer.step() 3728*da0073e9SAndroid Build Coastguard Worker 3729*da0073e9SAndroid Build Coastguard Worker compiled_opt_step = torch._dynamo.optimize("eager")(opt_step) 3730*da0073e9SAndroid Build Coastguard Worker 3731*da0073e9SAndroid Build Coastguard Worker def compiled_model_step(x): 3732*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad() 3733*da0073e9SAndroid Build Coastguard Worker y = model(x) 3734*da0073e9SAndroid Build Coastguard Worker torch.sum(y).backward() 3735*da0073e9SAndroid Build Coastguard Worker compiled_opt_step() 3736*da0073e9SAndroid Build Coastguard Worker 3737*da0073e9SAndroid Build Coastguard Worker compiled_model_step(x) 3738*da0073e9SAndroid Build Coastguard Worker 3739*da0073e9SAndroid Build Coastguard Worker # Picked "square_avg" arbitrarily to check that 3740*da0073e9SAndroid Build Coastguard Worker # optimizer state tensors are deallocated 3741*da0073e9SAndroid Build Coastguard Worker state_ref = weakref.ref( 3742*da0073e9SAndroid Build Coastguard Worker optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"] 3743*da0073e9SAndroid Build Coastguard Worker ) 3744*da0073e9SAndroid Build Coastguard Worker optimizer = None 3745*da0073e9SAndroid Build Coastguard Worker 3746*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(state_ref()) 3747*da0073e9SAndroid Build Coastguard Worker 3748*da0073e9SAndroid Build Coastguard Worker def test_grad_references_cleared(self): 3749*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(2048, 2048, bias=False) 3750*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2048) 3751*da0073e9SAndroid Build Coastguard Worker optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01) 3752*da0073e9SAndroid Build Coastguard Worker 3753*da0073e9SAndroid Build Coastguard Worker def opt_step(): 3754*da0073e9SAndroid Build Coastguard Worker optimizer.step() 3755*da0073e9SAndroid Build Coastguard Worker 3756*da0073e9SAndroid Build Coastguard Worker compiled_opt_step = torch._dynamo.optimize("eager")(opt_step) 3757*da0073e9SAndroid Build Coastguard Worker 3758*da0073e9SAndroid Build Coastguard Worker def compiled_model_step(x): 3759*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad(True) 3760*da0073e9SAndroid Build Coastguard Worker y = model(x) 3761*da0073e9SAndroid Build Coastguard Worker torch.sum(y).backward() 3762*da0073e9SAndroid Build Coastguard Worker compiled_opt_step() 3763*da0073e9SAndroid Build Coastguard Worker 3764*da0073e9SAndroid Build Coastguard Worker compiled_model_step(x) 3765*da0073e9SAndroid Build Coastguard Worker param_grad_ref = weakref.ref(next(iter(model.parameters())).grad) 3766*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad(True) 3767*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(param_grad_ref()) 3768*da0073e9SAndroid Build Coastguard Worker 3769*da0073e9SAndroid Build Coastguard Worker def test_batch_encoding_clone_inputs(self): 3770*da0073e9SAndroid Build Coastguard Worker class BatchEncoding(dict): 3771*da0073e9SAndroid Build Coastguard Worker """ 3772*da0073e9SAndroid Build Coastguard Worker Copied from test_tokenization 3773*da0073e9SAndroid Build Coastguard Worker """ 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker def __init__( 3776*da0073e9SAndroid Build Coastguard Worker self, 3777*da0073e9SAndroid Build Coastguard Worker data, 3778*da0073e9SAndroid Build Coastguard Worker ): 3779*da0073e9SAndroid Build Coastguard Worker super().__init__(data) 3780*da0073e9SAndroid Build Coastguard Worker 3781*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, item: str): 3782*da0073e9SAndroid Build Coastguard Worker try: 3783*da0073e9SAndroid Build Coastguard Worker return self.data[item] 3784*da0073e9SAndroid Build Coastguard Worker except KeyError as e: 3785*da0073e9SAndroid Build Coastguard Worker raise AttributeError from e 3786*da0073e9SAndroid Build Coastguard Worker 3787*da0073e9SAndroid Build Coastguard Worker encoding = BatchEncoding({"key": torch.rand((1, 4))}) 3788*da0073e9SAndroid Build Coastguard Worker cloned_encoding = torch._dynamo.utils.clone_inputs(encoding) 3789*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(cloned_encoding) is not dict) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker def test_iadd_graph_break(self): 3792*da0073e9SAndroid Build Coastguard Worker def fn(x): 3793*da0073e9SAndroid Build Coastguard Worker a = () 3794*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 3795*da0073e9SAndroid Build Coastguard Worker a += (x,) 3796*da0073e9SAndroid Build Coastguard Worker return a 3797*da0073e9SAndroid Build Coastguard Worker 3798*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 3799*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 3800*da0073e9SAndroid Build Coastguard Worker 3801*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3802*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 3803*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ref, res)) 3804*da0073e9SAndroid Build Coastguard Worker 3805*da0073e9SAndroid Build Coastguard Worker def test_odict_get_item_index_name(self): 3806*da0073e9SAndroid Build Coastguard Worker d = {float: torch.float32, np.float16: torch.float16} 3807*da0073e9SAndroid Build Coastguard Worker 3808*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 3809*da0073e9SAndroid Build Coastguard Worker def f(x, y1, y2): 3810*da0073e9SAndroid Build Coastguard Worker return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2]) 3811*da0073e9SAndroid Build Coastguard Worker 3812*da0073e9SAndroid Build Coastguard Worker f(torch.zeros(4), float, np.float16) 3813*da0073e9SAndroid Build Coastguard Worker 3814*da0073e9SAndroid Build Coastguard Worker def test_dedup_global(self): 3815*da0073e9SAndroid Build Coastguard Worker @torch.compile() 3816*da0073e9SAndroid Build Coastguard Worker def f(): 3817*da0073e9SAndroid Build Coastguard Worker return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR 3818*da0073e9SAndroid Build Coastguard Worker 3819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR) 3820*da0073e9SAndroid Build Coastguard Worker 3821*da0073e9SAndroid Build Coastguard Worker def test_randint_out_dynamic(self): 3822*da0073e9SAndroid Build Coastguard Worker def randint_fn(high, size, out): 3823*da0073e9SAndroid Build Coastguard Worker return torch.randint(high, size, out=out) 3824*da0073e9SAndroid Build Coastguard Worker 3825*da0073e9SAndroid Build Coastguard Worker opt_model = torch.compile(randint_fn) 3826*da0073e9SAndroid Build Coastguard Worker 3827*da0073e9SAndroid Build Coastguard Worker out1 = torch.empty(10, dtype=torch.int32) 3828*da0073e9SAndroid Build Coastguard Worker opt_model(17, (10,), out1) 3829*da0073e9SAndroid Build Coastguard Worker 3830*da0073e9SAndroid Build Coastguard Worker out2 = torch.empty(12, dtype=torch.int32) 3831*da0073e9SAndroid Build Coastguard Worker opt_model(17, (12,), out2) 3832*da0073e9SAndroid Build Coastguard Worker 3833*da0073e9SAndroid Build Coastguard Worker @requires_cuda 3834*da0073e9SAndroid Build Coastguard Worker def test_guard_default_device(self): 3835*da0073e9SAndroid Build Coastguard Worker try: 3836*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cuda") 3837*da0073e9SAndroid Build Coastguard Worker 3838*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 3839*da0073e9SAndroid Build Coastguard Worker 3840*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(counter) 3841*da0073e9SAndroid Build Coastguard Worker def f(): 3842*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 3843*da0073e9SAndroid Build Coastguard Worker return x * 2 3844*da0073e9SAndroid Build Coastguard Worker 3845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f().device.type, "cuda") 3846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 3847*da0073e9SAndroid Build Coastguard Worker 3848*da0073e9SAndroid Build Coastguard Worker torch.set_default_device("cpu") 3849*da0073e9SAndroid Build Coastguard Worker 3850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f().device.type, "cpu") 3851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 3852*da0073e9SAndroid Build Coastguard Worker 3853*da0073e9SAndroid Build Coastguard Worker finally: 3854*da0073e9SAndroid Build Coastguard Worker torch.set_default_device(None) 3855*da0073e9SAndroid Build Coastguard Worker 3856*da0073e9SAndroid Build Coastguard Worker def test_list_self_reference(self): 3857*da0073e9SAndroid Build Coastguard Worker # Issue - https://github.com/pytorch/pytorch/issues/100150 3858*da0073e9SAndroid Build Coastguard Worker root = [] 3859*da0073e9SAndroid Build Coastguard Worker root[:] = [root, root, None, None] 3860*da0073e9SAndroid Build Coastguard Worker 3861*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("eager") 3862*da0073e9SAndroid Build Coastguard Worker def test_bug(): 3863*da0073e9SAndroid Build Coastguard Worker return root 3864*da0073e9SAndroid Build Coastguard Worker 3865*da0073e9SAndroid Build Coastguard Worker test_bug() 3866*da0073e9SAndroid Build Coastguard Worker 3867*da0073e9SAndroid Build Coastguard Worker def test_hf_bigbird_unsqueeze(self): 3868*da0073e9SAndroid Build Coastguard Worker def torch_bmm_nd(inp_1, inp_2, ndim=None): 3869*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 3870*da0073e9SAndroid Build Coastguard Worker return torch.bmm(inp1, inp2) 3871*da0073e9SAndroid Build Coastguard Worker 3872*da0073e9SAndroid Build Coastguard Worker def fn(inp1, inp2, inp3, inp4, c): 3873*da0073e9SAndroid Build Coastguard Worker a = torch_bmm_nd(inp1, inp2, 4) 3874*da0073e9SAndroid Build Coastguard Worker a.unsqueeze_(2) 3875*da0073e9SAndroid Build Coastguard Worker a = a * 2 3876*da0073e9SAndroid Build Coastguard Worker 3877*da0073e9SAndroid Build Coastguard Worker b = torch_bmm_nd(inp3, inp4, 4) 3878*da0073e9SAndroid Build Coastguard Worker b.unsqueeze_(2) 3879*da0073e9SAndroid Build Coastguard Worker l = a + b 3880*da0073e9SAndroid Build Coastguard Worker 3881*da0073e9SAndroid Build Coastguard Worker out = torch.cat([a, b, c], dim=2) 3882*da0073e9SAndroid Build Coastguard Worker return out, l 3883*da0073e9SAndroid Build Coastguard Worker 3884*da0073e9SAndroid Build Coastguard Worker inp1 = torch.rand(1, 64, 448) 3885*da0073e9SAndroid Build Coastguard Worker inp2 = torch.rand(1, 448, 64) 3886*da0073e9SAndroid Build Coastguard Worker inp3 = torch.rand(1, 64, 448) 3887*da0073e9SAndroid Build Coastguard Worker inp4 = torch.rand(1, 448, 64) 3888*da0073e9SAndroid Build Coastguard Worker c = torch.rand(1, 64, 1, 64) 3889*da0073e9SAndroid Build Coastguard Worker 3890*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 3891*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 3892*da0073e9SAndroid Build Coastguard Worker opt_fn(inp1, inp2, inp3, inp4, c) 3893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 3894*da0073e9SAndroid Build Coastguard Worker 3895*da0073e9SAndroid Build Coastguard Worker def test_torch_variable_type(self): 3896*da0073e9SAndroid Build Coastguard Worker # from torchvision 3897*da0073e9SAndroid Build Coastguard Worker def check_type(obj, types_or_checks): 3898*da0073e9SAndroid Build Coastguard Worker for type_or_check in types_or_checks: 3899*da0073e9SAndroid Build Coastguard Worker if ( 3900*da0073e9SAndroid Build Coastguard Worker isinstance(obj, type_or_check) 3901*da0073e9SAndroid Build Coastguard Worker if isinstance(type_or_check, type) 3902*da0073e9SAndroid Build Coastguard Worker else type_or_check(obj) 3903*da0073e9SAndroid Build Coastguard Worker ): 3904*da0073e9SAndroid Build Coastguard Worker return True 3905*da0073e9SAndroid Build Coastguard Worker return False 3906*da0073e9SAndroid Build Coastguard Worker 3907*da0073e9SAndroid Build Coastguard Worker opt_check_type = torch._dynamo.optimize("eager")(check_type) 3908*da0073e9SAndroid Build Coastguard Worker ref = check_type(torch.randn(4), [torch.Tensor]) 3909*da0073e9SAndroid Build Coastguard Worker res = opt_check_type(torch.randn(4), [torch.Tensor]) 3910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3911*da0073e9SAndroid Build Coastguard Worker 3912*da0073e9SAndroid Build Coastguard Worker # Test for https://github.com/pytorch/pytorch/issues/103132 3913*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch("assume_static_by_default", False) 3914*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_dynamic_shapes(self): 3915*da0073e9SAndroid Build Coastguard Worker class Repro(torch.nn.Module): 3916*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3917*da0073e9SAndroid Build Coastguard Worker super().__init__() 3918*da0073e9SAndroid Build Coastguard Worker 3919*da0073e9SAndroid Build Coastguard Worker def forward(self, param): 3920*da0073e9SAndroid Build Coastguard Worker z = torch.matmul(param, param) 3921*da0073e9SAndroid Build Coastguard Worker return z 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Worker model = Repro() 3924*da0073e9SAndroid Build Coastguard Worker # Need a 3d tensor to actually cause the error: 3925*da0073e9SAndroid Build Coastguard Worker # we go down a path of the C++ matmul decomp that calls sizes(). 3926*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(4, 4, 4, requires_grad=True) 3927*da0073e9SAndroid Build Coastguard Worker model = torch.compile(model, backend="aot_eager", dynamic=True) 3928*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 3929*da0073e9SAndroid Build Coastguard Worker model(inp) 3930*da0073e9SAndroid Build Coastguard Worker 3931*da0073e9SAndroid Build Coastguard Worker def test_kwargs_out_list_variable(self): 3932*da0073e9SAndroid Build Coastguard Worker class Repro(torch.nn.Module): 3933*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3934*da0073e9SAndroid Build Coastguard Worker super().__init__() 3935*da0073e9SAndroid Build Coastguard Worker 3936*da0073e9SAndroid Build Coastguard Worker def forward(self, param): 3937*da0073e9SAndroid Build Coastguard Worker z = torch.frexp(**param) 3938*da0073e9SAndroid Build Coastguard Worker return z 3939*da0073e9SAndroid Build Coastguard Worker 3940*da0073e9SAndroid Build Coastguard Worker model = Repro() 3941*da0073e9SAndroid Build Coastguard Worker params = {"input": torch.tensor([[0.0, 1, 2, 4]])} 3942*da0073e9SAndroid Build Coastguard Worker params["out"] = [ 3943*da0073e9SAndroid Build Coastguard Worker torch.empty(0, dtype=torch.float32), # mantissa 3944*da0073e9SAndroid Build Coastguard Worker torch.empty(0, dtype=torch.int32), # exponent 3945*da0073e9SAndroid Build Coastguard Worker ] 3946*da0073e9SAndroid Build Coastguard Worker 3947*da0073e9SAndroid Build Coastguard Worker model = torch.compile(model, backend="eager") 3948*da0073e9SAndroid Build Coastguard Worker mantissa, exponent = model(params) 3949*da0073e9SAndroid Build Coastguard Worker ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]]) 3950*da0073e9SAndroid Build Coastguard Worker ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32) 3951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_mantissa, mantissa) 3952*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_exponent, exponent) 3953*da0073e9SAndroid Build Coastguard Worker 3954*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 3955*da0073e9SAndroid Build Coastguard Worker def test_split_with_sizes_aot_autograd(self): 3956*da0073e9SAndroid Build Coastguard Worker def fn(result, split_sizes): 3957*da0073e9SAndroid Build Coastguard Worker rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist()) 3958*da0073e9SAndroid Build Coastguard Worker return rs 3959*da0073e9SAndroid Build Coastguard Worker 3960*da0073e9SAndroid Build Coastguard Worker example_inputs = ( 3961*da0073e9SAndroid Build Coastguard Worker torch.randn(32, requires_grad=True), 3962*da0073e9SAndroid Build Coastguard Worker torch.tensor((7, 16, 9)), 3963*da0073e9SAndroid Build Coastguard Worker ) 3964*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs) 3965*da0073e9SAndroid Build Coastguard Worker expected = fn(*example_inputs) 3966*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 3967*da0073e9SAndroid Build Coastguard Worker 3968*da0073e9SAndroid Build Coastguard Worker def test_unspecialized_nn_module_with_torch_variable_attribute(self): 3969*da0073e9SAndroid Build Coastguard Worker """ 3970*da0073e9SAndroid Build Coastguard Worker In this case self.fn = something that should be a TorchVariable. 3971*da0073e9SAndroid Build Coastguard Worker When it's not a TorchVariable, dynamo tries to trace through and fails. 3972*da0073e9SAndroid Build Coastguard Worker This makes sure that the self.fn is handled as a TorchVariable. 3973*da0073e9SAndroid Build Coastguard Worker """ 3974*da0073e9SAndroid Build Coastguard Worker 3975*da0073e9SAndroid Build Coastguard Worker class UserModule(torch.nn.Module): 3976*da0073e9SAndroid Build Coastguard Worker torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule 3977*da0073e9SAndroid Build Coastguard Worker 3978*da0073e9SAndroid Build Coastguard Worker def __init__(self, fn): 3979*da0073e9SAndroid Build Coastguard Worker super().__init__() 3980*da0073e9SAndroid Build Coastguard Worker self.fn = fn 3981*da0073e9SAndroid Build Coastguard Worker 3982*da0073e9SAndroid Build Coastguard Worker def forward(self, **inp): 3983*da0073e9SAndroid Build Coastguard Worker return self.fn(**inp) 3984*da0073e9SAndroid Build Coastguard Worker 3985*da0073e9SAndroid Build Coastguard Worker inputs = { 3986*da0073e9SAndroid Build Coastguard Worker "input": torch.randn([2, 9]).uniform_(0, 1), 3987*da0073e9SAndroid Build Coastguard Worker "target": torch.randn([2, 9]).uniform_(0, 1), 3988*da0073e9SAndroid Build Coastguard Worker "reduction": "mean", 3989*da0073e9SAndroid Build Coastguard Worker } 3990*da0073e9SAndroid Build Coastguard Worker 3991*da0073e9SAndroid Build Coastguard Worker mod = UserModule(torch.nn.functional.binary_cross_entropy) 3992*da0073e9SAndroid Build Coastguard Worker ref = mod(**inputs) 3993*da0073e9SAndroid Build Coastguard Worker res = torch._dynamo.optimize("eager", nopython=True)(mod)(**inputs) 3994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 3995*da0073e9SAndroid Build Coastguard Worker 3996*da0073e9SAndroid Build Coastguard Worker def test_call_finally_python_3_8(self): 3997*da0073e9SAndroid Build Coastguard Worker # Issue - https://github.com/pytorch/pytorch/issues/97811 3998*da0073e9SAndroid Build Coastguard Worker def make_fn(g): 3999*da0073e9SAndroid Build Coastguard Worker def fn(): 4000*da0073e9SAndroid Build Coastguard Worker while True: 4001*da0073e9SAndroid Build Coastguard Worker try: 4002*da0073e9SAndroid Build Coastguard Worker print(g) 4003*da0073e9SAndroid Build Coastguard Worker break 4004*da0073e9SAndroid Build Coastguard Worker except Exception as _: 4005*da0073e9SAndroid Build Coastguard Worker break 4006*da0073e9SAndroid Build Coastguard Worker 4007*da0073e9SAndroid Build Coastguard Worker return torch.compile(fn, backend="eager") 4008*da0073e9SAndroid Build Coastguard Worker 4009*da0073e9SAndroid Build Coastguard Worker make_fn(None)() 4010*da0073e9SAndroid Build Coastguard Worker 4011*da0073e9SAndroid Build Coastguard Worker def test_call_finally_python_3_8_2(self): 4012*da0073e9SAndroid Build Coastguard Worker def f(x): 4013*da0073e9SAndroid Build Coastguard Worker while x: 4014*da0073e9SAndroid Build Coastguard Worker try: 4015*da0073e9SAndroid Build Coastguard Worker pass 4016*da0073e9SAndroid Build Coastguard Worker except Exception as _: 4017*da0073e9SAndroid Build Coastguard Worker continue 4018*da0073e9SAndroid Build Coastguard Worker 4019*da0073e9SAndroid Build Coastguard Worker torch.compile(f, backend="eager")(0) 4020*da0073e9SAndroid Build Coastguard Worker 4021*da0073e9SAndroid Build Coastguard Worker def test_call_finally_opcode_python_3_8(self): 4022*da0073e9SAndroid Build Coastguard Worker def fn(): 4023*da0073e9SAndroid Build Coastguard Worker try: 4024*da0073e9SAndroid Build Coastguard Worker return torch.zeros(4) 4025*da0073e9SAndroid Build Coastguard Worker finally: 4026*da0073e9SAndroid Build Coastguard Worker return torch.ones(4) # noqa: SIM107, B012 4027*da0073e9SAndroid Build Coastguard Worker 4028*da0073e9SAndroid Build Coastguard Worker result = torch.compile(fn, backend="aot_eager")() 4029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.ones(4)) 4030*da0073e9SAndroid Build Coastguard Worker 4031*da0073e9SAndroid Build Coastguard Worker def test_string_format(self): 4032*da0073e9SAndroid Build Coastguard Worker s = "temp{i}" 4033*da0073e9SAndroid Build Coastguard Worker 4034*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 4035*da0073e9SAndroid Build Coastguard Worker def fn(x): 4036*da0073e9SAndroid Build Coastguard Worker if s.format(i=4) == "temp4": 4037*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 4038*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 4039*da0073e9SAndroid Build Coastguard Worker 4040*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 4041*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), torch.sin(x)) 4042*da0073e9SAndroid Build Coastguard Worker 4043*da0073e9SAndroid Build Coastguard Worker # Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards' 4044*da0073e9SAndroid Build Coastguard Worker # due to bad empty list handling 4045*da0073e9SAndroid Build Coastguard Worker def test_empty_list_contains_with_jump(self): 4046*da0073e9SAndroid Build Coastguard Worker def fn(x, l): 4047*da0073e9SAndroid Build Coastguard Worker if x in l: 4048*da0073e9SAndroid Build Coastguard Worker return x.cos() 4049*da0073e9SAndroid Build Coastguard Worker return x.sin() 4050*da0073e9SAndroid Build Coastguard Worker 4051*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 4052*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(counter)(fn)(torch.randn([2, 2]), []) 4053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 4054*da0073e9SAndroid Build Coastguard Worker 4055*da0073e9SAndroid Build Coastguard Worker def test_graph_break_on_jit_isinstance(self): 4056*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 4057*da0073e9SAndroid Build Coastguard Worker def fn(x): 4058*da0073e9SAndroid Build Coastguard Worker if torch.jit.isinstance(x, List[str]): 4059*da0073e9SAndroid Build Coastguard Worker return x * 2 4060*da0073e9SAndroid Build Coastguard Worker return x 4061*da0073e9SAndroid Build Coastguard Worker 4062*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 4063*da0073e9SAndroid Build Coastguard Worker x = torch.rand(4) 4064*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(fn(x), opt_fn(x))) 4065*da0073e9SAndroid Build Coastguard Worker 4066*da0073e9SAndroid Build Coastguard Worker def test_add_sub_alpha_out(self): 4067*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3, 4) 4068*da0073e9SAndroid Build Coastguard Worker other = 1 4069*da0073e9SAndroid Build Coastguard Worker alpha = 2 4070*da0073e9SAndroid Build Coastguard Worker for op in [torch.add, torch.sub]: 4071*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(2, 3, 4) 4072*da0073e9SAndroid Build Coastguard Worker compile_out = torch.zeros(2, 3, 4) 4073*da0073e9SAndroid Build Coastguard Worker op(inp, other, alpha=alpha, out=out) 4074*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(op, dynamic=True) 4075*da0073e9SAndroid Build Coastguard Worker compiled_fn(inp, other, alpha=alpha, out=compile_out) 4076*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, compile_out)) 4077*da0073e9SAndroid Build Coastguard Worker 4078*da0073e9SAndroid Build Coastguard Worker def test_negative_shape_guard(self): 4079*da0073e9SAndroid Build Coastguard Worker def fn(x): 4080*da0073e9SAndroid Build Coastguard Worker if x.size() != (5, 1, 2, 3): 4081*da0073e9SAndroid Build Coastguard Worker return x.cos() 4082*da0073e9SAndroid Build Coastguard Worker return x.sin() 4083*da0073e9SAndroid Build Coastguard Worker 4084*da0073e9SAndroid Build Coastguard Worker counter = torch._dynamo.testing.CompileCounter() 4085*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=counter, dynamic=True) 4086*da0073e9SAndroid Build Coastguard Worker 4087*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 1, 3, 4) 4088*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(5, 1, 2, 3) 4089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x2), opt_fn(x2)) 4091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 4092*da0073e9SAndroid Build Coastguard Worker 4093*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 4094*da0073e9SAndroid Build Coastguard Worker def test_deferred_runtime_asserts(self): 4095*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True) 4096*da0073e9SAndroid Build Coastguard Worker def f(x): 4097*da0073e9SAndroid Build Coastguard Worker y = x.item() 4098*da0073e9SAndroid Build Coastguard Worker torch._check_is_size(y) 4099*da0073e9SAndroid Build Coastguard Worker if y >= 0: 4100*da0073e9SAndroid Build Coastguard Worker return x * 2 4101*da0073e9SAndroid Build Coastguard Worker else: 4102*da0073e9SAndroid Build Coastguard Worker return x * 3 4103*da0073e9SAndroid Build Coastguard Worker 4104*da0073e9SAndroid Build Coastguard Worker f(torch.tensor([3])) 4105*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2]))) 4106*da0073e9SAndroid Build Coastguard Worker 4107*da0073e9SAndroid Build Coastguard Worker def test_addr_alpha_beta_out(self): 4108*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 3) 4109*da0073e9SAndroid Build Coastguard Worker vec1 = torch.randn(2) 4110*da0073e9SAndroid Build Coastguard Worker vec2 = torch.randn(3) 4111*da0073e9SAndroid Build Coastguard Worker alpha = 2 4112*da0073e9SAndroid Build Coastguard Worker beta = 5 4113*da0073e9SAndroid Build Coastguard Worker 4114*da0073e9SAndroid Build Coastguard Worker out = torch.zeros(2, 3) 4115*da0073e9SAndroid Build Coastguard Worker compile_out = torch.zeros(2, 3) 4116*da0073e9SAndroid Build Coastguard Worker 4117*da0073e9SAndroid Build Coastguard Worker torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out) 4118*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(torch.addr, dynamic=True) 4119*da0073e9SAndroid Build Coastguard Worker compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out) 4120*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(out, compile_out)) 4121*da0073e9SAndroid Build Coastguard Worker 4122*da0073e9SAndroid Build Coastguard Worker def test_setattr_requires_grad_graph_breaks(self): 4123*da0073e9SAndroid Build Coastguard Worker def fn(x): 4124*da0073e9SAndroid Build Coastguard Worker z = x + 4 4125*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 4126*da0073e9SAndroid Build Coastguard Worker y = x * z 4127*da0073e9SAndroid Build Coastguard Worker return y 4128*da0073e9SAndroid Build Coastguard Worker 4129*da0073e9SAndroid Build Coastguard Worker for backend in ["count", "eager", "aot_eager"]: 4130*da0073e9SAndroid Build Coastguard Worker if backend == "count": 4131*da0073e9SAndroid Build Coastguard Worker backend = CompileCounter() 4132*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=backend) 4133*da0073e9SAndroid Build Coastguard Worker 4134*da0073e9SAndroid Build Coastguard Worker eager = torch.zeros(5) 4135*da0073e9SAndroid Build Coastguard Worker compiled = eager.clone() 4136*da0073e9SAndroid Build Coastguard Worker 4137*da0073e9SAndroid Build Coastguard Worker out_eager = fn(eager) 4138*da0073e9SAndroid Build Coastguard Worker out_opt = opt_fn(compiled) 4139*da0073e9SAndroid Build Coastguard Worker 4140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_eager, out_opt) 4141*da0073e9SAndroid Build Coastguard Worker 4142*da0073e9SAndroid Build Coastguard Worker out_eager.sum().backward() 4143*da0073e9SAndroid Build Coastguard Worker out_opt.sum().backward() 4144*da0073e9SAndroid Build Coastguard Worker 4145*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager, compiled) 4146*da0073e9SAndroid Build Coastguard Worker if isinstance(backend, CompileCounter): 4147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(backend.frame_count, 2) # graph breaks 4148*da0073e9SAndroid Build Coastguard Worker 4149*da0073e9SAndroid Build Coastguard Worker def test_dynamic_shapes_double_not_equal(self): 4150*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113393 4151*da0073e9SAndroid Build Coastguard Worker def fn(x): 4152*da0073e9SAndroid Build Coastguard Worker if x.size() != (5, 1, 2, 3): 4153*da0073e9SAndroid Build Coastguard Worker return x.cos() 4154*da0073e9SAndroid Build Coastguard Worker return x.sin() 4155*da0073e9SAndroid Build Coastguard Worker 4156*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 4157*da0073e9SAndroid Build Coastguard Worker 4158*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 1, 2, 3) 4159*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(5, 1, 3, 4) 4160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x2), opt_fn(x2)) 4162*da0073e9SAndroid Build Coastguard Worker 4163*da0073e9SAndroid Build Coastguard Worker def test_inductor_no_recursionerror_on_for_loops(self): 4164*da0073e9SAndroid Build Coastguard Worker def forward(x): 4165*da0073e9SAndroid Build Coastguard Worker for _ in range(1000): 4166*da0073e9SAndroid Build Coastguard Worker x = 1.0 * x 4167*da0073e9SAndroid Build Coastguard Worker return x 4168*da0073e9SAndroid Build Coastguard Worker 4169*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 4170*da0073e9SAndroid Build Coastguard Worker same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0])) 4171*da0073e9SAndroid Build Coastguard Worker ) 4172*da0073e9SAndroid Build Coastguard Worker 4173*da0073e9SAndroid Build Coastguard Worker def test_user_defined_object_callable(self): 4174*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/114019 4175*da0073e9SAndroid Build Coastguard Worker class MyCallable: 4176*da0073e9SAndroid Build Coastguard Worker def __call__(self, x): 4177*da0073e9SAndroid Build Coastguard Worker return x + 1 4178*da0073e9SAndroid Build Coastguard Worker 4179*da0073e9SAndroid Build Coastguard Worker def fn(x): 4180*da0073e9SAndroid Build Coastguard Worker # Create in graph - will not have source 4181*da0073e9SAndroid Build Coastguard Worker return MyCallable()(x) 4182*da0073e9SAndroid Build Coastguard Worker 4183*da0073e9SAndroid Build Coastguard Worker fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 4184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1))) 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(log_compilation_metrics=True) 4187*da0073e9SAndroid Build Coastguard Worker def test_many_views_with_mutation(self): 4188*da0073e9SAndroid Build Coastguard Worker # When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap 4189*da0073e9SAndroid Build Coastguard Worker # began adding shape guards - a quadratic amount relative to the number of inputs. 4190*da0073e9SAndroid Build Coastguard Worker # Test this configuration, and test that a reasonable number of guards are added. 4191*da0073e9SAndroid Build Coastguard Worker # Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards. 4192*da0073e9SAndroid Build Coastguard Worker def fn(x): 4193*da0073e9SAndroid Build Coastguard Worker x[0].relu_() 4194*da0073e9SAndroid Build Coastguard Worker return torch.cat(x).sum() 4195*da0073e9SAndroid Build Coastguard Worker 4196*da0073e9SAndroid Build Coastguard Worker AMT = 32 4197*da0073e9SAndroid Build Coastguard Worker src = torch.rand(16 * (AMT + 1)) 4198*da0073e9SAndroid Build Coastguard Worker 4199*da0073e9SAndroid Build Coastguard Worker x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)] 4200*da0073e9SAndroid Build Coastguard Worker 4201*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4202*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.clear_compilation_metrics() 4203*da0073e9SAndroid Build Coastguard Worker 4204*da0073e9SAndroid Build Coastguard Worker res = torch.compile(fn, backend="aot_eager")(x) 4205*da0073e9SAndroid Build Coastguard Worker 4206*da0073e9SAndroid Build Coastguard Worker all_metrics = torch._dynamo.utils.get_compilation_metrics() 4207*da0073e9SAndroid Build Coastguard Worker 4208*da0073e9SAndroid Build Coastguard Worker total_guards = sum(metric.guard_count for metric in all_metrics) 4209*da0073e9SAndroid Build Coastguard Worker self.assertLess(total_guards, AMT * 8) 4210*da0073e9SAndroid Build Coastguard Worker 4211*da0073e9SAndroid Build Coastguard Worker total_shape_env_guards = sum( 4212*da0073e9SAndroid Build Coastguard Worker metric.shape_env_guard_count for metric in all_metrics 4213*da0073e9SAndroid Build Coastguard Worker ) 4214*da0073e9SAndroid Build Coastguard Worker self.assertLess(total_shape_env_guards, AMT * 8) 4215*da0073e9SAndroid Build Coastguard Worker 4216*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/118799 4217*da0073e9SAndroid Build Coastguard Worker def test_subclass_graph_output_repro(self): 4218*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.allow_in_graph 4219*da0073e9SAndroid Build Coastguard Worker def to_subclass(x): 4220*da0073e9SAndroid Build Coastguard Worker return TwoTensor(x.clone(), x.clone()) 4221*da0073e9SAndroid Build Coastguard Worker 4222*da0073e9SAndroid Build Coastguard Worker def f(x): 4223*da0073e9SAndroid Build Coastguard Worker tmp_subclass = to_subclass(x) 4224*da0073e9SAndroid Build Coastguard Worker return tmp_subclass.view(-1) 4225*da0073e9SAndroid Build Coastguard Worker 4226*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2) 4227*da0073e9SAndroid Build Coastguard Worker out_ref = f(x) 4228*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(f, backend="aot_eager")(x) 4229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 4230*da0073e9SAndroid Build Coastguard Worker 4231*da0073e9SAndroid Build Coastguard Worker def test_numpy_tobytes_no_error(self): 4232*da0073e9SAndroid Build Coastguard Worker def fn(x): 4233*da0073e9SAndroid Build Coastguard Worker x += 1 4234*da0073e9SAndroid Build Coastguard Worker z = x.tobytes() 4235*da0073e9SAndroid Build Coastguard Worker x += 1 4236*da0073e9SAndroid Build Coastguard Worker return z 4237*da0073e9SAndroid Build Coastguard Worker 4238*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4239*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 4240*da0073e9SAndroid Build Coastguard Worker opt_arg, arg = np.array([1, 2]), np.array([1, 2]) 4241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(opt_arg), fn(arg)) 4242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 4243*da0073e9SAndroid Build Coastguard Worker 4244*da0073e9SAndroid Build Coastguard Worker def test_numpy_not_ndarray_recompiles(self): 4245*da0073e9SAndroid Build Coastguard Worker import torch 4246*da0073e9SAndroid Build Coastguard Worker 4247*da0073e9SAndroid Build Coastguard Worker def fn(x=None): 4248*da0073e9SAndroid Build Coastguard Worker if x is None: 4249*da0073e9SAndroid Build Coastguard Worker x = np.ones(3) 4250*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, int): 4251*da0073e9SAndroid Build Coastguard Worker x = np.ones(6) 4252*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, str): 4253*da0073e9SAndroid Build Coastguard Worker x = np.ones(9) 4254*da0073e9SAndroid Build Coastguard Worker return x**2 4255*da0073e9SAndroid Build Coastguard Worker 4256*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4257*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 4258*da0073e9SAndroid Build Coastguard Worker 4259*da0073e9SAndroid Build Coastguard Worker x = np.zeros((2, 2)) 4260*da0073e9SAndroid Build Coastguard Worker 4261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x), fn(x)) 4262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 4263*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(), fn()) 4264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 4265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(10), fn(10)) 4266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 3) 4267*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn("10"), fn("10")) 4268*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 4) 4269*da0073e9SAndroid Build Coastguard Worker 4270*da0073e9SAndroid Build Coastguard Worker @parametrize( 4271*da0073e9SAndroid Build Coastguard Worker "backend", 4272*da0073e9SAndroid Build Coastguard Worker ["eager", "aot_eager", "inductor"], 4273*da0073e9SAndroid Build Coastguard Worker ) 4274*da0073e9SAndroid Build Coastguard Worker @parametrize( 4275*da0073e9SAndroid Build Coastguard Worker "func_name", 4276*da0073e9SAndroid Build Coastguard Worker ["func1", "func2", "func3"], 4277*da0073e9SAndroid Build Coastguard Worker ) 4278*da0073e9SAndroid Build Coastguard Worker def test_tensor_set_data(self, backend, func_name): 4279*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/113030 4280*da0073e9SAndroid Build Coastguard Worker def func1(x, y): 4281*da0073e9SAndroid Build Coastguard Worker x.data = y 4282*da0073e9SAndroid Build Coastguard Worker x.add_(1) 4283*da0073e9SAndroid Build Coastguard Worker return x 4284*da0073e9SAndroid Build Coastguard Worker 4285*da0073e9SAndroid Build Coastguard Worker def func2(x, y): 4286*da0073e9SAndroid Build Coastguard Worker x.data = y 4287*da0073e9SAndroid Build Coastguard Worker y.data = torch.zeros([0]) 4288*da0073e9SAndroid Build Coastguard Worker return x 4289*da0073e9SAndroid Build Coastguard Worker 4290*da0073e9SAndroid Build Coastguard Worker def func3(x, y): 4291*da0073e9SAndroid Build Coastguard Worker z = x 4292*da0073e9SAndroid Build Coastguard Worker x.data = y 4293*da0073e9SAndroid Build Coastguard Worker y.data = torch.zeros([0]) 4294*da0073e9SAndroid Build Coastguard Worker return torch.tensor(x is z) 4295*da0073e9SAndroid Build Coastguard Worker 4296*da0073e9SAndroid Build Coastguard Worker funcs = {"func1": func1, "func2": func2, "func3": func3} 4297*da0073e9SAndroid Build Coastguard Worker func = funcs[func_name] 4298*da0073e9SAndroid Build Coastguard Worker 4299*da0073e9SAndroid Build Coastguard Worker if backend != "eager" and func is func1: 4300*da0073e9SAndroid Build Coastguard Worker # add_ not working w/ aot_autograd? 4301*da0073e9SAndroid Build Coastguard Worker return 4302*da0073e9SAndroid Build Coastguard Worker 4303*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 4304*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 4305*da0073e9SAndroid Build Coastguard Worker 4306*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch.compile(func, backend=cnt, fullgraph=True) 4307*da0073e9SAndroid Build Coastguard Worker requires_grad = func is not func1 4308*da0073e9SAndroid Build Coastguard Worker for i in range(0, 5): 4309*da0073e9SAndroid Build Coastguard Worker # Inputs 4310*da0073e9SAndroid Build Coastguard Worker eager_a = torch.ones([6], requires_grad=requires_grad) 4311*da0073e9SAndroid Build Coastguard Worker compiled_a = torch.ones([6], requires_grad=requires_grad) 4312*da0073e9SAndroid Build Coastguard Worker 4313*da0073e9SAndroid Build Coastguard Worker eager_b = torch.ones([6], requires_grad=requires_grad) 4314*da0073e9SAndroid Build Coastguard Worker compiled_b = torch.ones([6], requires_grad=requires_grad) 4315*da0073e9SAndroid Build Coastguard Worker 4316*da0073e9SAndroid Build Coastguard Worker # Eager 4317*da0073e9SAndroid Build Coastguard Worker out_eager = func(eager_a, eager_b) 4318*da0073e9SAndroid Build Coastguard Worker # Compiled 4319*da0073e9SAndroid Build Coastguard Worker out_compiled = compiled_fn(compiled_a, compiled_b) 4320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_a, compiled_a) 4321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_b, compiled_b) 4322*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(out_eager, out_compiled)) 4323*da0073e9SAndroid Build Coastguard Worker 4324*da0073e9SAndroid Build Coastguard Worker # func1 hits a leaf Variable that requires grad is being used in an in-place operation 4325*da0073e9SAndroid Build Coastguard Worker if requires_grad: 4326*da0073e9SAndroid Build Coastguard Worker bwd_inp_eager = torch.randn([6]) 4327*da0073e9SAndroid Build Coastguard Worker bwd_inp_compiled = torch.clone(bwd_inp_eager) 4328*da0073e9SAndroid Build Coastguard Worker eager_a.backward(bwd_inp_eager) 4329*da0073e9SAndroid Build Coastguard Worker compiled_a.backward(bwd_inp_compiled) 4330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_a.grad, compiled_a.grad) 4331*da0073e9SAndroid Build Coastguard Worker 4332*da0073e9SAndroid Build Coastguard Worker # Prove guarding works - we run the compiled_fn 5 times 4333*da0073e9SAndroid Build Coastguard Worker # frame_count should stay at 1. 4334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 4335*da0073e9SAndroid Build Coastguard Worker 4336*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 4337*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION, 4338*da0073e9SAndroid Build Coastguard Worker "flash attention not supported", 4339*da0073e9SAndroid Build Coastguard Worker ) 4340*da0073e9SAndroid Build Coastguard Worker def test_flash_attn_backward_mixed_strides(self): 4341*da0073e9SAndroid Build Coastguard Worker # in this repro, "grad_out" and "value" are transposed tensors, 4342*da0073e9SAndroid Build Coastguard Worker # but "key" and "value" are contiguous 4343*da0073e9SAndroid Build Coastguard Worker def gen_inputs(device): 4344*da0073e9SAndroid Build Coastguard Worker return ( 4345*da0073e9SAndroid Build Coastguard Worker torch.randn( 4346*da0073e9SAndroid Build Coastguard Worker 2, 513, 16, 64, dtype=torch.float16, device=device 4347*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2), 4348*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4349*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4350*da0073e9SAndroid Build Coastguard Worker torch.randn( 4351*da0073e9SAndroid Build Coastguard Worker 2, 513, 16, 64, dtype=torch.float16, device=device 4352*da0073e9SAndroid Build Coastguard Worker ).transpose(1, 2), 4353*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4354*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 16, 513, device=device), 4355*da0073e9SAndroid Build Coastguard Worker None, 4356*da0073e9SAndroid Build Coastguard Worker None, 4357*da0073e9SAndroid Build Coastguard Worker 513, 4358*da0073e9SAndroid Build Coastguard Worker 513, 4359*da0073e9SAndroid Build Coastguard Worker 0.0, 4360*da0073e9SAndroid Build Coastguard Worker False, 4361*da0073e9SAndroid Build Coastguard Worker torch.tensor(1, dtype=torch.int64), 4362*da0073e9SAndroid Build Coastguard Worker torch.tensor(1, dtype=torch.int64), 4363*da0073e9SAndroid Build Coastguard Worker ) 4364*da0073e9SAndroid Build Coastguard Worker 4365*da0073e9SAndroid Build Coastguard Worker inps_cuda = gen_inputs("cuda") 4366*da0073e9SAndroid Build Coastguard Worker inps_meta = gen_inputs("meta") 4367*da0073e9SAndroid Build Coastguard Worker ( 4368*da0073e9SAndroid Build Coastguard Worker out1_ref, 4369*da0073e9SAndroid Build Coastguard Worker out2_ref, 4370*da0073e9SAndroid Build Coastguard Worker out3_ref, 4371*da0073e9SAndroid Build Coastguard Worker ) = torch.ops.aten._scaled_dot_product_flash_attention_backward( 4372*da0073e9SAndroid Build Coastguard Worker *inps_cuda, scale=0.125 4373*da0073e9SAndroid Build Coastguard Worker ) 4374*da0073e9SAndroid Build Coastguard Worker from torch._meta_registrations import meta__scaled_dot_product_flash_backward 4375*da0073e9SAndroid Build Coastguard Worker 4376*da0073e9SAndroid Build Coastguard Worker out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward( 4377*da0073e9SAndroid Build Coastguard Worker *inps_meta, scale=0.125 4378*da0073e9SAndroid Build Coastguard Worker ) 4379*da0073e9SAndroid Build Coastguard Worker 4380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_ref.shape, out1_test.shape) 4381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_ref.stride(), out1_test.stride()) 4382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_ref.shape, out2_test.shape) 4383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_ref.stride(), out2_test.stride()) 4384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out3_ref.shape, out3_test.shape) 4385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out3_ref.stride(), out3_test.stride()) 4386*da0073e9SAndroid Build Coastguard Worker 4387*da0073e9SAndroid Build Coastguard Worker def test_user_ctor_ctx_manager(self): 4388*da0073e9SAndroid Build Coastguard Worker class UserCtxManager: 4389*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 4390*da0073e9SAndroid Build Coastguard Worker return 1 4391*da0073e9SAndroid Build Coastguard Worker 4392*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 4393*da0073e9SAndroid Build Coastguard Worker pass 4394*da0073e9SAndroid Build Coastguard Worker 4395*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 4396*da0073e9SAndroid Build Coastguard Worker ucm = UserCtxManager() 4397*da0073e9SAndroid Build Coastguard Worker return x * x 4398*da0073e9SAndroid Build Coastguard Worker 4399*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4400*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 4401*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 4402*da0073e9SAndroid Build Coastguard Worker opt_fn(x, x) 4403*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 4404*da0073e9SAndroid Build Coastguard Worker 4405*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(capture_scalar_outputs=True) 4406*da0073e9SAndroid Build Coastguard Worker def test_unbacked_arange_in_bounds(self): 4407*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/113002 4408*da0073e9SAndroid Build Coastguard Worker class PaddingNet(nn.Module): 4409*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4410*da0073e9SAndroid Build Coastguard Worker super().__init__() 4411*da0073e9SAndroid Build Coastguard Worker 4412*da0073e9SAndroid Build Coastguard Worker def forward(self, lengths): 4413*da0073e9SAndroid Build Coastguard Worker max_seq_len = lengths.max().item() 4414*da0073e9SAndroid Build Coastguard Worker row_vector = torch.arange(0, max_seq_len, 1) 4415*da0073e9SAndroid Build Coastguard Worker matrix = torch.unsqueeze(lengths, dim=-1) 4416*da0073e9SAndroid Build Coastguard Worker mask = row_vector < matrix 4417*da0073e9SAndroid Build Coastguard Worker mask = mask.type(torch.float32) 4418*da0073e9SAndroid Build Coastguard Worker mask_3d_btd = mask[:, :, None] 4419*da0073e9SAndroid Build Coastguard Worker return mask_3d_btd 4420*da0073e9SAndroid Build Coastguard Worker 4421*da0073e9SAndroid Build Coastguard Worker model = PaddingNet() 4422*da0073e9SAndroid Build Coastguard Worker lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32) 4423*da0073e9SAndroid Build Coastguard Worker 4424*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4425*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(model) 4426*da0073e9SAndroid Build Coastguard Worker opt_fn(lengths) 4427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 4428*da0073e9SAndroid Build Coastguard Worker 4429*da0073e9SAndroid Build Coastguard Worker def test_overlapping_inputs_with_dynamic_shapes_error(self): 4430*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager") 4431*da0073e9SAndroid Build Coastguard Worker def fn(a, b, c, d, e, f): 4432*da0073e9SAndroid Build Coastguard Worker a.mul_(2) 4433*da0073e9SAndroid Build Coastguard Worker b.mul_(2) 4434*da0073e9SAndroid Build Coastguard Worker c.mul_(2) 4435*da0073e9SAndroid Build Coastguard Worker d.mul_(2) 4436*da0073e9SAndroid Build Coastguard Worker e.mul_(2) 4437*da0073e9SAndroid Build Coastguard Worker f.mul_(2) 4438*da0073e9SAndroid Build Coastguard Worker 4439*da0073e9SAndroid Build Coastguard Worker base = torch.ones(2, 20) 4440*da0073e9SAndroid Build Coastguard Worker a = base[:, 0:2] 4441*da0073e9SAndroid Build Coastguard Worker b = base[:, 2:4] 4442*da0073e9SAndroid Build Coastguard Worker c = base[:, 4:6] 4443*da0073e9SAndroid Build Coastguard Worker d = base[:, 6:8] 4444*da0073e9SAndroid Build Coastguard Worker e = base[:, 8:10] 4445*da0073e9SAndroid Build Coastguard Worker f = base[:, 10:12] 4446*da0073e9SAndroid Build Coastguard Worker f2 = base[:, 10:14] 4447*da0073e9SAndroid Build Coastguard Worker out = fn(a, b, c, d, e, f) 4448*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4449*da0073e9SAndroid Build Coastguard Worker AssertionError, "is being compiled with dynamic shapes" 4450*da0073e9SAndroid Build Coastguard Worker ): 4451*da0073e9SAndroid Build Coastguard Worker out2 = fn(a, b, c, d, e, f2) 4452*da0073e9SAndroid Build Coastguard Worker 4453*da0073e9SAndroid Build Coastguard Worker def test_user_ctor_ctx_manager_custom_init(self): 4454*da0073e9SAndroid Build Coastguard Worker class UserCtxManager: 4455*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 4456*da0073e9SAndroid Build Coastguard Worker x[0] = 10 4457*da0073e9SAndroid Build Coastguard Worker 4458*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 4459*da0073e9SAndroid Build Coastguard Worker return 1 4460*da0073e9SAndroid Build Coastguard Worker 4461*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 4462*da0073e9SAndroid Build Coastguard Worker pass 4463*da0073e9SAndroid Build Coastguard Worker 4464*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 4465*da0073e9SAndroid Build Coastguard Worker ucm = UserCtxManager(y) 4466*da0073e9SAndroid Build Coastguard Worker return x * y[0] 4467*da0073e9SAndroid Build Coastguard Worker 4468*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4469*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 4470*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 4471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, [5]), fn(x, [5])) 4472*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 4473*da0073e9SAndroid Build Coastguard Worker 4474*da0073e9SAndroid Build Coastguard Worker def test_user_ctor_ctx_manager_custom_init_graph_break(self): 4475*da0073e9SAndroid Build Coastguard Worker counter = [0] 4476*da0073e9SAndroid Build Coastguard Worker 4477*da0073e9SAndroid Build Coastguard Worker class UserCtxManager: 4478*da0073e9SAndroid Build Coastguard Worker def __init__(self, k): 4479*da0073e9SAndroid Build Coastguard Worker k[0] += 1 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 4482*da0073e9SAndroid Build Coastguard Worker return 1 4483*da0073e9SAndroid Build Coastguard Worker 4484*da0073e9SAndroid Build Coastguard Worker def __exit__(self, exc_type, exc_val, exc_tb): 4485*da0073e9SAndroid Build Coastguard Worker pass 4486*da0073e9SAndroid Build Coastguard Worker 4487*da0073e9SAndroid Build Coastguard Worker def fn(x, counter): 4488*da0073e9SAndroid Build Coastguard Worker x = x * x 4489*da0073e9SAndroid Build Coastguard Worker ucm = UserCtxManager(counter) 4490*da0073e9SAndroid Build Coastguard Worker return x * x 4491*da0073e9SAndroid Build Coastguard Worker 4492*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4493*da0073e9SAndroid Build Coastguard Worker opt_fn = torch._dynamo.optimize(cnt)(fn) 4494*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]) 4495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_fn(x, counter), fn(x, counter)) 4496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 4497*da0073e9SAndroid Build Coastguard Worker for i in range(0, 10): 4498*da0073e9SAndroid Build Coastguard Worker opt_fn(x, counter) 4499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 12) 4500*da0073e9SAndroid Build Coastguard Worker if torch._dynamo.config.assume_static_by_default: 4501*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """2""") 4502*da0073e9SAndroid Build Coastguard Worker else: 4503*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(cnt.frame_count, """1""") 4504*da0073e9SAndroid Build Coastguard Worker 4505*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 4506*da0073e9SAndroid Build Coastguard Worker def test_many_overlapping_inputs_does_not_explode_guards(self): 4507*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.backends.common import aot_autograd 4508*da0073e9SAndroid Build Coastguard Worker 4509*da0073e9SAndroid Build Coastguard Worker # Before, this was (9702, 0) 4510*da0073e9SAndroid Build Coastguard Worker num_shape_guards = None 4511*da0073e9SAndroid Build Coastguard Worker num_aot_guards = None 4512*da0073e9SAndroid Build Coastguard Worker num_compiles = 0 4513*da0073e9SAndroid Build Coastguard Worker 4514*da0073e9SAndroid Build Coastguard Worker def guard_count_backend(gm, *args): 4515*da0073e9SAndroid Build Coastguard Worker nonlocal num_shape_guards 4516*da0073e9SAndroid Build Coastguard Worker nonlocal num_aot_guards 4517*da0073e9SAndroid Build Coastguard Worker nonlocal num_compiles 4518*da0073e9SAndroid Build Coastguard Worker num_shape_guards = len( 4519*da0073e9SAndroid Build Coastguard Worker torch._guards.TracingContext.try_get().fake_mode.shape_env.guards 4520*da0073e9SAndroid Build Coastguard Worker ) 4521*da0073e9SAndroid Build Coastguard Worker num_aot_guards = len( 4522*da0073e9SAndroid Build Coastguard Worker torch._guards.TracingContext.try_get().guards_context.aotautograd_guards 4523*da0073e9SAndroid Build Coastguard Worker ) 4524*da0073e9SAndroid Build Coastguard Worker num_compiles += 1 4525*da0073e9SAndroid Build Coastguard Worker return gm 4526*da0073e9SAndroid Build Coastguard Worker 4527*da0073e9SAndroid Build Coastguard Worker aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend) 4528*da0073e9SAndroid Build Coastguard Worker 4529*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=aot_guard_counter, dynamic=True) 4530*da0073e9SAndroid Build Coastguard Worker def f(*args): 4531*da0073e9SAndroid Build Coastguard Worker for a in args: 4532*da0073e9SAndroid Build Coastguard Worker a.add_(1) 4533*da0073e9SAndroid Build Coastguard Worker 4534*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1000, requires_grad=True) 4535*da0073e9SAndroid Build Coastguard Worker args = x.split(10) 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4538*da0073e9SAndroid Build Coastguard Worker f(*args) 4539*da0073e9SAndroid Build Coastguard Worker # In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2), 4540*da0073e9SAndroid Build Coastguard Worker # because every pair of aliased inputs needs a guard. 4541*da0073e9SAndroid Build Coastguard Worker self.assertTrue(num_aot_guards < 5000) 4542*da0073e9SAndroid Build Coastguard Worker # But there are no dynamic shape guards. 4543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_shape_guards, 0) 4544*da0073e9SAndroid Build Coastguard Worker # don't recompile 4545*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4546*da0073e9SAndroid Build Coastguard Worker f(*args) 4547*da0073e9SAndroid Build Coastguard Worker self.assertEqual(num_compiles, 1) 4548*da0073e9SAndroid Build Coastguard Worker 4549*da0073e9SAndroid Build Coastguard Worker def test_invalid_seq_unpack(self): 4550*da0073e9SAndroid Build Coastguard Worker def myfn(arg): 4551*da0073e9SAndroid Build Coastguard Worker (a, b) = arg 4552*da0073e9SAndroid Build Coastguard Worker 4553*da0073e9SAndroid Build Coastguard Worker def fn(): 4554*da0073e9SAndroid Build Coastguard Worker return myfn((1, 2, 3)) 4555*da0073e9SAndroid Build Coastguard Worker 4556*da0073e9SAndroid Build Coastguard Worker try: 4557*da0073e9SAndroid Build Coastguard Worker torch.compile(fn)() 4558*da0073e9SAndroid Build Coastguard Worker except ValueError: 4559*da0073e9SAndroid Build Coastguard Worker pass 4560*da0073e9SAndroid Build Coastguard Worker else: 4561*da0073e9SAndroid Build Coastguard Worker self.fail("expected exception") 4562*da0073e9SAndroid Build Coastguard Worker 4563*da0073e9SAndroid Build Coastguard Worker def test_megablocks_moe(self): 4564*da0073e9SAndroid Build Coastguard Worker try: 4565*da0073e9SAndroid Build Coastguard Worker from megablocks.layers import moe 4566*da0073e9SAndroid Build Coastguard Worker from megablocks.layers.arguments import Arguments 4567*da0073e9SAndroid Build Coastguard Worker except ImportError as e: 4568*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("requires megablocks") from e 4569*da0073e9SAndroid Build Coastguard Worker bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1) 4570*da0073e9SAndroid Build Coastguard Worker args = Arguments( 4571*da0073e9SAndroid Build Coastguard Worker hidden_size=hs, 4572*da0073e9SAndroid Build Coastguard Worker ffn_hidden_size=hs * 2, 4573*da0073e9SAndroid Build Coastguard Worker moe_num_experts=num_experts, 4574*da0073e9SAndroid Build Coastguard Worker moe_capacity_factor=1, 4575*da0073e9SAndroid Build Coastguard Worker moe_top_k=top_k, 4576*da0073e9SAndroid Build Coastguard Worker ) 4577*da0073e9SAndroid Build Coastguard Worker moe_mlp = moe.MoE(args) 4578*da0073e9SAndroid Build Coastguard Worker moe_mlp.cuda(torch.cuda.current_device()).half() 4579*da0073e9SAndroid Build Coastguard Worker x = torch.randn(sl, bs, hs).cuda().half() 4580*da0073e9SAndroid Build Coastguard Worker out1, _ = moe_mlp(x) 4581*da0073e9SAndroid Build Coastguard Worker out2, _ = torch.compile(moe_mlp, backend="eager")(x) 4582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, out2) 4583*da0073e9SAndroid Build Coastguard Worker 4584*da0073e9SAndroid Build Coastguard Worker def test_udf_classes_reconstruction(self): 4585*da0073e9SAndroid Build Coastguard Worker def fn(x): 4586*da0073e9SAndroid Build Coastguard Worker o = T(5) 4587*da0073e9SAndroid Build Coastguard Worker return o.x + x 4588*da0073e9SAndroid Build Coastguard Worker 4589*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 4590*da0073e9SAndroid Build Coastguard Worker T = IncByOne 4591*da0073e9SAndroid Build Coastguard Worker 4592*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 4593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4594*da0073e9SAndroid Build Coastguard Worker 4595*da0073e9SAndroid Build Coastguard Worker # This should recompile 4596*da0073e9SAndroid Build Coastguard Worker T = IncByTwo 4597*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4598*da0073e9SAndroid Build Coastguard Worker 4599*da0073e9SAndroid Build Coastguard Worker def test_contains_range_constprop(self): 4600*da0073e9SAndroid Build Coastguard Worker def fn(x): 4601*da0073e9SAndroid Build Coastguard Worker # dynamo should const prop to False 4602*da0073e9SAndroid Build Coastguard Worker if 3 in range(0, 10): 4603*da0073e9SAndroid Build Coastguard Worker return x + 1 4604*da0073e9SAndroid Build Coastguard Worker else: 4605*da0073e9SAndroid Build Coastguard Worker return x + 2 4606*da0073e9SAndroid Build Coastguard Worker 4607*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 4608*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4) 4609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 4610*da0073e9SAndroid Build Coastguard Worker 4611*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/104505 4612*da0073e9SAndroid Build Coastguard Worker def test_as_strided_on_base_with_mutation_works(self): 4613*da0073e9SAndroid Build Coastguard Worker def foo(a): 4614*da0073e9SAndroid Build Coastguard Worker f = a.as_strided((2,), (1,), 0) 4615*da0073e9SAndroid Build Coastguard Worker f.add_(1.0) 4616*da0073e9SAndroid Build Coastguard Worker return a 4617*da0073e9SAndroid Build Coastguard Worker 4618*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 4) 4619*da0073e9SAndroid Build Coastguard Worker a_ref = a.clone() 4620*da0073e9SAndroid Build Coastguard Worker out_ref = foo(a_ref) 4621*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(foo, backend="aot_eager") 4622*da0073e9SAndroid Build Coastguard Worker out = f_compiled(a) 4623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 4624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_ref, a) 4625*da0073e9SAndroid Build Coastguard Worker 4626*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/104505 4627*da0073e9SAndroid Build Coastguard Worker def test_as_strided_on_existing_view_banned(self): 4628*da0073e9SAndroid Build Coastguard Worker def foo(a): 4629*da0073e9SAndroid Build Coastguard Worker e = a.diagonal() 4630*da0073e9SAndroid Build Coastguard Worker f = e.as_strided((2,), (1,), 0) 4631*da0073e9SAndroid Build Coastguard Worker f.add_(1.0) 4632*da0073e9SAndroid Build Coastguard Worker return a 4633*da0073e9SAndroid Build Coastguard Worker 4634*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 4) 4635*da0073e9SAndroid Build Coastguard Worker a_ref = a.clone() 4636*da0073e9SAndroid Build Coastguard Worker out_ref = foo(a_ref) 4637*da0073e9SAndroid Build Coastguard Worker f_compiled = torch.compile(foo, backend="aot_eager") 4638*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4639*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4640*da0073e9SAndroid Build Coastguard Worker "encountered a mutation on a view chain of length 2, where view 1 was an as_strided", 4641*da0073e9SAndroid Build Coastguard Worker ): 4642*da0073e9SAndroid Build Coastguard Worker out = f_compiled(a) 4643*da0073e9SAndroid Build Coastguard Worker 4644*da0073e9SAndroid Build Coastguard Worker def test_dont_aggressively_write_assert(self): 4645*da0073e9SAndroid Build Coastguard Worker record_graph = torch._dynamo.testing.EagerAndRecordGraphs() 4646*da0073e9SAndroid Build Coastguard Worker 4647*da0073e9SAndroid Build Coastguard Worker @torch.compile(dynamic=True, backend=record_graph) 4648*da0073e9SAndroid Build Coastguard Worker def f(x): 4649*da0073e9SAndroid Build Coastguard Worker assert x.shape[0] > 3 4650*da0073e9SAndroid Build Coastguard Worker assert x[0].sum() > 0 4651*da0073e9SAndroid Build Coastguard Worker assert 1 % (x.shape[0] // 2) != 0 4652*da0073e9SAndroid Build Coastguard Worker assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0 4653*da0073e9SAndroid Build Coastguard Worker return x.cos() 4654*da0073e9SAndroid Build Coastguard Worker 4655*da0073e9SAndroid Build Coastguard Worker f(torch.ones(6, 4)) 4656*da0073e9SAndroid Build Coastguard Worker graph = record_graph.graphs[0] 4657*da0073e9SAndroid Build Coastguard Worker # It is bit annoying that we generate useless statements for 4658*da0073e9SAndroid Build Coastguard Worker # shape guards, but DCE should be able to remove them since t 4659*da0073e9SAndroid Build Coastguard Worker # there is no backed assert on them. The reason this is ok is 4660*da0073e9SAndroid Build Coastguard Worker # because dynamo will only skip the assert statement, but not 4661*da0073e9SAndroid Build Coastguard Worker # the instructions before it. 4662*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 4663*da0073e9SAndroid Build Coastguard Worker str(graph.code).strip(), 4664*da0073e9SAndroid Build Coastguard Worker """\ 4665*da0073e9SAndroid Build Coastguard Workerdef forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): 4666*da0073e9SAndroid Build Coastguard Worker l_x_ = L_x_ 4667*da0073e9SAndroid Build Coastguard Worker getitem_2 = l_x_[0] 4668*da0073e9SAndroid Build Coastguard Worker sum_1 = getitem_2.sum(); getitem_2 = None 4669*da0073e9SAndroid Build Coastguard Worker gt_1 = sum_1 > 0; sum_1 = None 4670*da0073e9SAndroid Build Coastguard Worker _assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = _assert_async = None 4671*da0073e9SAndroid Build Coastguard Worker cos = l_x_.cos(); l_x_ = None 4672*da0073e9SAndroid Build Coastguard Worker return (cos,)""", 4673*da0073e9SAndroid Build Coastguard Worker ) 4674*da0073e9SAndroid Build Coastguard Worker for node in graph.graph.nodes: 4675*da0073e9SAndroid Build Coastguard Worker if "example_value" in node.meta and isinstance( 4676*da0073e9SAndroid Build Coastguard Worker node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor 4677*da0073e9SAndroid Build Coastguard Worker ): 4678*da0073e9SAndroid Build Coastguard Worker shape_env = node.meta["example_value"].fake_mode.shape_env 4679*da0073e9SAndroid Build Coastguard Worker lower_ranges = [val.lower for val in shape_env.var_to_range.values()] 4680*da0073e9SAndroid Build Coastguard Worker self.assertTrue(lower_ranges == [4, 2]) 4681*da0073e9SAndroid Build Coastguard Worker 4682*da0073e9SAndroid Build Coastguard Worker @torch.compile(dynamic=True, backend=record_graph) 4683*da0073e9SAndroid Build Coastguard Worker def f_fail(x): 4684*da0073e9SAndroid Build Coastguard Worker assert x.shape[0] < 3 4685*da0073e9SAndroid Build Coastguard Worker 4686*da0073e9SAndroid Build Coastguard Worker # We graph-break here, so the failure should be eager 4687*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, ""): 4688*da0073e9SAndroid Build Coastguard Worker f_fail(torch.ones(6, 4)) 4689*da0073e9SAndroid Build Coastguard Worker 4690*da0073e9SAndroid Build Coastguard Worker def test_detectron2_instances_cat(self): 4691*da0073e9SAndroid Build Coastguard Worker class Instances: 4692*da0073e9SAndroid Build Coastguard Worker def __init__(self, image_size: Tuple[int, int], **kwargs: Any): 4693*da0073e9SAndroid Build Coastguard Worker self._image_size = image_size 4694*da0073e9SAndroid Build Coastguard Worker self._fields: Dict[str, Any] = {} 4695*da0073e9SAndroid Build Coastguard Worker for k, v in kwargs.items(): 4696*da0073e9SAndroid Build Coastguard Worker self.set(k, v) 4697*da0073e9SAndroid Build Coastguard Worker 4698*da0073e9SAndroid Build Coastguard Worker @property 4699*da0073e9SAndroid Build Coastguard Worker def image_size(self) -> Tuple[int, int]: 4700*da0073e9SAndroid Build Coastguard Worker return self._image_size 4701*da0073e9SAndroid Build Coastguard Worker 4702*da0073e9SAndroid Build Coastguard Worker def __setattr__(self, name: str, val: Any) -> None: 4703*da0073e9SAndroid Build Coastguard Worker if name.startswith("_"): 4704*da0073e9SAndroid Build Coastguard Worker super().__setattr__(name, val) 4705*da0073e9SAndroid Build Coastguard Worker else: 4706*da0073e9SAndroid Build Coastguard Worker self.set(name, val) 4707*da0073e9SAndroid Build Coastguard Worker 4708*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name: str) -> Any: 4709*da0073e9SAndroid Build Coastguard Worker if name == "_fields" or name not in self._fields: 4710*da0073e9SAndroid Build Coastguard Worker raise AttributeError( 4711*da0073e9SAndroid Build Coastguard Worker f"Cannot find field '{name}' in the given Instances!" 4712*da0073e9SAndroid Build Coastguard Worker ) 4713*da0073e9SAndroid Build Coastguard Worker return self._fields[name] 4714*da0073e9SAndroid Build Coastguard Worker 4715*da0073e9SAndroid Build Coastguard Worker def __len__(self) -> int: 4716*da0073e9SAndroid Build Coastguard Worker for v in self._fields.values(): 4717*da0073e9SAndroid Build Coastguard Worker # use __len__ because len() has to be int and is not friendly to tracing 4718*da0073e9SAndroid Build Coastguard Worker return v.__len__() 4719*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Empty Instances does not support __len__!") 4720*da0073e9SAndroid Build Coastguard Worker 4721*da0073e9SAndroid Build Coastguard Worker def set(self, name: str, value: Any) -> None: 4722*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True): 4723*da0073e9SAndroid Build Coastguard Worker data_len = len(value) 4724*da0073e9SAndroid Build Coastguard Worker if len(self._fields): 4725*da0073e9SAndroid Build Coastguard Worker assert ( 4726*da0073e9SAndroid Build Coastguard Worker len(self) == data_len 4727*da0073e9SAndroid Build Coastguard Worker ), f"Adding a field of length {data_len} to a Instances of length {len(self)}" 4728*da0073e9SAndroid Build Coastguard Worker self._fields[name] = value 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Worker def get(self, name: str) -> Any: 4731*da0073e9SAndroid Build Coastguard Worker return self._fields[name] 4732*da0073e9SAndroid Build Coastguard Worker 4733*da0073e9SAndroid Build Coastguard Worker @staticmethod 4734*da0073e9SAndroid Build Coastguard Worker def cat(instance_lists: List["Instances"]) -> "Instances": 4735*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(i, Instances) for i in instance_lists) 4736*da0073e9SAndroid Build Coastguard Worker assert len(instance_lists) > 0 4737*da0073e9SAndroid Build Coastguard Worker if len(instance_lists) == 1: 4738*da0073e9SAndroid Build Coastguard Worker return instance_lists[0] 4739*da0073e9SAndroid Build Coastguard Worker 4740*da0073e9SAndroid Build Coastguard Worker image_size = instance_lists[0].image_size 4741*da0073e9SAndroid Build Coastguard Worker if not isinstance( 4742*da0073e9SAndroid Build Coastguard Worker image_size, torch.Tensor 4743*da0073e9SAndroid Build Coastguard Worker ): # could be a tensor in tracing 4744*da0073e9SAndroid Build Coastguard Worker for i in instance_lists[1:]: 4745*da0073e9SAndroid Build Coastguard Worker assert i.image_size == image_size 4746*da0073e9SAndroid Build Coastguard Worker ret = Instances(image_size) 4747*da0073e9SAndroid Build Coastguard Worker for k in instance_lists[0]._fields.keys(): 4748*da0073e9SAndroid Build Coastguard Worker values = [i.get(k) for i in instance_lists] 4749*da0073e9SAndroid Build Coastguard Worker v0 = values[0] 4750*da0073e9SAndroid Build Coastguard Worker if isinstance(v0, torch.Tensor): 4751*da0073e9SAndroid Build Coastguard Worker values = torch.cat(values, dim=0) 4752*da0073e9SAndroid Build Coastguard Worker elif isinstance(v0, list): 4753*da0073e9SAndroid Build Coastguard Worker values = list(itertools.chain(*values)) 4754*da0073e9SAndroid Build Coastguard Worker elif hasattr(type(v0), "cat"): 4755*da0073e9SAndroid Build Coastguard Worker values = type(v0).cat(values) 4756*da0073e9SAndroid Build Coastguard Worker else: 4757*da0073e9SAndroid Build Coastguard Worker raise ValueError( 4758*da0073e9SAndroid Build Coastguard Worker f"Unsupported type {type(v0)} for concatenation" 4759*da0073e9SAndroid Build Coastguard Worker ) 4760*da0073e9SAndroid Build Coastguard Worker ret.set(k, values) 4761*da0073e9SAndroid Build Coastguard Worker return ret 4762*da0073e9SAndroid Build Coastguard Worker 4763*da0073e9SAndroid Build Coastguard Worker instances = [ 4764*da0073e9SAndroid Build Coastguard Worker Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16)) 4765*da0073e9SAndroid Build Coastguard Worker for _ in range(3) 4766*da0073e9SAndroid Build Coastguard Worker ] 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 4769*da0073e9SAndroid Build Coastguard Worker def fn(instances): 4770*da0073e9SAndroid Build Coastguard Worker return instances[0].cat(instances) 4771*da0073e9SAndroid Build Coastguard Worker 4772*da0073e9SAndroid Build Coastguard Worker actual = fn(instances) 4773*da0073e9SAndroid Build Coastguard Worker expected = instances[0].cat(instances) 4774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(actual), type(expected)) 4775*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual.__dict__, expected.__dict__) 4776*da0073e9SAndroid Build Coastguard Worker 4777*da0073e9SAndroid Build Coastguard Worker def test_weakref(self): 4778*da0073e9SAndroid Build Coastguard Worker def fn(x_weak, weight, y): 4779*da0073e9SAndroid Build Coastguard Worker if x_weak is not None and x_weak() is not weight: 4780*da0073e9SAndroid Build Coastguard Worker return torch.sin(y) 4781*da0073e9SAndroid Build Coastguard Worker return torch.cos(y) 4782*da0073e9SAndroid Build Coastguard Worker 4783*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(4) 4784*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4) 4785*da0073e9SAndroid Build Coastguard Worker x_weak = weakref.ref(weight) 4786*da0073e9SAndroid Build Coastguard Worker 4787*da0073e9SAndroid Build Coastguard Worker ref = fn(x_weak, weight, y) 4788*da0073e9SAndroid Build Coastguard Worker 4789*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4790*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x_weak, weight, y) 4791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4792*da0073e9SAndroid Build Coastguard Worker 4793*da0073e9SAndroid Build Coastguard Worker def test_weakref_reconstruct(self): 4794*da0073e9SAndroid Build Coastguard Worker def fn(x_weak, weight, y): 4795*da0073e9SAndroid Build Coastguard Worker y = torch.sin(y) 4796*da0073e9SAndroid Build Coastguard Worker referent = x_weak() 4797*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 4798*da0073e9SAndroid Build Coastguard Worker if referent is not weight: 4799*da0073e9SAndroid Build Coastguard Worker return torch.sin(y) 4800*da0073e9SAndroid Build Coastguard Worker return torch.cos(y) 4801*da0073e9SAndroid Build Coastguard Worker 4802*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(4) 4803*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4) 4804*da0073e9SAndroid Build Coastguard Worker x_weak = weakref.ref(weight) 4805*da0073e9SAndroid Build Coastguard Worker 4806*da0073e9SAndroid Build Coastguard Worker ref = fn(x_weak, weight, y) 4807*da0073e9SAndroid Build Coastguard Worker 4808*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 4809*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend=cnt) 4810*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x_weak, weight, y) 4811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 2) 4813*da0073e9SAndroid Build Coastguard Worker 4814*da0073e9SAndroid Build Coastguard Worker def test_weakref_del(self): 4815*da0073e9SAndroid Build Coastguard Worker def fn(x_weak, y): 4816*da0073e9SAndroid Build Coastguard Worker x = x_weak() 4817*da0073e9SAndroid Build Coastguard Worker if x is not None: 4818*da0073e9SAndroid Build Coastguard Worker return torch.sin(y) 4819*da0073e9SAndroid Build Coastguard Worker return torch.cos(y) 4820*da0073e9SAndroid Build Coastguard Worker 4821*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(4) 4822*da0073e9SAndroid Build Coastguard Worker x_weak = weakref.ref(weight) 4823*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4) 4824*da0073e9SAndroid Build Coastguard Worker 4825*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4826*da0073e9SAndroid Build Coastguard Worker 4827*da0073e9SAndroid Build Coastguard Worker ref = fn(x_weak, y) 4828*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x_weak, y) 4829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4830*da0073e9SAndroid Build Coastguard Worker 4831*da0073e9SAndroid Build Coastguard Worker del weight 4832*da0073e9SAndroid Build Coastguard Worker gc.collect() 4833*da0073e9SAndroid Build Coastguard Worker ref = fn(x_weak, y) 4834*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x_weak, y) 4835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4836*da0073e9SAndroid Build Coastguard Worker 4837*da0073e9SAndroid Build Coastguard Worker # @torch._functorch.config.patch( 4838*da0073e9SAndroid Build Coastguard Worker # recompute_views=True, 4839*da0073e9SAndroid Build Coastguard Worker # ) 4840*da0073e9SAndroid Build Coastguard Worker # def test_storage_resize_forward_full_graph(self): 4841*da0073e9SAndroid Build Coastguard Worker # class TestModule(torch.nn.Module): 4842*da0073e9SAndroid Build Coastguard Worker # def __init__(self) -> None: 4843*da0073e9SAndroid Build Coastguard Worker # super().__init__() 4844*da0073e9SAndroid Build Coastguard Worker # self.param = torch.nn.Parameter(torch.randn(4, 4)) 4845*da0073e9SAndroid Build Coastguard Worker 4846*da0073e9SAndroid Build Coastguard Worker # def forward(self, x): 4847*da0073e9SAndroid Build Coastguard Worker # self.param.untyped_storage().resize_( 4848*da0073e9SAndroid Build Coastguard Worker # self.param.numel() * self.param.itemsize 4849*da0073e9SAndroid Build Coastguard Worker # ) 4850*da0073e9SAndroid Build Coastguard Worker # with torch.no_grad(): 4851*da0073e9SAndroid Build Coastguard Worker # torch._foreach_copy_([self.param], [x]) 4852*da0073e9SAndroid Build Coastguard Worker # out = torch.matmul(self.param, self.param) 4853*da0073e9SAndroid Build Coastguard Worker # self.param.untyped_storage().resize_(0) 4854*da0073e9SAndroid Build Coastguard Worker # return out 4855*da0073e9SAndroid Build Coastguard Worker 4856*da0073e9SAndroid Build Coastguard Worker # def post_accumulate_grad_hook(param): 4857*da0073e9SAndroid Build Coastguard Worker # param.untyped_storage().resize_(0) 4858*da0073e9SAndroid Build Coastguard Worker 4859*da0073e9SAndroid Build Coastguard Worker # # Beginning of backward, resize and put data into the param 4860*da0073e9SAndroid Build Coastguard Worker # def pre_backward_hook(module, grad) -> None: 4861*da0073e9SAndroid Build Coastguard Worker # module.param.untyped_storage().resize_( 4862*da0073e9SAndroid Build Coastguard Worker # self.param.numel() * self.param.itemsize 4863*da0073e9SAndroid Build Coastguard Worker # ) 4864*da0073e9SAndroid Build Coastguard Worker # with torch.no_grad(): 4865*da0073e9SAndroid Build Coastguard Worker # # simulates loading data into param from allgather 4866*da0073e9SAndroid Build Coastguard Worker # module.param.fill_(2) 4867*da0073e9SAndroid Build Coastguard Worker 4868*da0073e9SAndroid Build Coastguard Worker # def post_forward_hook(module, args, output): 4869*da0073e9SAndroid Build Coastguard Worker # output.register_hook(functools.partial(pre_backward_hook, module)) 4870*da0073e9SAndroid Build Coastguard Worker 4871*da0073e9SAndroid Build Coastguard Worker # x = torch.randn(4, 4) 4872*da0073e9SAndroid Build Coastguard Worker 4873*da0073e9SAndroid Build Coastguard Worker # mod_ref = TestModule() 4874*da0073e9SAndroid Build Coastguard Worker # mod_test = deepcopy(mod_ref) 4875*da0073e9SAndroid Build Coastguard Worker 4876*da0073e9SAndroid Build Coastguard Worker # # Start the param off with zero storage size to mimic fsdp 4877*da0073e9SAndroid Build Coastguard Worker # mod_ref.param.untyped_storage().resize_(0) 4878*da0073e9SAndroid Build Coastguard Worker # mod_test.param.untyped_storage().resize_(0) 4879*da0073e9SAndroid Build Coastguard Worker 4880*da0073e9SAndroid Build Coastguard Worker # # Resize storage at beginning of backward 4881*da0073e9SAndroid Build Coastguard Worker # # Free storage at end of backward 4882*da0073e9SAndroid Build Coastguard Worker # mod_ref.register_forward_hook(post_forward_hook, prepend=False) 4883*da0073e9SAndroid Build Coastguard Worker # mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook) 4884*da0073e9SAndroid Build Coastguard Worker # mod_test.register_forward_hook(post_forward_hook, prepend=False) 4885*da0073e9SAndroid Build Coastguard Worker # mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook) 4886*da0073e9SAndroid Build Coastguard Worker 4887*da0073e9SAndroid Build Coastguard Worker # mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend) 4888*da0073e9SAndroid Build Coastguard Worker 4889*da0073e9SAndroid Build Coastguard Worker # out_ref = mod_ref(x) 4890*da0073e9SAndroid Build Coastguard Worker # out_test = mod_test(x) 4891*da0073e9SAndroid Build Coastguard Worker # self.assertExpectedInline( 4892*da0073e9SAndroid Build Coastguard Worker # str(fw_graph[0].code.strip()), 4893*da0073e9SAndroid Build Coastguard Worker # """\ 4894*da0073e9SAndroid Build Coastguard Worker # def forward(self, primals_1, primals_2): 4895*da0073e9SAndroid Build Coastguard Worker # _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None 4896*da0073e9SAndroid Build Coastguard Worker # getitem = _foreach_copy[0]; _foreach_copy = None 4897*da0073e9SAndroid Build Coastguard Worker # mm = torch.ops.aten.mm.default(getitem, getitem) 4898*da0073e9SAndroid Build Coastguard Worker # return [mm, getitem]""", 4899*da0073e9SAndroid Build Coastguard Worker # ) 4900*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out_ref, out_test) 4901*da0073e9SAndroid Build Coastguard Worker 4902*da0073e9SAndroid Build Coastguard Worker def test_super_in_staticmethod(self): 4903*da0073e9SAndroid Build Coastguard Worker class A: 4904*da0073e9SAndroid Build Coastguard Worker @staticmethod 4905*da0073e9SAndroid Build Coastguard Worker def foo(): 4906*da0073e9SAndroid Build Coastguard Worker return super().__init__() 4907*da0073e9SAndroid Build Coastguard Worker 4908*da0073e9SAndroid Build Coastguard Worker def fn(obj): 4909*da0073e9SAndroid Build Coastguard Worker return obj.foo() 4910*da0073e9SAndroid Build Coastguard Worker 4911*da0073e9SAndroid Build Coastguard Worker obj = A() 4912*da0073e9SAndroid Build Coastguard Worker 4913*da0073e9SAndroid Build Coastguard Worker try: 4914*da0073e9SAndroid Build Coastguard Worker fn(obj) 4915*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4916*da0073e9SAndroid Build Coastguard Worker orig_str = str(e) 4917*da0073e9SAndroid Build Coastguard Worker self.assertIn("no arguments", orig_str) 4918*da0073e9SAndroid Build Coastguard Worker 4919*da0073e9SAndroid Build Coastguard Worker try: 4920*da0073e9SAndroid Build Coastguard Worker torch.compile(backend="eager")(fn)(obj) 4921*da0073e9SAndroid Build Coastguard Worker except Exception as e: 4922*da0073e9SAndroid Build Coastguard Worker compiled_str = str(e) 4923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_str, compiled_str) 4924*da0073e9SAndroid Build Coastguard Worker 4925*da0073e9SAndroid Build Coastguard Worker def test_super_staticmethod(self): 4926*da0073e9SAndroid Build Coastguard Worker class Parent: 4927*da0073e9SAndroid Build Coastguard Worker @staticmethod 4928*da0073e9SAndroid Build Coastguard Worker def greet(): 4929*da0073e9SAndroid Build Coastguard Worker return 5 4930*da0073e9SAndroid Build Coastguard Worker 4931*da0073e9SAndroid Build Coastguard Worker class Child(Parent): 4932*da0073e9SAndroid Build Coastguard Worker @staticmethod 4933*da0073e9SAndroid Build Coastguard Worker def greet(x): 4934*da0073e9SAndroid Build Coastguard Worker return x * super(Child, Child).greet() 4935*da0073e9SAndroid Build Coastguard Worker 4936*da0073e9SAndroid Build Coastguard Worker child = Child() 4937*da0073e9SAndroid Build Coastguard Worker 4938*da0073e9SAndroid Build Coastguard Worker def fn(x): 4939*da0073e9SAndroid Build Coastguard Worker return child.greet(x) 4940*da0073e9SAndroid Build Coastguard Worker 4941*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4942*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4) 4943*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 4944*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 4945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4946*da0073e9SAndroid Build Coastguard Worker 4947*da0073e9SAndroid Build Coastguard Worker def test_super_diamond(self): 4948*da0073e9SAndroid Build Coastguard Worker class A: 4949*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4950*da0073e9SAndroid Build Coastguard Worker super().__init__() 4951*da0073e9SAndroid Build Coastguard Worker self.a = 5 4952*da0073e9SAndroid Build Coastguard Worker 4953*da0073e9SAndroid Build Coastguard Worker class Nothing: 4954*da0073e9SAndroid Build Coastguard Worker pass 4955*da0073e9SAndroid Build Coastguard Worker 4956*da0073e9SAndroid Build Coastguard Worker class B(Nothing, A): 4957*da0073e9SAndroid Build Coastguard Worker def __init__(self): 4958*da0073e9SAndroid Build Coastguard Worker super().__init__() 4959*da0073e9SAndroid Build Coastguard Worker self.b = 10 4960*da0073e9SAndroid Build Coastguard Worker 4961*da0073e9SAndroid Build Coastguard Worker def run(self, x): 4962*da0073e9SAndroid Build Coastguard Worker return self.a * self.b * x 4963*da0073e9SAndroid Build Coastguard Worker 4964*da0073e9SAndroid Build Coastguard Worker def fn(x): 4965*da0073e9SAndroid Build Coastguard Worker b = B() 4966*da0073e9SAndroid Build Coastguard Worker return b.run(x) 4967*da0073e9SAndroid Build Coastguard Worker 4968*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4969*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 4970*da0073e9SAndroid Build Coastguard Worker ref = fn(x) 4971*da0073e9SAndroid Build Coastguard Worker res = opt_fn(x) 4972*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, res) 4973*da0073e9SAndroid Build Coastguard Worker 4974*da0073e9SAndroid Build Coastguard Worker def test_vc_bumped_in_inference_graph(self): 4975*da0073e9SAndroid Build Coastguard Worker @torch.compile 4976*da0073e9SAndroid Build Coastguard Worker def f(x): 4977*da0073e9SAndroid Build Coastguard Worker return x.mul_(2) 4978*da0073e9SAndroid Build Coastguard Worker 4979*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 4980*da0073e9SAndroid Build Coastguard Worker vc_before = x._version 4981*da0073e9SAndroid Build Coastguard Worker f(x) 4982*da0073e9SAndroid Build Coastguard Worker vc_after = x._version 4983*da0073e9SAndroid Build Coastguard Worker self.assertTrue(vc_after > vc_before) 4984*da0073e9SAndroid Build Coastguard Worker 4985*da0073e9SAndroid Build Coastguard Worker def test_nn_module_callable(self): 4986*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 4987*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4988*da0073e9SAndroid Build Coastguard Worker return x.sin() 4989*da0073e9SAndroid Build Coastguard Worker 4990*da0073e9SAndroid Build Coastguard Worker def f(m): 4991*da0073e9SAndroid Build Coastguard Worker return callable(m) 4992*da0073e9SAndroid Build Coastguard Worker 4993*da0073e9SAndroid Build Coastguard Worker res = torch.compile(f, fullgraph=True)(M()) 4994*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res) 4995*da0073e9SAndroid Build Coastguard Worker 4996*da0073e9SAndroid Build Coastguard Worker def test_stk_sdd_is_transposed(self): 4997*da0073e9SAndroid Build Coastguard Worker trigger_graph_break = False 4998*da0073e9SAndroid Build Coastguard Worker 4999*da0073e9SAndroid Build Coastguard Worker def _is_transposed(x): 5000*da0073e9SAndroid Build Coastguard Worker return ( 5001*da0073e9SAndroid Build Coastguard Worker not x.is_contiguous() 5002*da0073e9SAndroid Build Coastguard Worker and x.stride()[0] == 1 5003*da0073e9SAndroid Build Coastguard Worker and x.stride()[1] == x.size()[0] 5004*da0073e9SAndroid Build Coastguard Worker ) 5005*da0073e9SAndroid Build Coastguard Worker 5006*da0073e9SAndroid Build Coastguard Worker class SDD(torch.autograd.Function): 5007*da0073e9SAndroid Build Coastguard Worker @staticmethod 5008*da0073e9SAndroid Build Coastguard Worker def forward(ctx, lhs, rhs): 5009*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(lhs, rhs) 5010*da0073e9SAndroid Build Coastguard Worker out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device) 5011*da0073e9SAndroid Build Coastguard Worker return out 5012*da0073e9SAndroid Build Coastguard Worker 5013*da0073e9SAndroid Build Coastguard Worker @staticmethod 5014*da0073e9SAndroid Build Coastguard Worker def backward(ctx, dy): 5015*da0073e9SAndroid Build Coastguard Worker saved_tensors = ctx.saved_tensors 5016*da0073e9SAndroid Build Coastguard Worker lhs, rhs = saved_tensors[:2] 5017*da0073e9SAndroid Build Coastguard Worker trans_a = _is_transposed(lhs) 5018*da0073e9SAndroid Build Coastguard Worker trans_b = _is_transposed(rhs) 5019*da0073e9SAndroid Build Coastguard Worker dlhs = None 5020*da0073e9SAndroid Build Coastguard Worker if ctx.needs_input_grad[0]: 5021*da0073e9SAndroid Build Coastguard Worker dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0) 5022*da0073e9SAndroid Build Coastguard Worker drhs = None 5023*da0073e9SAndroid Build Coastguard Worker if ctx.needs_input_grad[1]: 5024*da0073e9SAndroid Build Coastguard Worker drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0) 5025*da0073e9SAndroid Build Coastguard Worker if trigger_graph_break: 5026*da0073e9SAndroid Build Coastguard Worker if _is_transposed(dy): 5027*da0073e9SAndroid Build Coastguard Worker return dlhs + 1, drhs + 1, None, None 5028*da0073e9SAndroid Build Coastguard Worker return dlhs, drhs, None, None 5029*da0073e9SAndroid Build Coastguard Worker 5030*da0073e9SAndroid Build Coastguard Worker x1 = torch.randn((8, 8), requires_grad=True) 5031*da0073e9SAndroid Build Coastguard Worker y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True) 5032*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn((8, 8), requires_grad=True) 5033*da0073e9SAndroid Build Coastguard Worker y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True) 5034*da0073e9SAndroid Build Coastguard Worker 5035*da0073e9SAndroid Build Coastguard Worker SDD.apply(x1, y1).sum().backward() 5036*da0073e9SAndroid Build Coastguard Worker 5037*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 5038*da0073e9SAndroid Build Coastguard Worker def fn(): 5039*da0073e9SAndroid Build Coastguard Worker return SDD.apply(x2, y2) 5040*da0073e9SAndroid Build Coastguard Worker 5041*da0073e9SAndroid Build Coastguard Worker fn().sum().backward() 5042*da0073e9SAndroid Build Coastguard Worker 5043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x2.grad) 5044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1.grad, y2.grad) 5045*da0073e9SAndroid Build Coastguard Worker 5046*da0073e9SAndroid Build Coastguard Worker trigger_graph_break = True 5047*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch._dynamo.exc.Unsupported): 5048*da0073e9SAndroid Build Coastguard Worker fn().sum().backward() 5049*da0073e9SAndroid Build Coastguard Worker 5050*da0073e9SAndroid Build Coastguard Worker def test_partially_initialized_module_property(self): 5051*da0073e9SAndroid Build Coastguard Worker class Matrix(torch.nn.Module): 5052*da0073e9SAndroid Build Coastguard Worker def __init__(self, data): 5053*da0073e9SAndroid Build Coastguard Worker super().__init__() 5054*da0073e9SAndroid Build Coastguard Worker self._data = data 5055*da0073e9SAndroid Build Coastguard Worker self.foo = 10 * self.blocking 5056*da0073e9SAndroid Build Coastguard Worker 5057*da0073e9SAndroid Build Coastguard Worker @property 5058*da0073e9SAndroid Build Coastguard Worker def data(self): 5059*da0073e9SAndroid Build Coastguard Worker return self._data 5060*da0073e9SAndroid Build Coastguard Worker 5061*da0073e9SAndroid Build Coastguard Worker @property 5062*da0073e9SAndroid Build Coastguard Worker def blocking(self): 5063*da0073e9SAndroid Build Coastguard Worker return self.data.shape[1] 5064*da0073e9SAndroid Build Coastguard Worker 5065*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 5066*da0073e9SAndroid Build Coastguard Worker def fn(): 5067*da0073e9SAndroid Build Coastguard Worker return Matrix(torch.randn(10, 20)) 5068*da0073e9SAndroid Build Coastguard Worker 5069*da0073e9SAndroid Build Coastguard Worker v = fn() 5070*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.foo, 200) 5071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.data.shape, (10, 20)) 5072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(v), Matrix) 5073*da0073e9SAndroid Build Coastguard Worker 5074*da0073e9SAndroid Build Coastguard Worker def test_classmethod_with_slots(self): 5075*da0073e9SAndroid Build Coastguard Worker class Mock: 5076*da0073e9SAndroid Build Coastguard Worker __slots__ = ("_a",) 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5079*da0073e9SAndroid Build Coastguard Worker self._a = 2 5080*da0073e9SAndroid Build Coastguard Worker 5081*da0073e9SAndroid Build Coastguard Worker @classmethod 5082*da0073e9SAndroid Build Coastguard Worker def _m(cls): 5083*da0073e9SAndroid Build Coastguard Worker return 3 5084*da0073e9SAndroid Build Coastguard Worker 5085*da0073e9SAndroid Build Coastguard Worker def run(self, x): 5086*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) * self._a * self._m() 5087*da0073e9SAndroid Build Coastguard Worker 5088*da0073e9SAndroid Build Coastguard Worker def fn(x): 5089*da0073e9SAndroid Build Coastguard Worker mock = Mock() 5090*da0073e9SAndroid Build Coastguard Worker return mock.run(x) 5091*da0073e9SAndroid Build Coastguard Worker 5092*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5093*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 5094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 5095*da0073e9SAndroid Build Coastguard Worker 5096*da0073e9SAndroid Build Coastguard Worker def test_nn_parametrize(self): 5097*da0073e9SAndroid Build Coastguard Worker class Module(nn.Module): 5098*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5099*da0073e9SAndroid Build Coastguard Worker super().__init__() 5100*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.randn(10, 10)) 5101*da0073e9SAndroid Build Coastguard Worker 5102*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5103*da0073e9SAndroid Build Coastguard Worker return self.param @ x 5104*da0073e9SAndroid Build Coastguard Worker 5105*da0073e9SAndroid Build Coastguard Worker class Parametrization(torch.nn.Module): 5106*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5107*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 5108*da0073e9SAndroid Build Coastguard Worker 5109*da0073e9SAndroid Build Coastguard Worker m = Module() 5110*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrize.register_parametrization( 5111*da0073e9SAndroid Build Coastguard Worker m, "param", Parametrization() 5112*da0073e9SAndroid Build Coastguard Worker ) 5113*da0073e9SAndroid Build Coastguard Worker 5114*da0073e9SAndroid Build Coastguard Worker sin_found = False 5115*da0073e9SAndroid Build Coastguard Worker 5116*da0073e9SAndroid Build Coastguard Worker def backend(gm, _): 5117*da0073e9SAndroid Build Coastguard Worker nonlocal sin_found 5118*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 5119*da0073e9SAndroid Build Coastguard Worker if node.target is torch.sin: 5120*da0073e9SAndroid Build Coastguard Worker sin_found = True 5121*da0073e9SAndroid Build Coastguard Worker return gm 5122*da0073e9SAndroid Build Coastguard Worker 5123*da0073e9SAndroid Build Coastguard Worker opt_m = torch.compile(m, backend=backend, fullgraph=True) 5124*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, 10) 5125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(inp), opt_m(inp)) 5126*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sin_found) 5127*da0073e9SAndroid Build Coastguard Worker 5128*da0073e9SAndroid Build Coastguard Worker torch.nn.utils.parametrize.remove_parametrizations(m, "param") 5129*da0073e9SAndroid Build Coastguard Worker sin_found = False 5130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(inp), opt_m(inp)) 5131*da0073e9SAndroid Build Coastguard Worker self.assertFalse(sin_found) 5132*da0073e9SAndroid Build Coastguard Worker 5133*da0073e9SAndroid Build Coastguard Worker def test_nn_module_property_closure(self): 5134*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 5135*da0073e9SAndroid Build Coastguard Worker 5136*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 5137*da0073e9SAndroid Build Coastguard Worker @property 5138*da0073e9SAndroid Build Coastguard Worker def y(self): 5139*da0073e9SAndroid Build Coastguard Worker return torch.ones(10, 10) + x 5140*da0073e9SAndroid Build Coastguard Worker 5141*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5142*da0073e9SAndroid Build Coastguard Worker return x @ self.y 5143*da0073e9SAndroid Build Coastguard Worker 5144*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5145*da0073e9SAndroid Build Coastguard Worker 5146*da0073e9SAndroid Build Coastguard Worker def fn(x): 5147*da0073e9SAndroid Build Coastguard Worker return mod(x) 5148*da0073e9SAndroid Build Coastguard Worker 5149*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5150*da0073e9SAndroid Build Coastguard Worker 5151*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, 10) 5152*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(inp), opt_fn(inp)) 5153*da0073e9SAndroid Build Coastguard Worker 5154*da0073e9SAndroid Build Coastguard Worker def test_global_fn_mutation(self): 5155*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 5156*da0073e9SAndroid Build Coastguard Worker return global_fn(x) + y 5157*da0073e9SAndroid Build Coastguard Worker 5158*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 5159*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1) 5160*da0073e9SAndroid Build Coastguard Worker 5161*da0073e9SAndroid Build Coastguard Worker opt = torch.compile(foo, fullgraph=True, backend="eager") 5162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt(x, y), foo(x, y)) 5163*da0073e9SAndroid Build Coastguard Worker 5164*da0073e9SAndroid Build Coastguard Worker # Change global_fn 5165*da0073e9SAndroid Build Coastguard Worker global global_fn 5166*da0073e9SAndroid Build Coastguard Worker 5167*da0073e9SAndroid Build Coastguard Worker def new_fn(x): 5168*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 5169*da0073e9SAndroid Build Coastguard Worker 5170*da0073e9SAndroid Build Coastguard Worker global_fn = new_fn 5171*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt(x, y), foo(x, y)) 5172*da0073e9SAndroid Build Coastguard Worker 5173*da0073e9SAndroid Build Coastguard Worker # ref https://github.com/pytorch/pytorch/issues/123974 5174*da0073e9SAndroid Build Coastguard Worker def test_list_reverse(self): 5175*da0073e9SAndroid Build Coastguard Worker def ladder(x): 5176*da0073e9SAndroid Build Coastguard Worker trail = x.size(-1) 5177*da0073e9SAndroid Build Coastguard Worker assert trail > 2 5178*da0073e9SAndroid Build Coastguard Worker weights = [] 5179*da0073e9SAndroid Build Coastguard Worker for s in [trail, trail - 1, trail - 2]: 5180*da0073e9SAndroid Build Coastguard Worker weights.append(torch.ones(s, s - 1)) 5181*da0073e9SAndroid Build Coastguard Worker 5182*da0073e9SAndroid Build Coastguard Worker for w in weights: 5183*da0073e9SAndroid Build Coastguard Worker x = x @ w 5184*da0073e9SAndroid Build Coastguard Worker 5185*da0073e9SAndroid Build Coastguard Worker weights.reverse() 5186*da0073e9SAndroid Build Coastguard Worker 5187*da0073e9SAndroid Build Coastguard Worker for w in weights: 5188*da0073e9SAndroid Build Coastguard Worker x = x @ w.t() 5189*da0073e9SAndroid Build Coastguard Worker 5190*da0073e9SAndroid Build Coastguard Worker return x 5191*da0073e9SAndroid Build Coastguard Worker 5192*da0073e9SAndroid Build Coastguard Worker data = torch.randn(3, 4) 5193*da0073e9SAndroid Build Coastguard Worker opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager") 5194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(opt_ladder(data), ladder(data)) 5195*da0073e9SAndroid Build Coastguard Worker 5196*da0073e9SAndroid Build Coastguard Worker def test_trace_functional_tensor_with(self): 5197*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.fake_tensor import FakeTensorMode 5198*da0073e9SAndroid Build Coastguard Worker from torch._subclasses.functional_tensor import ( 5199*da0073e9SAndroid Build Coastguard Worker FunctionalTensor, 5200*da0073e9SAndroid Build Coastguard Worker FunctionalTensorMode, 5201*da0073e9SAndroid Build Coastguard Worker ) 5202*da0073e9SAndroid Build Coastguard Worker 5203*da0073e9SAndroid Build Coastguard Worker def f(a, tmp): 5204*da0073e9SAndroid Build Coastguard Worker a_view = a.view(-1) 5205*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5206*da0073e9SAndroid Build Coastguard Worker a.set_(tmp) 5207*da0073e9SAndroid Build Coastguard Worker a_view.mul_(2) 5208*da0073e9SAndroid Build Coastguard Worker return a + tmp 5209*da0073e9SAndroid Build Coastguard Worker 5210*da0073e9SAndroid Build Coastguard Worker fake_mode = FakeTensorMode() 5211*da0073e9SAndroid Build Coastguard Worker with FunctionalTensorMode(): 5212*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(3, 3, requires_grad=True) 5213*da0073e9SAndroid Build Coastguard Worker inp = fake_mode.from_tensor(inp, static_shapes=True) 5214*da0073e9SAndroid Build Coastguard Worker inp = FunctionalTensor.to_functional(inp) 5215*da0073e9SAndroid Build Coastguard Worker 5216*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(3, 3, requires_grad=True) 5217*da0073e9SAndroid Build Coastguard Worker tmp = fake_mode.from_tensor(tmp, static_shapes=True) 5218*da0073e9SAndroid Build Coastguard Worker tmp = FunctionalTensor.to_functional(tmp) 5219*da0073e9SAndroid Build Coastguard Worker 5220*da0073e9SAndroid Build Coastguard Worker opt_f = torch.compile(f, backend="eager") 5221*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5222*da0073e9SAndroid Build Coastguard Worker RuntimeError, "cannot mutate tensors with frozen storage" 5223*da0073e9SAndroid Build Coastguard Worker ): 5224*da0073e9SAndroid Build Coastguard Worker opt_f(inp, tmp) 5225*da0073e9SAndroid Build Coastguard Worker 5226*da0073e9SAndroid Build Coastguard Worker def test_const_dict_keyerror(self): 5227*da0073e9SAndroid Build Coastguard Worker d = {} 5228*da0073e9SAndroid Build Coastguard Worker 5229*da0073e9SAndroid Build Coastguard Worker def fn(x): 5230*da0073e9SAndroid Build Coastguard Worker try: 5231*da0073e9SAndroid Build Coastguard Worker y = d[0] 5232*da0073e9SAndroid Build Coastguard Worker except KeyError: 5233*da0073e9SAndroid Build Coastguard Worker y = 1 5234*da0073e9SAndroid Build Coastguard Worker return x + y 5235*da0073e9SAndroid Build Coastguard Worker 5236*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 5237*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 5238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(inp), opt_fn(inp)) 5239*da0073e9SAndroid Build Coastguard Worker 5240*da0073e9SAndroid Build Coastguard Worker def test_dict_tag_guard(self): 5241*da0073e9SAndroid Build Coastguard Worker class Foo: 5242*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5243*da0073e9SAndroid Build Coastguard Worker self.scalar = 10 5244*da0073e9SAndroid Build Coastguard Worker 5245*da0073e9SAndroid Build Coastguard Worker def fn(d, x): 5246*da0073e9SAndroid Build Coastguard Worker return d["a"] * d["b"] * d["c"].scalar * x 5247*da0073e9SAndroid Build Coastguard Worker 5248*da0073e9SAndroid Build Coastguard Worker foo = Foo() 5249*da0073e9SAndroid Build Coastguard Worker 5250*da0073e9SAndroid Build Coastguard Worker d = {"a": 2, "b": 3, "c": foo} 5251*da0073e9SAndroid Build Coastguard Worker 5252*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 5253*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 5254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5255*da0073e9SAndroid Build Coastguard Worker 5256*da0073e9SAndroid Build Coastguard Worker d["a"] = 4 5257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5258*da0073e9SAndroid Build Coastguard Worker 5259*da0073e9SAndroid Build Coastguard Worker # Check that recompilation happens 5260*da0073e9SAndroid Build Coastguard Worker foo.scalar = 12 5261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5262*da0073e9SAndroid Build Coastguard Worker 5263*da0073e9SAndroid Build Coastguard Worker def test_nonconst_issubclass(self): 5264*da0073e9SAndroid Build Coastguard Worker def fn(x): 5265*da0073e9SAndroid Build Coastguard Worker if issubclass(x.__class__, np.ndarray): 5266*da0073e9SAndroid Build Coastguard Worker return 1 5267*da0073e9SAndroid Build Coastguard Worker return 0 5268*da0073e9SAndroid Build Coastguard Worker 5269*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 5270*da0073e9SAndroid Build Coastguard Worker opt_fn(np.ones([3, 3])) 5271*da0073e9SAndroid Build Coastguard Worker 5272*da0073e9SAndroid Build Coastguard Worker def test_issue126128(self): 5273*da0073e9SAndroid Build Coastguard Worker def fn(): 5274*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 10) 5275*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 1) 5276*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y).sum() 5277*da0073e9SAndroid Build Coastguard Worker 5278*da0073e9SAndroid Build Coastguard Worker def fn2(): 5279*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 100) 5280*da0073e9SAndroid Build Coastguard Worker y = torch.randn(100, 10) 5281*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, y).sum() 5282*da0073e9SAndroid Build Coastguard Worker 5283*da0073e9SAndroid Build Coastguard Worker with fresh_inductor_cache(): 5284*da0073e9SAndroid Build Coastguard Worker torch.compile(fn)() 5285*da0073e9SAndroid Build Coastguard Worker 5286*da0073e9SAndroid Build Coastguard Worker torch.compile(fn2)() 5287*da0073e9SAndroid Build Coastguard Worker 5288*da0073e9SAndroid Build Coastguard Worker def test_jit_script_defaults(self): 5289*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5290*da0073e9SAndroid Build Coastguard Worker def fast_cos(x, c: float = 2.0): 5291*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) * c 5292*da0073e9SAndroid Build Coastguard Worker 5293*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 5294*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5295*da0073e9SAndroid Build Coastguard Worker super().__init__() 5296*da0073e9SAndroid Build Coastguard Worker self.fast_cos = fast_cos 5297*da0073e9SAndroid Build Coastguard Worker 5298*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5299*da0073e9SAndroid Build Coastguard Worker return self.fast_cos(x) 5300*da0073e9SAndroid Build Coastguard Worker 5301*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5302*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 5303*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4) 5304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod(x), opt_mod(x)) 5305*da0073e9SAndroid Build Coastguard Worker 5306*da0073e9SAndroid Build Coastguard Worker def test_enum(self): 5307*da0073e9SAndroid Build Coastguard Worker class ExplicitEnum(str, Enum): 5308*da0073e9SAndroid Build Coastguard Worker @classmethod 5309*da0073e9SAndroid Build Coastguard Worker def _missing_(cls, value): 5310*da0073e9SAndroid Build Coastguard Worker raise ValueError( 5311*da0073e9SAndroid Build Coastguard Worker f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 5312*da0073e9SAndroid Build Coastguard Worker ) 5313*da0073e9SAndroid Build Coastguard Worker 5314*da0073e9SAndroid Build Coastguard Worker class PaddingStrategy(ExplicitEnum): 5315*da0073e9SAndroid Build Coastguard Worker LONGEST = "longest" 5316*da0073e9SAndroid Build Coastguard Worker MAX_LENGTH = "max_length" 5317*da0073e9SAndroid Build Coastguard Worker DO_NOT_PAD = "do_not_pad" 5318*da0073e9SAndroid Build Coastguard Worker 5319*da0073e9SAndroid Build Coastguard Worker def fn(x): 5320*da0073e9SAndroid Build Coastguard Worker a = PaddingStrategy("longest") 5321*da0073e9SAndroid Build Coastguard Worker if a == PaddingStrategy.LONGEST: 5322*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 5323*da0073e9SAndroid Build Coastguard Worker return torch.cos(x) 5324*da0073e9SAndroid Build Coastguard Worker 5325*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 5326*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), opt_fn(x)) 5328*da0073e9SAndroid Build Coastguard Worker 5329*da0073e9SAndroid Build Coastguard Worker def test_hasattr_builtin(self): 5330*da0073e9SAndroid Build Coastguard Worker class MyClass: 5331*da0073e9SAndroid Build Coastguard Worker foo: int = 1 5332*da0073e9SAndroid Build Coastguard Worker 5333*da0073e9SAndroid Build Coastguard Worker def func(x, m): 5334*da0073e9SAndroid Build Coastguard Worker if getattr(type(m), "foo", 0): 5335*da0073e9SAndroid Build Coastguard Worker return x + MyClass.foo 5336*da0073e9SAndroid Build Coastguard Worker return x 5337*da0073e9SAndroid Build Coastguard Worker 5338*da0073e9SAndroid Build Coastguard Worker opt_func = torch.compile(func, backend="eager", fullgraph=True) 5339*da0073e9SAndroid Build Coastguard Worker m = MyClass() 5340*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(()) 5341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x, m), opt_func(x, m)) 5342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x, 0), opt_func(x, 0)) 5343*da0073e9SAndroid Build Coastguard Worker 5344*da0073e9SAndroid Build Coastguard Worker def test_grad(self): 5345*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 5346*da0073e9SAndroid Build Coastguard Worker x._grad = y 5347*da0073e9SAndroid Build Coastguard Worker return x.grad.data 5348*da0073e9SAndroid Build Coastguard Worker 5349*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 5350*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4) 5351*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 5352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x, y), opt_fn(x, y)) 5353*da0073e9SAndroid Build Coastguard Worker 5354*da0073e9SAndroid Build Coastguard Worker def test_nn_module_stack_bc(self): 5355*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.mutation_guard import GenerationTracker 5356*da0073e9SAndroid Build Coastguard Worker 5357*da0073e9SAndroid Build Coastguard Worker def compiler(gm, *args): 5358*da0073e9SAndroid Build Coastguard Worker module_stacks = [ 5359*da0073e9SAndroid Build Coastguard Worker node.meta.get("nn_module_stack", None) for node in gm.graph.nodes 5360*da0073e9SAndroid Build Coastguard Worker ] 5361*da0073e9SAndroid Build Coastguard Worker module_stacks, _ = pytree.tree_flatten(module_stacks) 5362*da0073e9SAndroid Build Coastguard Worker module_stacks = [x for x in module_stacks if isinstance(x, str)] 5363*da0073e9SAndroid Build Coastguard Worker for stack in module_stacks: 5364*da0073e9SAndroid Build Coastguard Worker self.assertTrue("_module" not in stack) 5365*da0073e9SAndroid Build Coastguard Worker return gm.forward 5366*da0073e9SAndroid Build Coastguard Worker 5367*da0073e9SAndroid Build Coastguard Worker class SubMod(torch.nn.Module): 5368*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5369*da0073e9SAndroid Build Coastguard Worker super().__init__() 5370*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(2, 2) 5371*da0073e9SAndroid Build Coastguard Worker 5372*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5373*da0073e9SAndroid Build Coastguard Worker return self.linear(x) 5374*da0073e9SAndroid Build Coastguard Worker 5375*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 5376*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5377*da0073e9SAndroid Build Coastguard Worker super().__init__() 5378*da0073e9SAndroid Build Coastguard Worker self.submod1 = SubMod() 5379*da0073e9SAndroid Build Coastguard Worker self.submod2 = SubMod() 5380*da0073e9SAndroid Build Coastguard Worker 5381*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5382*da0073e9SAndroid Build Coastguard Worker return self.submod1(x) + self.submod2(x) 5383*da0073e9SAndroid Build Coastguard Worker 5384*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5385*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend=compiler) 5386*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.randn(2, 2)) 5387*da0073e9SAndroid Build Coastguard Worker 5388*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): 5389*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5390*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend=compiler) 5391*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.randn(2, 2)) 5392*da0073e9SAndroid Build Coastguard Worker 5393*da0073e9SAndroid Build Coastguard Worker # an example similar to Pippy usecase 5394*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5395*da0073e9SAndroid Build Coastguard Worker GenerationTracker.tag(mod.submod1) 5396*da0073e9SAndroid Build Coastguard Worker GenerationTracker.mark_class_dynamic(type(mod.submod1)) 5397*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5398*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend=compiler) 5399*da0073e9SAndroid Build Coastguard Worker opt_mod(torch.randn(2, 2)) 5400*da0073e9SAndroid Build Coastguard Worker 5401*da0073e9SAndroid Build Coastguard Worker def test_is_make_fx_tracing(self): 5402*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 5403*da0073e9SAndroid Build Coastguard Worker def fn(x): 5404*da0073e9SAndroid Build Coastguard Worker torch.nn.modules.activation._is_make_fx_tracing() 5405*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 5406*da0073e9SAndroid Build Coastguard Worker 5407*da0073e9SAndroid Build Coastguard Worker fn(torch.rand(4)) 5408*da0073e9SAndroid Build Coastguard Worker 5409*da0073e9SAndroid Build Coastguard Worker def test_negative_floor_div_solve(self): 5410*da0073e9SAndroid Build Coastguard Worker class CompiledClass(nn.Module): 5411*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5412*da0073e9SAndroid Build Coastguard Worker super().__init__() 5413*da0073e9SAndroid Build Coastguard Worker self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) 5414*da0073e9SAndroid Build Coastguard Worker self.t = 5 5415*da0073e9SAndroid Build Coastguard Worker 5416*da0073e9SAndroid Build Coastguard Worker def forward(self): 5417*da0073e9SAndroid Build Coastguard Worker self.num = self.nums[self.t // 12] 5418*da0073e9SAndroid Build Coastguard Worker self.t += 1 5419*da0073e9SAndroid Build Coastguard Worker return self.num 5420*da0073e9SAndroid Build Coastguard Worker 5421*da0073e9SAndroid Build Coastguard Worker m = CompiledClass() 5422*da0073e9SAndroid Build Coastguard Worker m = torch.compile(m, backend="eager") 5423*da0073e9SAndroid Build Coastguard Worker 5424*da0073e9SAndroid Build Coastguard Worker # the first call works 5425*da0073e9SAndroid Build Coastguard Worker m() 5426*da0073e9SAndroid Build Coastguard Worker # the second call causes a failure 5427*da0073e9SAndroid Build Coastguard Worker m() 5428*da0073e9SAndroid Build Coastguard Worker 5429*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/121621 5430*da0073e9SAndroid Build Coastguard Worker def test_tensor_random(self): 5431*da0073e9SAndroid Build Coastguard Worker def random_op(tensor, params): 5432*da0073e9SAndroid Build Coastguard Worker res = tensor.random_(**params) 5433*da0073e9SAndroid Build Coastguard Worker return res 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker random_op = torch.compile(random_op) 5436*da0073e9SAndroid Build Coastguard Worker params = {"from": -10, "to": 10} 5437*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn([2, 3]) 5438*da0073e9SAndroid Build Coastguard Worker res = random_op(tensor, params) 5439*da0073e9SAndroid Build Coastguard Worker 5440*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/131019 5441*da0073e9SAndroid Build Coastguard Worker def test_tensor_uniform(self): 5442*da0073e9SAndroid Build Coastguard Worker def uniform_op(tensor, params): 5443*da0073e9SAndroid Build Coastguard Worker res = tensor.uniform_(**params) 5444*da0073e9SAndroid Build Coastguard Worker return res 5445*da0073e9SAndroid Build Coastguard Worker 5446*da0073e9SAndroid Build Coastguard Worker uniform_op = torch.compile(uniform_op) 5447*da0073e9SAndroid Build Coastguard Worker params = {"from": -10, "to": 10} 5448*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn([2, 3]) 5449*da0073e9SAndroid Build Coastguard Worker res = uniform_op(tensor, params) 5450*da0073e9SAndroid Build Coastguard Worker 5451*da0073e9SAndroid Build Coastguard Worker def test_data_attr_mutation_after_saved_for_bw(self): 5452*da0073e9SAndroid Build Coastguard Worker def f(x): 5453*da0073e9SAndroid Build Coastguard Worker out = x.sin() 5454*da0073e9SAndroid Build Coastguard Worker x.data.mul_(2) 5455*da0073e9SAndroid Build Coastguard Worker return out 5456*da0073e9SAndroid Build Coastguard Worker 5457*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 5458*da0073e9SAndroid Build Coastguard Worker x_test = x.clone().detach().requires_grad_(True) 5459*da0073e9SAndroid Build Coastguard Worker 5460*da0073e9SAndroid Build Coastguard Worker out = f(x) 5461*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(f, backend="aot_eager")(x_test) 5462*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out_test) 5463*da0073e9SAndroid Build Coastguard Worker 5464*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 5465*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 5466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_test.grad) 5467*da0073e9SAndroid Build Coastguard Worker 5468*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/128072 5469*da0073e9SAndroid Build Coastguard Worker def test_map_with_multiple_args(self): 5470*da0073e9SAndroid Build Coastguard Worker def f(a, b): 5471*da0073e9SAndroid Build Coastguard Worker return a[0] * b[0] + a[1] * b[1] 5472*da0073e9SAndroid Build Coastguard Worker 5473*da0073e9SAndroid Build Coastguard Worker def gen_inps(len_x, len_y): 5474*da0073e9SAndroid Build Coastguard Worker x = [torch.randn(5) for _ in range(len_x)] 5475*da0073e9SAndroid Build Coastguard Worker y = [torch.randn(5) for _ in range(len_y)] 5476*da0073e9SAndroid Build Coastguard Worker return x, y 5477*da0073e9SAndroid Build Coastguard Worker 5478*da0073e9SAndroid Build Coastguard Worker def g(x, y): 5479*da0073e9SAndroid Build Coastguard Worker return map(f, x, y) 5480*da0073e9SAndroid Build Coastguard Worker 5481*da0073e9SAndroid Build Coastguard Worker opt_g = torch.compile(g, fullgraph=True, backend="eager") 5482*da0073e9SAndroid Build Coastguard Worker 5483*da0073e9SAndroid Build Coastguard Worker inps = gen_inps(3, 3) 5484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(g(*inps)), type(opt_g(*inps))) 5485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) 5486*da0073e9SAndroid Build Coastguard Worker 5487*da0073e9SAndroid Build Coastguard Worker inps = gen_inps(3, 5) 5488*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(g(*inps)), type(opt_g(*inps))) 5489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) 5490*da0073e9SAndroid Build Coastguard Worker 5491*da0073e9SAndroid Build Coastguard Worker def test_staticmethod_allow_in_graph(self): 5492*da0073e9SAndroid Build Coastguard Worker class MyClass: 5493*da0073e9SAndroid Build Coastguard Worker i = 3 5494*da0073e9SAndroid Build Coastguard Worker 5495*da0073e9SAndroid Build Coastguard Worker @staticmethod 5496*da0073e9SAndroid Build Coastguard Worker def foo_inner(x): 5497*da0073e9SAndroid Build Coastguard Worker return torch.mul(x, MyClass.i) 5498*da0073e9SAndroid Build Coastguard Worker 5499*da0073e9SAndroid Build Coastguard Worker # if dynamo inlines with fullgraph, will error 5500*da0073e9SAndroid Build Coastguard Worker # verify that dynamo doesn't inline 5501*da0073e9SAndroid Build Coastguard Worker @staticmethod 5502*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.allow_in_graph 5503*da0073e9SAndroid Build Coastguard Worker def foo1(x): 5504*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 5505*da0073e9SAndroid Build Coastguard Worker return MyClass.foo_inner(x) 5506*da0073e9SAndroid Build Coastguard Worker 5507*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 5508*da0073e9SAndroid Build Coastguard Worker def f_bad(x): 5509*da0073e9SAndroid Build Coastguard Worker return MyClass.foo1(x) 5510*da0073e9SAndroid Build Coastguard Worker 5511*da0073e9SAndroid Build Coastguard Worker f_bad(torch.ones(2, 2)) 5512*da0073e9SAndroid Build Coastguard Worker 5513*da0073e9SAndroid Build Coastguard Worker def test_guard_with_tuple_mutation(self): 5514*da0073e9SAndroid Build Coastguard Worker class Foo: 5515*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 5516*da0073e9SAndroid Build Coastguard Worker self.x = 10 5517*da0073e9SAndroid Build Coastguard Worker 5518*da0073e9SAndroid Build Coastguard Worker foo = Foo() 5519*da0073e9SAndroid Build Coastguard Worker d = { 5520*da0073e9SAndroid Build Coastguard Worker "a": 2, 5521*da0073e9SAndroid Build Coastguard Worker "b": (foo,), 5522*da0073e9SAndroid Build Coastguard Worker } 5523*da0073e9SAndroid Build Coastguard Worker 5524*da0073e9SAndroid Build Coastguard Worker def fn(x, d): 5525*da0073e9SAndroid Build Coastguard Worker return x * d["a"] * d["b"][0].x 5526*da0073e9SAndroid Build Coastguard Worker 5527*da0073e9SAndroid Build Coastguard Worker opt_fn = torch.compile(fn, backend="eager") 5528*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 5529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(inp, d), opt_fn(inp, d)) 5530*da0073e9SAndroid Build Coastguard Worker d["b"][0].x = 12 5531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(inp, d), opt_fn(inp, d)) 5532*da0073e9SAndroid Build Coastguard Worker 5533*da0073e9SAndroid Build Coastguard Worker def test_compile_complex_conj(self): 5534*da0073e9SAndroid Build Coastguard Worker def f(x): 5535*da0073e9SAndroid Build Coastguard Worker return torch.mul(x, 2j) 5536*da0073e9SAndroid Build Coastguard Worker 5537*da0073e9SAndroid Build Coastguard Worker x_ref = torch.randn(4, 2, requires_grad=True) 5538*da0073e9SAndroid Build Coastguard Worker x_test = x_ref.clone().detach().requires_grad_(True) 5539*da0073e9SAndroid Build Coastguard Worker 5540*da0073e9SAndroid Build Coastguard Worker out_ref = f(torch.view_as_complex(x_ref)) 5541*da0073e9SAndroid Build Coastguard Worker out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test)) 5542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 5543*da0073e9SAndroid Build Coastguard Worker 5544*da0073e9SAndroid Build Coastguard Worker torch.view_as_real(out_ref).sum().backward() 5545*da0073e9SAndroid Build Coastguard Worker torch.view_as_real(out_test).sum().backward() 5546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_ref.grad, x_test.grad) 5547*da0073e9SAndroid Build Coastguard Worker 5548*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/132200 5549*da0073e9SAndroid Build Coastguard Worker def test_partitioner_cse_respects_mutation_boundaries(self): 5550*da0073e9SAndroid Build Coastguard Worker set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") 5551*da0073e9SAndroid Build Coastguard Worker if not set_available: 5552*da0073e9SAndroid Build Coastguard Worker return 5553*da0073e9SAndroid Build Coastguard Worker 5554*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager_decomp_partition") 5555*da0073e9SAndroid Build Coastguard Worker def f(x, l): 5556*da0073e9SAndroid Build Coastguard Worker # z0 and z1 can be CSEd 5557*da0073e9SAndroid Build Coastguard Worker z0 = x.sin() 5558*da0073e9SAndroid Build Coastguard Worker z1 = x.sin() 5559*da0073e9SAndroid Build Coastguard Worker y = x + 1 5560*da0073e9SAndroid Build Coastguard Worker torch.ops.fsdp.set_.default(x, y) 5561*da0073e9SAndroid Build Coastguard Worker # z3 and z3 can be CSEd with each other, 5562*da0073e9SAndroid Build Coastguard Worker # but *not* with z0/z1 (they cross a mutation boundary) 5563*da0073e9SAndroid Build Coastguard Worker z2 = x.sin() 5564*da0073e9SAndroid Build Coastguard Worker z3 = x.sin() 5565*da0073e9SAndroid Build Coastguard Worker return z0, z1, z2, z3, l**2 5566*da0073e9SAndroid Build Coastguard Worker 5567*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 5568*da0073e9SAndroid Build Coastguard Worker x_clone = x.clone() 5569*da0073e9SAndroid Build Coastguard Worker l = torch.randn(3, requires_grad=True) 5570*da0073e9SAndroid Build Coastguard Worker z0, z1, z2, z3, _ = f(x, l) 5571*da0073e9SAndroid Build Coastguard Worker 5572*da0073e9SAndroid Build Coastguard Worker # the partitioner runs CSE. We expect that of the 4 sin() ops above: 5573*da0073e9SAndroid Build Coastguard Worker # - the first 2 are CSE'd 5574*da0073e9SAndroid Build Coastguard Worker # - the last 2 are CSE'd 5575*da0073e9SAndroid Build Coastguard Worker # - the set_() op in the middle is a mutation barrier, preventing CSE 5576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z0, (x_clone).sin()) 5577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z1, (x_clone).sin()) 5578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z2, (x_clone + 1).sin()) 5579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z3, (x_clone + 1).sin()) 5580*da0073e9SAndroid Build Coastguard Worker 5581*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/132197 5582*da0073e9SAndroid Build Coastguard Worker def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self): 5583*da0073e9SAndroid Build Coastguard Worker set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") 5584*da0073e9SAndroid Build Coastguard Worker if not set_available: 5585*da0073e9SAndroid Build Coastguard Worker return 5586*da0073e9SAndroid Build Coastguard Worker 5587*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="aot_eager_decomp_partition") 5588*da0073e9SAndroid Build Coastguard Worker def f(x, l): 5589*da0073e9SAndroid Build Coastguard Worker z = x.sin() 5590*da0073e9SAndroid Build Coastguard Worker y = x + 1 5591*da0073e9SAndroid Build Coastguard Worker # graph input has its storage mutated 5592*da0073e9SAndroid Build Coastguard Worker torch.ops.fsdp.set_.default(x, y) 5593*da0073e9SAndroid Build Coastguard Worker z2 = x.sin() 5594*da0073e9SAndroid Build Coastguard Worker return z2, l**2 5595*da0073e9SAndroid Build Coastguard Worker 5596*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 5597*da0073e9SAndroid Build Coastguard Worker x_test = x.clone() 5598*da0073e9SAndroid Build Coastguard Worker l = torch.randn(3, requires_grad=True) 5599*da0073e9SAndroid Build Coastguard Worker result, _ = f(x, l) 5600*da0073e9SAndroid Build Coastguard Worker result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")( 5601*da0073e9SAndroid Build Coastguard Worker x_test, l 5602*da0073e9SAndroid Build Coastguard Worker ) 5603*da0073e9SAndroid Build Coastguard Worker 5604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, result_test) 5605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_test) 5606*da0073e9SAndroid Build Coastguard Worker 5607*da0073e9SAndroid Build Coastguard Worker def test_changing_stride(self): 5608*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 5609*da0073e9SAndroid Build Coastguard Worker 5610*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt) 5611*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 5612*da0073e9SAndroid Build Coastguard Worker return x * y 5613*da0073e9SAndroid Build Coastguard Worker 5614*da0073e9SAndroid Build Coastguard Worker for i in range(1, 4): 5615*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, i) 5616*da0073e9SAndroid Build Coastguard Worker 5617*da0073e9SAndroid Build Coastguard Worker # create a view for i > 1 5618*da0073e9SAndroid Build Coastguard Worker if i == 1: 5619*da0073e9SAndroid Build Coastguard Worker x1 = x 5620*da0073e9SAndroid Build Coastguard Worker else: 5621*da0073e9SAndroid Build Coastguard Worker x1 = x[:, 0:1] 5622*da0073e9SAndroid Build Coastguard Worker 5623*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 1) 5624*da0073e9SAndroid Build Coastguard Worker print(x1.shape, y.shape) 5625*da0073e9SAndroid Build Coastguard Worker fn(x1, y) 5626*da0073e9SAndroid Build Coastguard Worker 5627*da0073e9SAndroid Build Coastguard Worker self.assertTrue(cnt.frame_count <= 2) 5628*da0073e9SAndroid Build Coastguard Worker 5629*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(guard_nn_modules=False) 5630*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) 5631*da0073e9SAndroid Build Coastguard Worker def test_inlining_cornercase(self): 5632*da0073e9SAndroid Build Coastguard Worker """ 5633*da0073e9SAndroid Build Coastguard Worker nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the 5634*da0073e9SAndroid Build Coastguard Worker tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs. 5635*da0073e9SAndroid Build Coastguard Worker 5636*da0073e9SAndroid Build Coastguard Worker But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is 5637*da0073e9SAndroid Build Coastguard Worker UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of 5638*da0073e9SAndroid Build Coastguard Worker guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also 5639*da0073e9SAndroid Build Coastguard Worker cudagraphs which make it harder to fix the corner case right away. The long term solution is 5640*da0073e9SAndroid Build Coastguard Worker inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term. 5641*da0073e9SAndroid Build Coastguard Worker 5642*da0073e9SAndroid Build Coastguard Worker We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked 5643*da0073e9SAndroid Build Coastguard Worker as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this 5644*da0073e9SAndroid Build Coastguard Worker changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior. 5645*da0073e9SAndroid Build Coastguard Worker 5646*da0073e9SAndroid Build Coastguard Worker 5647*da0073e9SAndroid Build Coastguard Worker To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on. 5648*da0073e9SAndroid Build Coastguard Worker 5649*da0073e9SAndroid Build Coastguard Worker In this test, we purposely turn the flag off, testing that the tagging is disabled. 5650*da0073e9SAndroid Build Coastguard Worker """ 5651*da0073e9SAndroid Build Coastguard Worker 5652*da0073e9SAndroid Build Coastguard Worker class SubMod(torch.nn.Module): 5653*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5654*da0073e9SAndroid Build Coastguard Worker super().__init__() 5655*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(1, 1) 5656*da0073e9SAndroid Build Coastguard Worker self.a = torch.randn(1, 1) 5657*da0073e9SAndroid Build Coastguard Worker self.counter = 0 5658*da0073e9SAndroid Build Coastguard Worker self.multipliers = [2.2, 3.3] 5659*da0073e9SAndroid Build Coastguard Worker 5660*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5661*da0073e9SAndroid Build Coastguard Worker self.counter += 1 5662*da0073e9SAndroid Build Coastguard Worker return ( 5663*da0073e9SAndroid Build Coastguard Worker self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1] 5664*da0073e9SAndroid Build Coastguard Worker ) 5665*da0073e9SAndroid Build Coastguard Worker 5666*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 5667*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5668*da0073e9SAndroid Build Coastguard Worker super().__init__() 5669*da0073e9SAndroid Build Coastguard Worker self.submod = SubMod() 5670*da0073e9SAndroid Build Coastguard Worker 5671*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 5672*da0073e9SAndroid Build Coastguard Worker return self.submod(x) 5673*da0073e9SAndroid Build Coastguard Worker 5674*da0073e9SAndroid Build Coastguard Worker mod = Mod() 5675*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="eager") 5676*da0073e9SAndroid Build Coastguard Worker 5677*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 1) 5678*da0073e9SAndroid Build Coastguard Worker ref = mod(x) 5679*da0073e9SAndroid Build Coastguard Worker res = opt_mod(x) 5680*da0073e9SAndroid Build Coastguard Worker 5681*da0073e9SAndroid Build Coastguard Worker mod.submod.multipliers = [3.3, 4.4] 5682*da0073e9SAndroid Build Coastguard Worker # Since guard_nn_modules is False, this will not recompile 5683*da0073e9SAndroid Build Coastguard Worker with torch._dynamo.config.patch(error_on_recompile=True): 5684*da0073e9SAndroid Build Coastguard Worker ref = mod(x) 5685*da0073e9SAndroid Build Coastguard Worker res = opt_mod(x) 5686*da0073e9SAndroid Build Coastguard Worker 5687*da0073e9SAndroid Build Coastguard Worker def test_optimized_module_training(self): 5688*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Linear(3, 3) 5689*da0073e9SAndroid Build Coastguard Worker mod.eval() 5690*da0073e9SAndroid Build Coastguard Worker 5691*da0073e9SAndroid Build Coastguard Worker opt_mod = torch.compile(mod, backend="eager") 5692*da0073e9SAndroid Build Coastguard Worker self.assertFalse(opt_mod.training) 5693*da0073e9SAndroid Build Coastguard Worker 5694*da0073e9SAndroid Build Coastguard Worker opt_mod.train() 5695*da0073e9SAndroid Build Coastguard Worker self.assertTrue(opt_mod.training) 5696*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mod.training) 5697*da0073e9SAndroid Build Coastguard Worker 5698*da0073e9SAndroid Build Coastguard Worker mod.eval() 5699*da0073e9SAndroid Build Coastguard Worker self.assertFalse(opt_mod.training) 5700*da0073e9SAndroid Build Coastguard Worker 5701*da0073e9SAndroid Build Coastguard Worker @requires_cuda 5702*da0073e9SAndroid Build Coastguard Worker def test_memleak_when_graph_input_has_tensor_attr(self): 5703*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 5704*da0073e9SAndroid Build Coastguard Worker def f(x): 5705*da0073e9SAndroid Build Coastguard Worker x.add_(1) 5706*da0073e9SAndroid Build Coastguard Worker 5707*da0073e9SAndroid Build Coastguard Worker mem_before = torch.cuda.memory_allocated() 5708*da0073e9SAndroid Build Coastguard Worker 5709*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, device="cuda") 5710*da0073e9SAndroid Build Coastguard Worker x.foo = torch.zeros(2, device="cuda") 5711*da0073e9SAndroid Build Coastguard Worker f(x) 5712*da0073e9SAndroid Build Coastguard Worker del x.foo 5713*da0073e9SAndroid Build Coastguard Worker del x 5714*da0073e9SAndroid Build Coastguard Worker mem_after = torch.cuda.memory_allocated() 5715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mem_before, mem_after) 5716*da0073e9SAndroid Build Coastguard Worker 5717*da0073e9SAndroid Build Coastguard Worker # check when non-tensor data structure attribute contains a tensor 5718*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 5719*da0073e9SAndroid Build Coastguard Worker def f(x): 5720*da0073e9SAndroid Build Coastguard Worker x.add_(1) 5721*da0073e9SAndroid Build Coastguard Worker 5722*da0073e9SAndroid Build Coastguard Worker mem_before = torch.cuda.memory_allocated() 5723*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, device="cuda") 5724*da0073e9SAndroid Build Coastguard Worker x.foo = [torch.zeros(2, device="cuda") for _ in range(5)] 5725*da0073e9SAndroid Build Coastguard Worker f(x) 5726*da0073e9SAndroid Build Coastguard Worker del x.foo 5727*da0073e9SAndroid Build Coastguard Worker del x 5728*da0073e9SAndroid Build Coastguard Worker mem_after = torch.cuda.memory_allocated() 5729*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mem_before, mem_after) 5730*da0073e9SAndroid Build Coastguard Worker 5731*da0073e9SAndroid Build Coastguard Worker # check with tensor refcycle 5732*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager") 5733*da0073e9SAndroid Build Coastguard Worker def g(x, y): 5734*da0073e9SAndroid Build Coastguard Worker return x + y 5735*da0073e9SAndroid Build Coastguard Worker 5736*da0073e9SAndroid Build Coastguard Worker mem_before = torch.cuda.memory_allocated() 5737*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, device="cuda") 5738*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(2, device="cuda") 5739*da0073e9SAndroid Build Coastguard Worker x.foo = [y] 5740*da0073e9SAndroid Build Coastguard Worker y.foo = [x] 5741*da0073e9SAndroid Build Coastguard Worker g(x, y) 5742*da0073e9SAndroid Build Coastguard Worker del x.foo 5743*da0073e9SAndroid Build Coastguard Worker del y.foo 5744*da0073e9SAndroid Build Coastguard Worker del x 5745*da0073e9SAndroid Build Coastguard Worker del y 5746*da0073e9SAndroid Build Coastguard Worker mem_after = torch.cuda.memory_allocated() 5747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mem_before, mem_after) 5748*da0073e9SAndroid Build Coastguard Worker 5749*da0073e9SAndroid Build Coastguard Worker def test_os_fspath(self): 5750*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 5751*da0073e9SAndroid Build Coastguard Worker def fn(x): 5752*da0073e9SAndroid Build Coastguard Worker os.fspath(".") 5753*da0073e9SAndroid Build Coastguard Worker return torch.sin(x) 5754*da0073e9SAndroid Build Coastguard Worker 5755*da0073e9SAndroid Build Coastguard Worker fn(torch.randn(4)) 5756*da0073e9SAndroid Build Coastguard Worker 5757*da0073e9SAndroid Build Coastguard Worker @requires_cuda 5758*da0073e9SAndroid Build Coastguard Worker # This test will fail as flip in combination with particular input lenghts 5759*da0073e9SAndroid Build Coastguard Worker # produces weird results. 5760*da0073e9SAndroid Build Coastguard Worker # This is under investigations in 5761*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/131805 5762*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Skip this flip test for the moment. It is under investigation") 5763*da0073e9SAndroid Build Coastguard Worker def test_flip_bad_accuracy(self): 5764*da0073e9SAndroid Build Coastguard Worker import torch 5765*da0073e9SAndroid Build Coastguard Worker import torch._dynamo.config 5766*da0073e9SAndroid Build Coastguard Worker import torch._functorch.config 5767*da0073e9SAndroid Build Coastguard Worker import torch._inductor.config 5768*da0073e9SAndroid Build Coastguard Worker import torch._inductor.inductor_prims 5769*da0073e9SAndroid Build Coastguard Worker import torch.fx.experimental._config 5770*da0073e9SAndroid Build Coastguard Worker 5771*da0073e9SAndroid Build Coastguard Worker class Repro(torch.nn.Module): 5772*da0073e9SAndroid Build Coastguard Worker def __init__(self): 5773*da0073e9SAndroid Build Coastguard Worker super().__init__() 5774*da0073e9SAndroid Build Coastguard Worker 5775*da0073e9SAndroid Build Coastguard Worker def forward(self, arg0_1): 5776*da0073e9SAndroid Build Coastguard Worker rev = torch.ops.prims.rev.default(arg0_1, [0]) 5777*da0073e9SAndroid Build Coastguard Worker arg0_1 = None 5778*da0073e9SAndroid Build Coastguard Worker slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2) 5779*da0073e9SAndroid Build Coastguard Worker slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2) 5780*da0073e9SAndroid Build Coastguard Worker add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2) 5781*da0073e9SAndroid Build Coastguard Worker slice_1 = slice_2 = None 5782*da0073e9SAndroid Build Coastguard Worker slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2) 5783*da0073e9SAndroid Build Coastguard Worker slice_4 = torch.ops.aten.slice.Tensor( 5784*da0073e9SAndroid Build Coastguard Worker add_1, 0, 1, 9223372036854775807, 2 5785*da0073e9SAndroid Build Coastguard Worker ) 5786*da0073e9SAndroid Build Coastguard Worker add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4) 5787*da0073e9SAndroid Build Coastguard Worker slice_3 = slice_4 = None 5788*da0073e9SAndroid Build Coastguard Worker slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2) 5789*da0073e9SAndroid Build Coastguard Worker slice_6 = torch.ops.aten.slice.Tensor( 5790*da0073e9SAndroid Build Coastguard Worker add_2, 0, 1, 9223372036854775807, 2 5791*da0073e9SAndroid Build Coastguard Worker ) 5792*da0073e9SAndroid Build Coastguard Worker add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6) 5793*da0073e9SAndroid Build Coastguard Worker slice_5 = slice_6 = None 5794*da0073e9SAndroid Build Coastguard Worker slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1) 5795*da0073e9SAndroid Build Coastguard Worker add_2 = None 5796*da0073e9SAndroid Build Coastguard Worker unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1) 5797*da0073e9SAndroid Build Coastguard Worker slice_9 = None 5798*da0073e9SAndroid Build Coastguard Worker unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1) 5799*da0073e9SAndroid Build Coastguard Worker add_3 = None 5800*da0073e9SAndroid Build Coastguard Worker cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) 5801*da0073e9SAndroid Build Coastguard Worker unsqueeze = unsqueeze_1 = None 5802*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.view.default(cat, [2]) 5803*da0073e9SAndroid Build Coastguard Worker cat = None 5804*da0073e9SAndroid Build Coastguard Worker slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1) 5805*da0073e9SAndroid Build Coastguard Worker slice_11 = torch.ops.aten.slice.Tensor( 5806*da0073e9SAndroid Build Coastguard Worker add_1, 0, 2, 9223372036854775807, 2 5807*da0073e9SAndroid Build Coastguard Worker ) 5808*da0073e9SAndroid Build Coastguard Worker add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11) 5809*da0073e9SAndroid Build Coastguard Worker slice_10 = slice_11 = None 5810*da0073e9SAndroid Build Coastguard Worker slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1) 5811*da0073e9SAndroid Build Coastguard Worker add_1 = None 5812*da0073e9SAndroid Build Coastguard Worker cat_1 = torch.ops.aten.cat.default([slice_12, add_5]) 5813*da0073e9SAndroid Build Coastguard Worker slice_12 = add_5 = None 5814*da0073e9SAndroid Build Coastguard Worker unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1) 5815*da0073e9SAndroid Build Coastguard Worker cat_1 = None 5816*da0073e9SAndroid Build Coastguard Worker unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1) 5817*da0073e9SAndroid Build Coastguard Worker view = None 5818*da0073e9SAndroid Build Coastguard Worker cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1) 5819*da0073e9SAndroid Build Coastguard Worker unsqueeze_2 = unsqueeze_3 = None 5820*da0073e9SAndroid Build Coastguard Worker view_1 = torch.ops.aten.view.default(cat_2, [4]) 5821*da0073e9SAndroid Build Coastguard Worker cat_2 = None 5822*da0073e9SAndroid Build Coastguard Worker slice_13 = torch.ops.aten.slice.Tensor( 5823*da0073e9SAndroid Build Coastguard Worker rev, 0, 2, 9223372036854775807, 2 5824*da0073e9SAndroid Build Coastguard Worker ) 5825*da0073e9SAndroid Build Coastguard Worker add_6 = torch.ops.aten.add.Tensor(view_1, slice_13) 5826*da0073e9SAndroid Build Coastguard Worker slice_13 = None 5827*da0073e9SAndroid Build Coastguard Worker slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1) 5828*da0073e9SAndroid Build Coastguard Worker rev = None 5829*da0073e9SAndroid Build Coastguard Worker cat_3 = torch.ops.aten.cat.default([slice_14, add_6]) 5830*da0073e9SAndroid Build Coastguard Worker slice_14 = add_6 = None 5831*da0073e9SAndroid Build Coastguard Worker constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 5832*da0073e9SAndroid Build Coastguard Worker view_1, [0, 1], 0.0 5833*da0073e9SAndroid Build Coastguard Worker ) 5834*da0073e9SAndroid Build Coastguard Worker view_1 = None 5835*da0073e9SAndroid Build Coastguard Worker unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1) 5836*da0073e9SAndroid Build Coastguard Worker cat_3 = None 5837*da0073e9SAndroid Build Coastguard Worker unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1) 5838*da0073e9SAndroid Build Coastguard Worker constant_pad_nd = None 5839*da0073e9SAndroid Build Coastguard Worker cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1) 5840*da0073e9SAndroid Build Coastguard Worker unsqueeze_4 = unsqueeze_5 = None 5841*da0073e9SAndroid Build Coastguard Worker view_2 = torch.ops.aten.view.default(cat_4, [10]) 5842*da0073e9SAndroid Build Coastguard Worker cat_4 = None 5843*da0073e9SAndroid Build Coastguard Worker slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9) 5844*da0073e9SAndroid Build Coastguard Worker view_2 = None 5845*da0073e9SAndroid Build Coastguard Worker rev_1 = torch.ops.prims.rev.default(slice_15, [0]) 5846*da0073e9SAndroid Build Coastguard Worker slice_15 = None 5847*da0073e9SAndroid Build Coastguard Worker return (rev_1,) 5848*da0073e9SAndroid Build Coastguard Worker 5849*da0073e9SAndroid Build Coastguard Worker mod = Repro() 5850*da0073e9SAndroid Build Coastguard Worker x = torch.arange(9, device=torch.device("cuda")) 5851*da0073e9SAndroid Build Coastguard Worker 5852*da0073e9SAndroid Build Coastguard Worker @torch.compile 5853*da0073e9SAndroid Build Coastguard Worker def f(x): 5854*da0073e9SAndroid Build Coastguard Worker return mod(x) 5855*da0073e9SAndroid Build Coastguard Worker 5856*da0073e9SAndroid Build Coastguard Worker out = f(x) 5857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0]) 5858*da0073e9SAndroid Build Coastguard Worker 5859*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/88813 5860*da0073e9SAndroid Build Coastguard Worker def test_return_value_duplication_tensor(self) -> None: 5861*da0073e9SAndroid Build Coastguard Worker def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5862*da0073e9SAndroid Build Coastguard Worker return val * 2, val * 2 5863*da0073e9SAndroid Build Coastguard Worker 5864*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) 5865*da0073e9SAndroid Build Coastguard Worker 5866*da0073e9SAndroid Build Coastguard Worker expect = fn(x) 5867*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 5868*da0073e9SAndroid Build Coastguard Worker expect[0].untyped_storage().data_ptr(), 5869*da0073e9SAndroid Build Coastguard Worker expect[1].untyped_storage().data_ptr(), 5870*da0073e9SAndroid Build Coastguard Worker ) 5871*da0073e9SAndroid Build Coastguard Worker 5872*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(fn, backend="aot_eager")(x) 5873*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 5874*da0073e9SAndroid Build Coastguard Worker actual[0].untyped_storage().data_ptr(), 5875*da0073e9SAndroid Build Coastguard Worker actual[1].untyped_storage().data_ptr(), 5876*da0073e9SAndroid Build Coastguard Worker ) 5877*da0073e9SAndroid Build Coastguard Worker 5878*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/114344 5879*da0073e9SAndroid Build Coastguard Worker def test_return_value_duplication_mixed_grad(self) -> None: 5880*da0073e9SAndroid Build Coastguard Worker def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5881*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5882*da0073e9SAndroid Build Coastguard Worker out0 = val + 1 5883*da0073e9SAndroid Build Coastguard Worker out1 = val + 1 5884*da0073e9SAndroid Build Coastguard Worker return out0, out1 5885*da0073e9SAndroid Build Coastguard Worker 5886*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) 5887*da0073e9SAndroid Build Coastguard Worker 5888*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 5889*da0073e9SAndroid Build Coastguard Worker expect = fn(x) 5890*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(fn, backend="aot_eager")(x) 5891*da0073e9SAndroid Build Coastguard Worker 5892*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect[0].requires_grad, actual[0].requires_grad) 5893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expect[1].requires_grad, actual[1].requires_grad) 5894*da0073e9SAndroid Build Coastguard Worker 5895*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371 5896*da0073e9SAndroid Build Coastguard Worker def test_return_value_duplication_scalar(self) -> None: 5897*da0073e9SAndroid Build Coastguard Worker def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5898*da0073e9SAndroid Build Coastguard Worker x, y = val * 2, val * 2 5899*da0073e9SAndroid Build Coastguard Worker return x[0], y[0] 5900*da0073e9SAndroid Build Coastguard Worker 5901*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) 5902*da0073e9SAndroid Build Coastguard Worker 5903*da0073e9SAndroid Build Coastguard Worker expect = fn(x) 5904*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 5905*da0073e9SAndroid Build Coastguard Worker expect[0].untyped_storage().data_ptr(), 5906*da0073e9SAndroid Build Coastguard Worker expect[1].untyped_storage().data_ptr(), 5907*da0073e9SAndroid Build Coastguard Worker ) 5908*da0073e9SAndroid Build Coastguard Worker 5909*da0073e9SAndroid Build Coastguard Worker actual = torch.compile(fn, backend="aot_eager")(x) 5910*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 5911*da0073e9SAndroid Build Coastguard Worker actual[0].untyped_storage().data_ptr(), 5912*da0073e9SAndroid Build Coastguard Worker actual[1].untyped_storage().data_ptr(), 5913*da0073e9SAndroid Build Coastguard Worker ) 5914*da0073e9SAndroid Build Coastguard Worker 5915*da0073e9SAndroid Build Coastguard Worker 5916*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(ReproTests) 5917*da0073e9SAndroid Build Coastguard Worker 5918*da0073e9SAndroid Build Coastguard Worker 5919*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 5920*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 5921*da0073e9SAndroid Build Coastguard Worker 5922*da0073e9SAndroid Build Coastguard Worker run_tests() 5923