1""" 2PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes 3with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg) 4""" 5 6# Owner(s): ["module: dynamo"] 7import collections 8import contextlib 9import copy 10import dataclasses 11import functools 12import gc 13import inspect 14import itertools 15import os 16import random 17import unittest 18import warnings 19import weakref 20from abc import ABC 21from collections import namedtuple 22from copy import deepcopy 23from enum import Enum 24from functools import wraps 25from typing import Any, Dict, Iterator, List, Tuple 26from unittest import mock 27 28import numpy as np 29 30import torch 31import torch._dynamo.test_case 32import torch._dynamo.testing 33import torch._dynamo.utils 34import torch._functorch.config 35import torch.library 36import torch.utils._pytree as pytree 37from torch import nn 38from torch._dynamo.debug_utils import same_two_models 39from torch._dynamo.testing import CompileCounter, rand_strided, same 40from torch._inductor.utils import fresh_inductor_cache 41from torch.nn import functional as F 42from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION 43from torch.testing._internal.common_utils import ( 44 disable_translation_validation_if_dynamic_shapes, 45 instantiate_parametrized_tests, 46 parametrize, 47 skipIfWindows, 48 TEST_WITH_ROCM, 49) 50from torch.testing._internal.two_tensor import TwoTensor 51 52 53_orig_module_call = torch.nn.Module.__call__ 54 55# Custom operator that only supports CPU and Meta 56lib = torch.library.Library("test_sample", "DEF") # noqa: TOR901 57lib.define("foo(Tensor self) -> Tensor") 58lib.impl("foo", torch.sin, "CPU") 59 60 61requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") 62 63 64_GLOBAL_CPU_TENSOR = torch.randn(3) 65 66 67def exists(val): 68 return val is not None 69 70 71def maybe(fn): 72 @wraps(fn) 73 def inner(x, *args, **kwargs): 74 if not exists(x): 75 return x 76 return fn(x, *args, **kwargs) 77 78 return inner 79 80 81def is_fx_tracing_test() -> bool: 82 """ 83 Copied from the hpc trainer codebase 84 """ 85 return torch.nn.Module.__call__ is not _orig_module_call 86 87 88def has_detectron2(): 89 try: 90 from detectron2.layers.mask_ops import _paste_masks_tensor_shape 91 92 return _paste_masks_tensor_shape is not None 93 except ImportError: 94 return False 95 96 97def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): 98 # from detectron2 mask_ops.py 99 100 device = masks.device 101 102 if skip_empty and not torch.jit.is_scripting(): 103 x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( 104 dtype=torch.int32 105 ) 106 x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to( 107 dtype=torch.int32 108 ) 109 y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to( 110 dtype=torch.int32 111 ) 112 else: 113 x0_int, y0_int = 0, 0 114 x1_int, y1_int = img_w, img_h 115 x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 116 117 N = masks.shape[0] 118 119 img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 120 img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 121 img_y = (img_y - y0) / (y1 - y0) * 2 - 1 122 img_x = (img_x - x0) / (x1 - x0) * 2 - 1 123 # img_x, img_y have shapes (N, w), (N, h) 124 125 gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) 126 gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) 127 grid = torch.stack([gx, gy], dim=3) 128 129 if not torch.jit.is_scripting(): 130 if not masks.dtype.is_floating_point: 131 masks = masks.float() 132 img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) 133 134 if skip_empty and not torch.jit.is_scripting(): 135 return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) 136 else: 137 return img_masks[:, 0], () 138 139 140def global_fn(x): 141 return torch.sin(x) 142 143 144def cat(tensors, dim=0): 145 # from detectron2 wrappers.py 146 assert isinstance(tensors, (list, tuple)) 147 if len(tensors) == 1: 148 return tensors[0] 149 return torch.cat(tensors, dim) 150 151 152def shapes_to_tensor(x, device=None): 153 # from detectron2 wrappers.py 154 if torch.jit.is_scripting(): 155 return torch.as_tensor(x, device=device) 156 if torch.jit.is_tracing(): 157 assert all( 158 isinstance(t, torch.Tensor) for t in x 159 ), "Shape should be tensor during tracing!" 160 # as_tensor should not be used in tracing because it records a constant 161 ret = torch.stack(x) 162 if ret.device != device: # avoid recording a hard-coded device if not necessary 163 ret = ret.to(device=device) 164 return ret 165 return torch.as_tensor(x, device=device) 166 167 168fw_graph = [None] 169bw_graph = [None] 170 171 172def aot_graph_capture_backend(gm, args): 173 from functorch.compile import min_cut_rematerialization_partition 174 from torch._functorch.aot_autograd import aot_module_simplified 175 176 def fw_compiler(gm, _): 177 fw_graph[0] = gm 178 return gm 179 180 def bw_compiler(gm, _): 181 bw_graph[0] = gm 182 return gm 183 184 return aot_module_simplified( 185 gm, 186 args, 187 fw_compiler, 188 bw_compiler, 189 partition_fn=min_cut_rematerialization_partition, 190 keep_inference_input_mutations=True, 191 ) 192 193 194class Boxes: 195 # from detectron2 poolers.py 196 def __init__(self, tensor: torch.Tensor): 197 """ 198 Args: 199 tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). 200 """ 201 device = ( 202 tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") 203 ) 204 tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) 205 if tensor.numel() == 0: 206 # Use reshape, so we don't end up creating a new tensor that does not depend on 207 # the inputs (and consequently confuses jit) 208 tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device) 209 assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() 210 self.tensor = tensor 211 212 def __len__(self) -> int: 213 return self.tensor.shape[0] 214 215 @property 216 def device(self): 217 return self.tensor.device 218 219 220def convert_boxes_to_pooler_format(box_lists): 221 # from detectron2 structures.py 222 boxes = torch.cat([x.tensor for x in box_lists], dim=0) 223 # __len__ returns Tensor in tracing. 224 sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device) 225 indices = torch.repeat_interleave( 226 torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes 227 ) 228 return cat([indices[:, None], boxes], dim=1) 229 230 231ReformerBackwardOutput = namedtuple( 232 "ReformerBackwardOutput", 233 ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"], 234) 235ReformerEncoderOutput = namedtuple( 236 "ReformerEncoderOutput", 237 ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], 238) 239 240 241class _ReversibleFunction(torch.autograd.Function): 242 # taken from modeling_reformer.py in huggingface 243 @staticmethod 244 def forward( 245 ctx, 246 hidden_states, 247 layers, 248 attention_mask, 249 head_mask, 250 num_hashes, 251 all_hidden_states, 252 all_attentions, 253 past_buckets_states, 254 use_cache, 255 orig_sequence_length, 256 output_hidden_states, 257 output_attentions, 258 ): 259 all_buckets = () 260 261 # split duplicated tensor 262 hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) 263 264 for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): 265 if output_hidden_states is True: 266 all_hidden_states.append(hidden_states) 267 268 attn_output = layer(attn_output) 269 all_buckets = all_buckets + (attn_output,) 270 271 # Add last layer 272 if output_hidden_states is True: 273 all_hidden_states.append(hidden_states) 274 275 # attach params to ctx for backward 276 ctx.save_for_backward(attn_output.detach(), hidden_states.detach()) 277 ctx.layers = layers 278 ctx.all_buckets = all_buckets 279 ctx.head_mask = head_mask 280 ctx.attention_mask = attention_mask 281 282 # Concatenate 2 RevNet outputs 283 return torch.cat([attn_output, hidden_states], dim=-1) 284 285 @staticmethod 286 def backward(ctx, grad_hidden_states): 287 grad_attn_output, grad_hidden_states = torch.chunk( 288 grad_hidden_states, 2, dim=-1 289 ) 290 291 # free memory 292 del grad_attn_output 293 294 # num of return vars has to match num of forward() args 295 # return gradient for hidden_states arg and None for other args 296 return ( 297 grad_hidden_states, 298 None, 299 None, 300 None, 301 None, 302 None, 303 None, 304 None, 305 None, 306 None, 307 None, 308 None, 309 ) 310 311 312class ReformerEncoder(torch.nn.Module): 313 def __init__(self) -> None: 314 super().__init__() 315 self.dropout = 0.5 316 self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12) 317 self.layers = [torch.nn.Linear(256, 256)] 318 319 def forward( 320 self, 321 hidden_states, 322 attention_mask=None, 323 head_mask=[None] * 6, 324 num_hashes=None, 325 use_cache=False, 326 orig_sequence_length=64, 327 output_hidden_states=False, 328 output_attentions=False, 329 ): 330 # hidden_states and attention lists to be filled if wished 331 all_hidden_states = [] 332 all_attentions = [] 333 past_buckets_states = [((None), (None)) for i in range(len(self.layers))] 334 335 # concat same tensor for reversible ResNet 336 hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) 337 hidden_states = _ReversibleFunction.apply( 338 hidden_states, 339 self.layers, 340 attention_mask, 341 head_mask, 342 num_hashes, 343 all_hidden_states, 344 all_attentions, 345 past_buckets_states, 346 use_cache, 347 orig_sequence_length, 348 output_hidden_states, 349 output_attentions, 350 ) 351 352 # Apply layer norm to concatenated hidden states 353 hidden_states = self.layer_norm(hidden_states) 354 355 # Apply dropout 356 hidden_states = torch.nn.functional.dropout( 357 hidden_states, p=self.dropout, training=self.training 358 ) 359 360 return ReformerEncoderOutput( 361 hidden_states=hidden_states, 362 all_hidden_states=all_hidden_states, 363 all_attentions=all_attentions, 364 past_buckets_states=past_buckets_states, 365 ) 366 367 368class ListConfig: 369 class ValueNode: 370 def __init__(self, value): 371 self.value = value 372 373 def _dereference_node(self): 374 return self 375 376 def _is_missing(self): 377 return False 378 379 def _value(self): 380 return self.value 381 382 # Based on an example from omegaconfig.listconfig 383 class ListIterator(Iterator[Any]): 384 def __init__(self, lst: Any, resolve: bool) -> None: 385 self.resolve = resolve 386 self.iterator = iter(lst.__dict__["_content"]) 387 self.index = 0 388 389 def __next__(self) -> Any: 390 x = next(self.iterator) 391 if self.resolve: 392 x = x._dereference_node() 393 if x._is_missing(): 394 raise AssertionError 395 396 self.index = self.index + 1 397 if isinstance(x, ListConfig.ValueNode): 398 return x._value() 399 raise AssertionError 400 401 def __iter__(self): 402 return self._iter_ex(True) 403 404 def _iter_ex(self, resolve: bool) -> Iterator[Any]: 405 try: 406 return ListConfig.ListIterator(self, resolve) 407 except Exception: 408 raise AssertionError from None 409 410 def __init__(self) -> None: 411 self._content = [ 412 ListConfig.ValueNode(1), 413 ListConfig.ValueNode(3), 414 ListConfig.ValueNode(torch.tensor([7.0])), 415 ] 416 417 418def longformer_chunk(hidden_states, window_overlap=256): 419 """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" 420 421 # non-overlapping chunks of size = 2w 422 hidden_states = hidden_states.view( 423 hidden_states.size(0), 424 hidden_states.size(1) // (window_overlap * 2), 425 window_overlap * 2, 426 hidden_states.size(2), 427 ) 428 429 # use `as_strided` to make the chunks overlap with an overlap size = window_overlap 430 chunk_size = list(hidden_states.size()) 431 chunk_size[1] = chunk_size[1] * 2 - 1 432 433 chunk_stride = list(hidden_states.stride()) 434 chunk_stride[1] = chunk_stride[1] // 2 435 return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) 436 437 438class PartialT5(torch.nn.Module): 439 # Highly simplified T5Attention prefix 440 def __init__(self) -> None: 441 super().__init__() 442 self.q = torch.nn.Linear(512, 512) 443 self.k = torch.nn.Linear(512, 512) 444 self.v = torch.nn.Linear(512, 512) 445 446 def forward( 447 self, 448 hidden_states, 449 key_value_states=None, 450 past_key_value=None, 451 query_length=None, 452 ): 453 batch_size, seq_length = hidden_states.shape[:2] 454 455 real_seq_length = seq_length 456 457 if past_key_value is not None: 458 assert ( 459 len(past_key_value) == 2 460 ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 461 real_seq_length += ( 462 past_key_value[0].shape[2] if query_length is None else query_length 463 ) 464 465 def shape(states): 466 """projection""" 467 return states.view(batch_size, -1, 8, 64).transpose(1, 2) 468 469 def project(hidden_states, proj_layer, key_value_states, past_key_value): 470 """projects hidden states correctly to key/query states""" 471 if key_value_states is None: 472 # self-attn 473 # (batch_size, n_heads, seq_length, dim_per_head) 474 hidden_states = shape(proj_layer(hidden_states)) 475 elif past_key_value is None: 476 # cross-attn 477 # (batch_size, n_heads, seq_length, dim_per_head) 478 hidden_states = shape(proj_layer(key_value_states)) 479 480 if past_key_value is not None: 481 if key_value_states is None: 482 # self-attn 483 # (batch_size, n_heads, key_length, dim_per_head) 484 hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 485 else: 486 # cross-attn 487 hidden_states = past_key_value 488 return hidden_states 489 490 # get query states 491 query_states = shape( 492 self.q(hidden_states) 493 ) # (batch_size, n_heads, seq_length, dim_per_head) 494 495 # get key/value states 496 key_states = project( 497 hidden_states, 498 self.k, 499 key_value_states, 500 past_key_value[0] if past_key_value is not None else None, 501 ) 502 value_states = project( 503 hidden_states, 504 self.v, 505 key_value_states, 506 past_key_value[1] if past_key_value is not None else None, 507 ) 508 509 # compute scores 510 scores = torch.matmul(query_states, key_states.transpose(3, 2)) 511 512 # (truncated here ) 513 return scores, value_states 514 515 516class ChunkReformerFeedForward(torch.nn.Module): 517 # simplified from HF modeling_reformer.py 518 def __init__(self) -> None: 519 super().__init__() 520 self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12) 521 self.dense = torch.nn.Linear(256, 256) 522 self.output = torch.nn.Linear(256, 256) 523 524 def forward(self, attention_output): 525 return apply_chunking_to_forward( 526 self.forward_chunk, 527 attention_output + 1, 528 ) 529 530 def forward_chunk(self, hidden_states): 531 hidden_states = self.layer_norm(hidden_states) 532 hidden_states = self.dense(hidden_states) 533 return self.output(hidden_states) 534 535 536def apply_chunking_to_forward(forward_fn, *input_tensors): 537 # simplified from HF model_utils.py 538 assert len(input_tensors) > 0 539 tensor_shape = input_tensors[0].shape[1] 540 assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors) 541 num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters) 542 if num_args_in_forward_chunk_fn != len(input_tensors): 543 raise ValueError 544 545 return forward_fn(*input_tensors) 546 547 548def _validate_model_kwargs(fn, model_kwargs): 549 # simplified from transformers.generation.utils._validate_model_kwargs 550 unused_model_args = [] 551 model_args = set(inspect.signature(fn).parameters) 552 for key, value in model_kwargs.items(): 553 if value is not None and key not in model_args: 554 unused_model_args.append(key) 555 if unused_model_args: 556 raise ValueError( 557 f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the" 558 " generate arguments will also show up in this list)" 559 ) 560 561 562class FakeMamlInner(torch.nn.Module): 563 def __init__(self) -> None: 564 super().__init__() 565 self.linear = torch.nn.Linear(784, 5) 566 567 def forward(self, x, ignored=None, bn_training=False): 568 return self.linear(x.view(x.shape[0], -1)) 569 570 571class PartialMaml(torch.nn.Module): 572 # Highly simplified version of maml.meta.Meta.finetuning 573 def __init__(self) -> None: 574 super().__init__() 575 self.net = FakeMamlInner() 576 self.update_step_test = 10 577 self.update_lr = 0.4 578 579 def forward(self, x_spt, y_spt, x_qry, y_qry): 580 querysz = x_qry.size(0) 581 582 corrects = [0 for _ in range(self.update_step_test + 1)] 583 584 # in order to not ruin the state of running_mean/variance and bn_weight/bias 585 # we finetuning on the copied model instead of self.net 586 net = deepcopy(self.net) 587 588 # 1. run the i-th task and compute loss for k=0 589 logits = net(x_spt) 590 loss = F.cross_entropy(logits, y_spt) 591 grad = torch.autograd.grad(loss, net.parameters()) 592 fast_weights = [ 593 p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters()) 594 ] 595 596 # this is the loss and accuracy before first update 597 with torch.no_grad(): 598 # [setsz, nway] 599 logits_q = net(x_qry, net.parameters(), bn_training=True) 600 # [setsz] 601 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 602 # scalar 603 correct = torch.eq(pred_q, y_qry).sum().item() 604 corrects[0] = corrects[0] + correct 605 606 # this is the loss and accuracy after the first update 607 with torch.no_grad(): 608 # [setsz, nway] 609 logits_q = net(x_qry, fast_weights, bn_training=True) 610 # [setsz] 611 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1) 612 # scalar 613 correct = torch.eq(pred_q, y_qry).sum().item() 614 corrects[1] = corrects[1] + correct 615 616 del net 617 618 accs = torch.tensor(corrects) / querysz 619 620 return accs 621 622 623def softmax_backward_data(parent, grad_output, output, dim, self): 624 from torch import _softmax_backward_data 625 626 return _softmax_backward_data(grad_output, output, parent.dim, self.dtype) 627 628 629class XSoftmax(torch.autograd.Function): 630 # transformers.models.deberta.modeling_deberta.XSoftmax 631 @staticmethod 632 def forward(self, input, mask, dim): 633 self.dim = dim 634 rmask = ~(mask.to(torch.bool)) 635 output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min)) 636 output = torch.softmax(output, self.dim) 637 output.masked_fill_(rmask, 0) 638 self.save_for_backward(output, rmask) 639 return output 640 641 @staticmethod 642 def backward(self, grad_output): 643 (output, rmask) = self.saved_tensors 644 inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output) 645 return inputGrad, None, None 646 647 648class ModelOutput(collections.OrderedDict): 649 """based on file_utils.py in HuggingFace""" 650 651 def __getitem__(self, k): 652 if isinstance(k, str): 653 inner_dict = dict(self.items()) 654 return inner_dict[k] 655 else: 656 return self.to_tuple()[k] 657 658 def __setattr__(self, name, value): 659 if name in self.keys() and value is not None: 660 # Don't call self.__setitem__ to avoid recursion errors 661 super().__setitem__(name, value) 662 super().__setattr__(name, value) 663 664 def __setitem__(self, key, value): 665 # Will raise a KeyException if needed 666 super().__setitem__(key, value) 667 # Don't call self.__setattr__ to avoid recursion errors 668 super().__setattr__(key, value) 669 670 def to_tuple(self): 671 return tuple(self[k] for k in self.keys()) 672 673 674def create_rand_mask_from_inputs( 675 from_blocked_mask, 676 to_blocked_mask, 677 rand_attn, 678 num_attention_heads, 679 num_rand_blocks, 680 batch_size, 681 from_seq_length, 682 from_block_size, 683): 684 """taken from HF modeling_big_bird.py""" 685 num_windows = from_seq_length // from_block_size - 2 686 rand_mask = torch.stack( 687 [p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)] 688 ) 689 rand_mask = rand_mask.view( 690 batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size 691 ) 692 rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask) 693 return rand_mask 694 695 696class SequentialAppendList(torch.nn.Sequential): 697 """from timm/models/vovnet.py""" 698 699 def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor: 700 for i, module in enumerate(self): 701 if i == 0: 702 concat_list.append(module(x)) 703 else: 704 concat_list.append(module(concat_list[-1])) 705 x = torch.cat(concat_list, dim=1) 706 return x, concat_list 707 708 709class BatchNormAct2d(torch.nn.BatchNorm2d): 710 """Taken from timm""" 711 712 def __init__( 713 self, 714 num_features, 715 eps=1e-5, 716 momentum=0.1, 717 affine=True, 718 track_running_stats=True, 719 act_layer=torch.nn.ReLU, 720 inplace=True, 721 ): 722 super().__init__( 723 num_features, 724 eps=eps, 725 momentum=momentum, 726 affine=affine, 727 track_running_stats=track_running_stats, 728 ) 729 self.act = act_layer(inplace=inplace) 730 731 @torch.jit.ignore 732 def _forward_python(self, x): 733 return super().forward(x) 734 735 def forward(self, x): 736 if torch.jit.is_scripting(): 737 x = self._forward_jit(x) 738 else: 739 x = self._forward_python(x) 740 x = self.act(x) 741 return x 742 743 744def get_parameter_dtype(parameter): 745 """from huggingface model_utils.py""" 746 try: 747 return next(parameter.parameters()).dtype 748 except StopIteration: 749 # For nn.DataParallel compatibility in PyTorch 1.5 750 751 def find_tensor_attributes(module): 752 tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 753 return tuples 754 755 gen = parameter._named_members(get_members_fn=find_tensor_attributes) 756 first_tuple = next(gen) 757 return first_tuple[1].dtype 758 759 760class DummyConfig: 761 attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"] 762 lsh_attn_chunk_length = 64 763 local_attn_chunk_length = 64 764 765 766def _get_min_chunk_len(config): 767 """from hf_Reformer""" 768 attn_types = config.attn_layers 769 attn_types_set = set(attn_types) 770 if len(attn_types_set) == 1 and attn_types[0] == "lsh": 771 return config.lsh_attn_chunk_length 772 elif len(attn_types_set) == 1 and attn_types[0] == "local": 773 return config.local_attn_chunk_length 774 elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}: 775 return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) 776 else: 777 raise NotImplementedError( 778 f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select " 779 "attn layer types from ['lsh', 'local'] only." 780 ) 781 782 783def _stable_argsort(vector, dim): 784 """from hf_Reformer""" 785 # this function scales the vector so that torch.argsort is stable. 786 # torch.argsort is not stable on its own 787 scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) 788 scale_offset = scale_offset.expand(vector.shape) 789 scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) 790 return torch.argsort(scaled_vector, dim=dim) 791 792 793def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets): 794 """from hf_Reformer""" 795 # no gradients are needed 796 with torch.no_grad(): 797 # hash-based sort 798 sorted_bucket_idx = _stable_argsort(buckets, dim=-1) 799 800 # create simple indices to scatter to, to have undo sort 801 indices = ( 802 torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device) 803 .view(1, 1, -1) 804 .expand(sorted_bucket_idx.shape) 805 ) 806 807 # get undo sort 808 undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size()) 809 undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices) 810 811 return sorted_bucket_idx, undo_sorted_bucket_idx 812 813 814class CustomList1(list): 815 def __call__(self, x): 816 for processor in self: 817 x = processor(x) 818 return x 819 820 def clear(self): 821 pass # this prevents RestrictedListSubclassVariable from kicking in 822 823 824class CustomList2(list): 825 def __call__(self, x): 826 for processor in self: 827 x = processor(x) 828 return x 829 830 def length_times_10(self): 831 return len(self) * 10 832 833 def append_twice(self, x): 834 self.extend([x, x]) 835 836 837def _merge_criteria_processor_list(default_list, custom_list): 838 # simplified transformers/generation/utils.py 839 if len(custom_list) == 0: 840 return default_list 841 for default in default_list: 842 for custom in custom_list: 843 if type(custom) is type(default): 844 raise ValueError 845 default_list.extend(custom_list) 846 return default_list 847 848 849class FeedForwardLayer(nn.Module): 850 def __init__(self, d_model, dim_feedforward, activation, dropout) -> None: 851 super().__init__() 852 self.linear1 = nn.Linear(d_model, dim_feedforward) 853 self.activation = activation 854 self.dropout1 = nn.Dropout(dropout) 855 self.linear2 = nn.Linear(dim_feedforward, d_model) 856 self.dropout2 = nn.Dropout(dropout) 857 858 def forward(self, x): 859 return self.dropout2( 860 self.linear2(self.dropout1(self.activation(self.linear1(x)))) 861 ) 862 863 864class TransformerEncoderLayer(nn.Module): 865 def __init__( 866 self, 867 d_model, 868 nhead, 869 dim_feedforward=2048, 870 dropout=0.1, 871 activation=nn.ReLU(), 872 layer_norm_eps=1e-5, 873 ): 874 super().__init__() 875 self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 876 self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 877 self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 878 self.dropout = nn.Dropout(dropout) 879 self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout) 880 881 def forward(self, src, src_mask=None, src_key_padding_mask=None): 882 x = src 883 x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) 884 x = self.norm2(x + self._ff_block(x)) 885 return x 886 887 # self-attention block 888 def _sa_block(self, x, attn_mask, key_padding_mask): 889 x = self.self_attn( 890 x, 891 x, 892 x, 893 attn_mask=attn_mask, 894 key_padding_mask=key_padding_mask, 895 need_weights=False, 896 )[0] 897 return self.dropout(x) 898 899 # feed forward block 900 def _ff_block(self, x): 901 return self.ff_block(x) 902 903 904class MockModule(torch.nn.Module): 905 def inner_fn(self, left, right): 906 return tuple(left) == tuple(right) 907 908 def fn(self, tensor): 909 if type(tensor) is int: 910 return False 911 912 torch.add(tensor, tensor) 913 return self.inner_fn(tensor.shape, (1, 2, 3)) 914 915 916class IncByOne: 917 def __init__(self, x): 918 self.x = x + 1 919 920 921class IncByTwo: 922 def __init__(self, x): 923 self.x = x + 2 924 925 926class ReproTests(torch._dynamo.test_case.TestCase): 927 def test_do_paste_mask(self): 928 torch._dynamo.utils.counters.clear() 929 cnt = torch._dynamo.testing.CompileCounter() 930 opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt) 931 opt__do_paste_mask( 932 torch.randn(1, 1, 28, 28), 933 torch.tensor([[0.0, 1, 2, 4]]) * 1, 934 427, 935 640, 936 True, 937 ) 938 opt__do_paste_mask( 939 torch.randn(1, 1, 28, 28), 940 torch.tensor([[0.0, 1, 2, 4]]) * 2, 941 427, 942 640, 943 True, 944 ) 945 opt__do_paste_mask( 946 torch.randn(1, 1, 28, 28), 947 torch.tensor([[0.0, 1, 2, 4]]) * 3, 948 612, 949 612, 950 True, 951 ) 952 opt__do_paste_mask( 953 torch.randn(1, 1, 28, 28), 954 torch.tensor([[0.0, 1, 2, 4]]) * 4, 955 612, 956 612, 957 True, 958 ) 959 opt__do_paste_mask( 960 torch.randn(1, 1, 28, 28), 961 torch.tensor([[0.0, 1, 2, 4]]) * 2, 962 427, 963 640, 964 False, 965 ) 966 # (dynamic shapes, static shapes) 967 self.assertIn(cnt.frame_count, (5, 7)) 968 self.assertIn(cnt.op_count, (92, 106, 119)) 969 970 def test_convert_boxes_to_pooler_format(self): 971 boxes1 = [ 972 Boxes(torch.arange(0, 8).reshape((2, 4))), 973 Boxes(torch.arange(8, 16).reshape((2, 4))), 974 ] 975 boxes2 = [ 976 Boxes(torch.arange(16, 20).reshape((1, 4))), 977 Boxes(torch.arange(20, 24).reshape((1, 4))), 978 ] 979 correct1 = convert_boxes_to_pooler_format(boxes1) 980 correct2 = convert_boxes_to_pooler_format(boxes2) 981 fn = convert_boxes_to_pooler_format 982 cnt = torch._dynamo.testing.CompileCounter() 983 opt_fn = torch._dynamo.optimize(cnt)(fn) 984 self.assertTrue(same(opt_fn(boxes1), correct1)) 985 self.assertTrue(same(opt_fn(boxes2), correct2)) 986 987 # repeat_interleave is a dynamic shape operator we do not execute/ 988 # In the future, we could reduce the frame_count down to 1 989 # by guarding on the exact values of `Tensor repeats` arg 990 if torch._dynamo.config.assume_static_by_default: 991 self.assertExpectedInline(cnt.frame_count, """4""") 992 self.assertExpectedInline(cnt.op_count, """10""") 993 else: 994 self.assertExpectedInline(cnt.frame_count, """4""") 995 self.assertExpectedInline(cnt.op_count, """14""") 996 997 def test_boxes_len(self): 998 def fn(boxes): 999 return len(boxes) + boxes.__len__() + boxes.tensor 1000 1001 boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4))) 1002 cnt = torch._dynamo.testing.CompileCounter() 1003 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1004 self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0)) 1005 1006 if torch._dynamo.config.assume_static_by_default: 1007 self.assertExpectedInline(cnt.frame_count, """1""") 1008 self.assertExpectedInline(cnt.op_count, """1""") 1009 else: 1010 self.assertExpectedInline(cnt.frame_count, """1""") 1011 self.assertExpectedInline(cnt.op_count, """2""") 1012 1013 def _reformer(self, nopython): 1014 input = torch.randn([1, 64, 256]) 1015 model = ReformerEncoder() 1016 torch.manual_seed(1337) 1017 correct = copy.deepcopy(model)(input) 1018 cnt = torch._dynamo.testing.CompileCounter() 1019 torch.manual_seed(1337) 1020 opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model) 1021 self.assertTrue(same(opt_model(input), correct)) 1022 return cnt 1023 1024 @requires_cuda 1025 def test_sub_alpha_scalar_repro(self): 1026 @torch.compile(backend="aot_eager") 1027 def f(x): 1028 return x.sub(1, alpha=2) 1029 1030 f(torch.ones(2, device="cuda", dtype=torch.float64)) 1031 1032 # https://github.com/pytorch/pytorch/issues/113010 1033 def test_out_overload_non_contiguous(self): 1034 def f(x, y): 1035 return torch.abs(x, out=y.T) 1036 1037 f_compiled = torch.compile(f, backend="aot_eager") 1038 1039 x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1040 y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1041 x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1042 y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1043 1044 out_ref = f(x_ref, y_ref) 1045 out_test = f_compiled(x_test, y_test) 1046 self.assertEqual(out_ref, out_test) 1047 self.assertEqual(y_ref, y_test) 1048 1049 # https://github.com/pytorch/pytorch/issues/109053 1050 def test_view_dtype_overload(self): 1051 def f(x): 1052 return x.view(torch.int32) 1053 1054 f_compiled = torch.compile(f, backend="aot_eager") 1055 1056 x1 = torch.ones(4, requires_grad=True) 1057 out_ref = f(x1) 1058 out_test = f_compiled(x1) 1059 self.assertEqual(out_ref, out_test) 1060 1061 x2 = torch.ones(4, requires_grad=False) 1062 out_ref = f(x2) 1063 out_test = f_compiled(x2) 1064 self.assertEqual(out_ref, out_test) 1065 1066 # https://github.com/pytorch/pytorch/issues/90552 1067 def test_intermediate_leaf_requires_grad(self): 1068 def f(x): 1069 leaf = torch.ones(2, requires_grad=True) 1070 return leaf, leaf * 2 1071 1072 f_compiled = torch.compile(f, backend="aot_eager") 1073 x = torch.arange(4, dtype=torch.float32).reshape(2, 2) 1074 1075 leaf, out = f(x) 1076 leaf_test, out_test = f_compiled(x) 1077 out.sum().backward() 1078 out_test.sum().backward() 1079 self.assertEqual(leaf.grad, leaf_test.grad) 1080 1081 # https://github.com/pytorch/pytorch/issues/113263 1082 def test_unpack_hooks_dont_run_during_tracing(self): 1083 def f(x, y): 1084 return x * y 1085 1086 f_compiled = torch.compile(f, backend="aot_eager") 1087 1088 pack_count = 0 1089 unpack_count = 0 1090 1091 def pack_hook(x): 1092 nonlocal pack_count 1093 pack_count += 1 1094 return x 1095 1096 # unpack hook shouldn't run during compilation, while we trace the forward 1097 def unpack_hook(x): 1098 nonlocal unpack_count 1099 unpack_count += 1 1100 return x 1101 1102 x = torch.ones(4, requires_grad=True) 1103 y = torch.ones(4, requires_grad=False) 1104 with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 1105 out_test = f_compiled(x, y) 1106 self.assertEqual(pack_count, 1) 1107 self.assertEqual(unpack_count, 0) 1108 out_test.sum().backward() 1109 self.assertEqual(pack_count, 1) 1110 self.assertEqual(unpack_count, 1) 1111 1112 # https://github.com/pytorch/pytorch/issues/113263 1113 def test_unpack_hooks_can_be_disabled(self): 1114 def f(x, y): 1115 return x * y 1116 1117 f_compiled = torch.compile(f, backend="aot_eager") 1118 1119 x = torch.ones(4, requires_grad=True) 1120 y = torch.ones(4, requires_grad=False) 1121 with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): 1122 out_test = f_compiled(x, y) 1123 out_test.sum().backward() 1124 1125 # https://github.com/pytorch/pytorch/issues/113263 1126 def test_disabling_unpack_hooks_within_compiled_region(self): 1127 def g(z): 1128 with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"): 1129 return z + 5 1130 1131 def f(x, y): 1132 z = x * y 1133 return g(z) 1134 1135 f_compiled = torch.compile(f, backend="aot_eager") 1136 1137 x = torch.ones(4, requires_grad=True) 1138 y = torch.ones(4, requires_grad=False) 1139 out_test = f_compiled(x, y) 1140 out_test.sum().backward() 1141 1142 # See https://github.com/pytorch/pytorch/issues/97745 1143 def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self): 1144 def f(a, b): 1145 c = torch.ones(2, 2) 1146 d = torch.ones(2, 2) 1147 e = torch.matmul(a, c) 1148 g_loss = torch.abs(e - d).mean() 1149 g_loss.backward() 1150 fake_d_pred = torch.matmul(b, e.detach()) 1151 d_loss = fake_d_pred.mean() 1152 d_loss.backward() 1153 1154 a_ref = torch.randn(2, 2, requires_grad=True) 1155 b_ref = torch.randn(2, 2, requires_grad=True) 1156 out_ref = f(a_ref, b_ref) 1157 1158 a_test = a_ref.clone().detach().requires_grad_(True) 1159 b_test = b_ref.clone().detach().requires_grad_(True) 1160 out_test = torch.compile(f, backend="aot_eager")(a_test, b_test) 1161 1162 self.assertEqual(out_ref, out_test) 1163 self.assertEqual(a_ref.grad, a_test.grad) 1164 self.assertEqual(b_ref.grad, b_test.grad) 1165 1166 # https://github.com/pytorch/pytorch/issues/111603 1167 def test_tuple_enum_as_key_dict(self): 1168 class MyEnum(Enum): 1169 A = "a" 1170 1171 class SomeModel(torch.nn.Module): 1172 def __init__(self) -> None: 1173 super().__init__() 1174 self.linear = torch.nn.Linear(1, 1) 1175 1176 def forward(self, x) -> torch.Tensor: 1177 return self.linear(x[MyEnum.A]) 1178 1179 x = {MyEnum.A: torch.rand(8, 1)} 1180 model_pytorch = SomeModel() 1181 model = torch.compile(model_pytorch) 1182 # Executing twice works 1183 model(x) 1184 y = model(x) 1185 self.assertEqual(y, model_pytorch(x)) 1186 1187 def test_embedding_backward_broadcasting_decomp(self): 1188 def f(grad_output, indices): 1189 num_weights = 10 1190 padding_idx = 1 1191 scale_grad_by_freq = True 1192 return torch.ops.aten.embedding_dense_backward( 1193 grad_output, indices, num_weights, padding_idx, scale_grad_by_freq 1194 ) 1195 1196 f_compiled = torch.compile(f, backend="aot_eager") 1197 1198 grad_output = torch.ones(2, 4, 3, dtype=torch.float16) 1199 indices = torch.ones(2, 4, dtype=torch.int64) 1200 1201 out_ref = f(grad_output, indices) 1202 out_test = f_compiled(grad_output, indices) 1203 1204 self.assertEqual(out_ref, out_test) 1205 1206 def test_reformer_eval(self): 1207 with torch.no_grad(): 1208 cnt = self._reformer(nopython=True) 1209 self.assertEqual(cnt.frame_count, 1) 1210 self.assertEqual(cnt.op_count, 11) 1211 1212 def test_reformer_train(self): 1213 with torch.enable_grad(): 1214 cnt = self._reformer(nopython=False) 1215 expected_op_count = ( 1216 """11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5""" 1217 ) 1218 1219 self.assertExpectedInline(cnt.frame_count, """1""") 1220 self.assertExpectedInline(cnt.op_count, expected_op_count) 1221 1222 @disable_translation_validation_if_dynamic_shapes 1223 def test_longformer_chunk(self): 1224 input1 = torch.randn([1, 4096, 1]) 1225 input2 = torch.randn([12, 4096, 64]) 1226 correct1 = longformer_chunk(input1) 1227 correct2 = longformer_chunk(input2) 1228 fn = longformer_chunk 1229 cnt = torch._dynamo.testing.CompileCounter() 1230 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1231 self.assertTrue(same(opt_fn(input1), correct1)) 1232 self.assertTrue(same(opt_fn(input2), correct2)) 1233 self.assertTrue(same(opt_fn(input1), correct1)) 1234 self.assertTrue(same(opt_fn(input2), correct2)) 1235 1236 if torch._dynamo.config.assume_static_by_default: 1237 if torch._dynamo.config.automatic_dynamic_shapes: 1238 self.assertExpectedInline(cnt.frame_count, """2""") 1239 self.assertExpectedInline(cnt.op_count, """8""") 1240 else: 1241 self.assertExpectedInline(cnt.frame_count, """2""") 1242 self.assertExpectedInline(cnt.op_count, """4""") 1243 else: 1244 self.assertExpectedInline(cnt.frame_count, """2""") 1245 self.assertExpectedInline(cnt.op_count, """19""") 1246 1247 def test_hf_t5_forward(self): 1248 input = torch.randn([1, 2048, 512]) 1249 model = PartialT5() 1250 correct = model(input) 1251 cnt = torch._dynamo.testing.CompileCounter() 1252 opt_model = torch._dynamo.optimize_assert(cnt)(model) 1253 self.assertTrue(same(opt_model(input), correct)) 1254 1255 if torch._dynamo.config.assume_static_by_default: 1256 self.assertExpectedInline(cnt.frame_count, """1""") 1257 self.assertExpectedInline(cnt.op_count, """11""") 1258 else: 1259 self.assertExpectedInline(cnt.frame_count, """1""") 1260 self.assertExpectedInline(cnt.op_count, """11""") 1261 1262 def test_module_in_skipfiles(self): 1263 model = nn.Linear(10, 10) 1264 cnt = torch._dynamo.testing.CompileCounter() 1265 torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10])) 1266 self.assertEqual(cnt.frame_count, 1) 1267 self.assertEqual(cnt.op_count, 1) 1268 1269 def test_function_in_skipfiles(self): 1270 cnt = torch._dynamo.testing.CompileCounter() 1271 torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10])) 1272 self.assertEqual(cnt.frame_count, 1) 1273 self.assertEqual(cnt.op_count, 1) 1274 1275 def test_slicing_dynamic_shape(self): 1276 def fn(y): 1277 x = torch.ones(8) 1278 idx = y[0] 1279 out = x[idx:] 1280 return (out + 3) * 5 1281 1282 counter = torch._dynamo.testing.CompileCounter() 1283 opt_fn = torch._dynamo.optimize(counter)(fn) 1284 out = opt_fn(torch.ones(10, dtype=torch.long)) 1285 # idx should be 1 -> slicing off [1:] of 8 elem tensor 1286 self.assertEqual(list(out.shape), [7]) 1287 1288 self.assertEqual(counter.op_count, 2) 1289 self.assertEqual(counter.frame_count, 1) 1290 1291 self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4]) 1292 1293 def test_slicing_dynamic_shape_setitem(self): 1294 def fn(input_lengths: torch.Tensor, new_ones_1): 1295 getitem_13 = input_lengths[3] 1296 new_ones_1[(3, slice(getitem_13, None, None))] = 0 1297 setitem_13 = new_ones_1 1298 return (setitem_13,) 1299 1300 x = torch.randn(10).to(dtype=torch.int64) 1301 y = torch.randn(10, 204) 1302 ref = fn(x, y) 1303 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 1304 res = opt_fn(x, y) 1305 self.assertTrue(same(ref, res)) 1306 1307 @torch._dynamo.config.patch(error_on_recompile=True) 1308 @torch.fx.experimental._config.patch(use_duck_shape=False) 1309 def test_dynamic_shape_disable_duck_size(self): 1310 class TestModel(nn.Module): 1311 def __init__( 1312 self, 1313 ): 1314 super().__init__() 1315 1316 def forward(self, x: torch.Tensor, val: int) -> torch.Tensor: 1317 return x + val 1318 1319 main_model = TestModel().to(memory_format=torch.channels_last) 1320 opt_model = torch.compile(main_model, backend="eager", dynamic=True) 1321 1322 x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last) 1323 x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last) 1324 1325 o1_ref = main_model(x1, 4) 1326 o1 = opt_model(x1, 4) 1327 1328 o2_ref = main_model(x2, 20) 1329 o2 = opt_model(x2, 20) 1330 1331 def test_chunk_reformer_ff(self): 1332 input = torch.randn([1, 4096, 256]) 1333 model = ChunkReformerFeedForward() 1334 correct = model(input) 1335 cnt = torch._dynamo.testing.CompileCounter() 1336 opt_model = torch._dynamo.optimize_assert(cnt)(model) 1337 self.assertTrue(same(opt_model(input), correct)) 1338 1339 self.assertEqual(cnt.frame_count, 1) 1340 self.assertLessEqual(cnt.op_count, 10) 1341 1342 # see: https://github.com/pytorch/pytorch/issues/80067 1343 # NB: When you remove the expectedFailure, don't forget to 1344 # uncomment/adjust the assertEqual below 1345 @unittest.expectedFailure 1346 @torch._dynamo.config.patch( 1347 fake_tensor_propagation=True, capture_scalar_outputs=True 1348 ) 1349 def test_maml_item_capture(self): 1350 a = torch.randn(5, 1, 28, 28) 1351 b = torch.zeros(5, dtype=torch.int64) 1352 c = torch.randn(75, 1, 28, 28) 1353 d = torch.zeros(75, dtype=torch.int64) 1354 model = PartialMaml() 1355 correct = model(a, b, c, d) 1356 cnt = torch._dynamo.testing.CompileCounter() 1357 opt_model = torch._dynamo.optimize(cnt)(model) 1358 for _ in range(10): 1359 self.assertTrue(same(opt_model(a, b, c, d), correct)) 1360 1361 # if torch._dynamo.config.assume_static_by_default: 1362 # self.assertExpectedInline(cnt.frame_count, """2""") 1363 # else: 1364 # self.assertExpectedInline(cnt.frame_count, """3""") 1365 # TODO(jansel): figure out why op count depends on imports 1366 self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27)) 1367 1368 # see: https://github.com/pytorch/pytorch/issues/80067 1369 @torch._dynamo.config.patch(capture_scalar_outputs=False) 1370 def test_maml_no_item_capture(self): 1371 a = torch.randn(5, 1, 28, 28) 1372 b = torch.zeros(5, dtype=torch.int64) 1373 c = torch.randn(75, 1, 28, 28) 1374 d = torch.zeros(75, dtype=torch.int64) 1375 model = PartialMaml() 1376 correct = model(a, b, c, d) 1377 cnt = torch._dynamo.testing.CompileCounter() 1378 opt_model = torch._dynamo.optimize(cnt)(model) 1379 for _ in range(10): 1380 self.assertTrue(same(opt_model(a, b, c, d), correct)) 1381 1382 if torch._dynamo.config.assume_static_by_default: 1383 self.assertExpectedInline(cnt.frame_count, """4""") 1384 else: 1385 self.assertExpectedInline(cnt.frame_count, """5""") 1386 1387 def test_hf_model_output(self): 1388 ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) 1389 1390 def fn1(x): 1391 return x["a"] + 1 1392 1393 def fn2(x): 1394 return x.a + 1 1395 1396 def fn3(x): 1397 return x.to_tuple()[0] + 1 1398 1399 def fn4(x): 1400 return x[0] + 1 1401 1402 cnt = torch._dynamo.testing.CompileCounter() 1403 for fn in (fn1, fn2, fn3, fn4): 1404 cnt.clear() 1405 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1406 self.assertTrue(same(opt_fn(ex), ex.a + 1)) 1407 self.assertEqual(cnt.frame_count, 1) 1408 self.assertEqual(cnt.op_count, 1) 1409 1410 @disable_translation_validation_if_dynamic_shapes 1411 def test_create_rand_mask_from_inputs(self): 1412 args = [ 1413 torch.randn([1, 64, 64]), 1414 torch.randn([1, 64, 64]), 1415 torch.zeros([1, 12, 62, 3], dtype=torch.int64), 1416 12, 1417 3, 1418 1, 1419 4096, 1420 64, 1421 ] 1422 correct = create_rand_mask_from_inputs(*args) 1423 fn = create_rand_mask_from_inputs 1424 1425 cnt = torch._dynamo.testing.CompileCounter() 1426 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1427 self.assertTrue(same(opt_fn(*args), correct)) 1428 if torch._dynamo.config.assume_static_by_default: 1429 self.assertExpectedInline(cnt.frame_count, """1""") 1430 self.assertExpectedInline(cnt.op_count, """8""") 1431 else: 1432 self.assertExpectedInline(cnt.frame_count, """1""") 1433 self.assertExpectedInline(cnt.op_count, """11""") 1434 1435 def test_rng_state(self): 1436 def fn(): 1437 state = torch.get_rng_state() 1438 before = torch.rand(1000) 1439 torch.set_rng_state(state) 1440 after = torch.rand(1000) 1441 return before, after 1442 1443 cnt = torch._dynamo.testing.CompileCounter() 1444 opt_fn = torch._dynamo.optimize(cnt)(fn) 1445 1446 before, after = opt_fn() 1447 self.assertTrue(same(before, after)) 1448 self.assertEqual(cnt.frame_count, 2) 1449 self.assertEqual(cnt.op_count, 2) # rand, rand 1450 try: 1451 graph, _ = torch._dynamo.export(fn)() 1452 # See https://github.com/pytorch/pytorch/pull/87490 1453 self.fail("unexpected export success") 1454 except torch._dynamo.exc.Unsupported: 1455 pass 1456 1457 def test_threading_local(self): 1458 import threading 1459 1460 foo = threading.local() 1461 foo.x = torch.rand(1) 1462 1463 def f(x): 1464 return torch.cat([x, foo.x]) 1465 1466 cnt = torch._dynamo.testing.CompileCounter() 1467 opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 1468 1469 inp = torch.ones(1) 1470 out = f(inp) 1471 opt_out = opt_f(inp) 1472 self.assertEqual(opt_out, out) 1473 self.assertEqual(cnt.frame_count, 1) 1474 1475 def test_seq_append_list(self): 1476 x = torch.randn(4, 10) 1477 model = SequentialAppendList( 1478 torch.nn.Linear(10, 10), 1479 torch.nn.ReLU(), 1480 torch.nn.Linear(10, 10), 1481 torch.nn.ReLU(), 1482 ) 1483 # this one is tricky because it mutates the list provided as an input 1484 l1 = [x] 1485 l2 = [x] 1486 correct, _ = model(x, l1) 1487 cnt = torch._dynamo.testing.CompileCounter() 1488 opt_model = torch._dynamo.optimize_assert(cnt)(model) 1489 result, l3 = opt_model(x, l2) 1490 self.assertTrue(same(result, correct)) 1491 self.assertTrue(same(l1, l2)) 1492 self.assertIs(l2, l3) 1493 self.assertEqual(cnt.frame_count, 1) 1494 self.assertEqual(cnt.op_count, 5) 1495 1496 def test_batch_norm_act(self): 1497 a = torch.randn(5, 1, 28, 28) 1498 model = BatchNormAct2d(1).eval() 1499 correct = model(a) 1500 cnt = torch._dynamo.testing.CompileCounter() 1501 if not torch._dynamo.config.specialize_int: 1502 # _local_scalar_dense causes graph break w 0-dim tensor 1503 opt_model = torch._dynamo.optimize(cnt)(model) 1504 self.assertTrue(same(opt_model(a), correct)) 1505 return 1506 1507 opt_model = torch._dynamo.optimize_assert(cnt)(model) 1508 self.assertTrue(same(opt_model(a), correct)) 1509 self.assertEqual(cnt.frame_count, 1) 1510 self.assertEqual(cnt.op_count, 2) 1511 1512 def test_get_parameter_dtype(self): 1513 model = SequentialAppendList( 1514 torch.nn.Linear(10, 10), 1515 torch.nn.ReLU(), 1516 ) 1517 1518 def fn(model, x): 1519 return x + torch.randn(10, dtype=get_parameter_dtype(model)) 1520 1521 cnt = torch._dynamo.testing.CompileCounter() 1522 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1523 self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32) 1524 self.assertEqual(cnt.frame_count, 1) 1525 self.assertEqual(cnt.op_count, 2) 1526 1527 def test_nn_parameter(self): 1528 def test_fn(): 1529 a = torch.nn.Parameter(torch.randn(5, 5)) 1530 # Checks that TensorVariable stores the type information correctly 1531 self.assertTrue(isinstance(a, torch.nn.Parameter)) 1532 return a 1533 1534 cnt = torch._dynamo.testing.CompileCounter() 1535 opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1536 out = opt_test_fn() 1537 self.assertTrue(isinstance(out, torch.nn.Parameter)) 1538 1539 def test_Size(self): 1540 def test_fn(): 1541 a = torch.randn(4) 1542 x = torch.Size([1, 2, 3]) 1543 # Checks that SizeVariable return torch.Size object 1544 assert isinstance(x, torch.Size) 1545 # Causes graph breaks and checks reconstruction of SizeVariable 1546 # object 1547 self.assertIsInstance(x, torch.Size) 1548 return a 1549 1550 cnt = torch._dynamo.testing.CompileCounter() 1551 opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1552 opt_test_fn() 1553 1554 # See https://github.com/pytorch/pytorch/issues/100067 1555 def test_copy_weird_strides(self): 1556 # This test requires inductor's copy() decomp to preserve strides properly. 1557 def test_fn(a): 1558 b = torch.zeros(48, 4, 256, 513) 1559 b[:, 0, 1:256, 1:256] = a 1560 c = b.view(4, 12, 1024, 513) 1561 d = c.transpose(2, 1) 1562 d.add_(1) 1563 return d 1564 1565 sh, st, dt, dev, rg = ( 1566 (48, 255, 255), 1567 (787968, 513, 1), 1568 torch.float16, 1569 "cpu", 1570 True, 1571 ) 1572 a = rand_strided(sh, st, dt, dev).requires_grad_(rg) 1573 compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition") 1574 out1 = test_fn(a) 1575 out2 = compiled_f(a) 1576 self.assertEqual(out1, out2) 1577 1578 def test_indexing_with_list(self): 1579 def test_fn(): 1580 def run_test(tensor, *idx): 1581 npt = tensor.numpy() 1582 assert npt[idx].shape == tensor[idx].shape 1583 1584 x = torch.arange(0, 10) 1585 cases = [ 1586 [None, None], 1587 [1, None], 1588 ] 1589 1590 for case in cases: 1591 run_test(x, *case) 1592 1593 return torch.randn(4) 1594 1595 cnt = torch._dynamo.testing.CompileCounter() 1596 opt_test_fn = torch._dynamo.optimize(cnt)(test_fn) 1597 opt_test_fn() 1598 1599 def test_reformer_min_chunk_len(self): 1600 def fn(cfg): 1601 t = torch.empty(10) 1602 t.fill_(_get_min_chunk_len(cfg)) 1603 return t[0] 1604 1605 cfg = DummyConfig() 1606 cnt = torch._dynamo.testing.CompileCounter() 1607 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1608 self.assertEqual(opt_fn(cfg), 64) 1609 # With unspec int, maximum computation is preserved 1610 self.assertExpectedInline(cnt.frame_count, """1""") 1611 self.assertExpectedInline(cnt.op_count, """3""") 1612 1613 def test_reformer_sorting(self): 1614 x = torch.zeros([1, 12, 4096], dtype=torch.int64) 1615 correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x) 1616 fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx 1617 1618 cnt = torch._dynamo.testing.CompileCounter() 1619 opt_fn = torch._dynamo.optimize_assert(cnt)(fn) 1620 self.assertTrue(same(opt_fn(x), correct)) 1621 if torch._dynamo.config.assume_static_by_default: 1622 self.assertExpectedInline(cnt.frame_count, """1""") 1623 self.assertExpectedInline(cnt.op_count, """14""") 1624 else: 1625 self.assertExpectedInline(cnt.frame_count, """1""") 1626 self.assertExpectedInline(cnt.op_count, """16""") 1627 1628 def test_recursive_map(self): 1629 # https://github.com/pytorch/torchdynamo/issues/132 1630 def _recursive_map(struct, batch_dim=0): 1631 for k, v in struct.items(): 1632 if v is not None: 1633 if isinstance(v, dict): 1634 _recursive_map(v) 1635 else: 1636 struct[k] = v 1637 1638 def toy_example(a, b, v): 1639 x = a / (torch.abs(a) + 1) 1640 if v is not None: 1641 _recursive_map(v) 1642 return x * b 1643 1644 cnt = torch._dynamo.testing.CompileCounter() 1645 opt_toy_example = torch._dynamo.optimize(cnt)(toy_example) 1646 opt_toy_example( 1647 torch.randn(10), 1648 torch.randn(10), 1649 {"layer0": {"memory_keys": torch.randn(10)}}, 1650 ) 1651 self.assertEqual(cnt.frame_count, 1) 1652 self.assertEqual(cnt.op_count, 4) 1653 1654 def test_issue114171(self): 1655 device = torch.device("cpu") 1656 1657 def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU): 1658 layers = [ 1659 torch.nn.Linear(in_dim, hidden_dim, device=device), 1660 activation(), 1661 torch.nn.Linear(hidden_dim, out_dim, device=device), 1662 ] 1663 return torch.nn.Sequential(*layers) 1664 1665 class testmodel(torch.nn.Module): 1666 def __init__(self) -> None: 1667 super().__init__() 1668 self.interaction_networks = torch.nn.ModuleList( 1669 [fcnn(262, 1174, 400) for _ in range(4)] 1670 ) 1671 1672 def interact(self, x, cycle): 1673 return self.interaction_networks[cycle](x) 1674 1675 model = testmodel() 1676 forward_aot = torch.compile( 1677 model.interact, fullgraph=True, dynamic=True, backend="eager" 1678 ) 1679 1680 x = torch.rand([111, 262], device=device) 1681 y2 = forward_aot(x, 2) # previously failed 1682 1683 def test_issue175(self): 1684 n_heads = 2 1685 d_model = 64 1686 model = TransformerEncoderLayer(d_model, n_heads) 1687 inp = torch.randn(1, d_model) 1688 cnt = torch._dynamo.testing.CompileCounter() 1689 opt_model = torch._dynamo.optimize(cnt, nopython=True)(model) 1690 opt_model(inp) 1691 opt_model(inp) 1692 self.assertEqual(cnt.frame_count, 1) 1693 1694 self.assertEqual( 1695 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count 1696 ) 1697 1698 def test_exec_import(self): 1699 def fn1(): 1700 exec("import math") 1701 1702 def fn2(): 1703 try: 1704 math.sqrt(4) 1705 return False 1706 except NameError: 1707 return True 1708 1709 def fn3(): 1710 fn1() 1711 return fn2() 1712 1713 self.assertTrue(fn3()) 1714 opt_fn3 = torch._dynamo.optimize("eager")(fn3) 1715 self.assertTrue(opt_fn3()) 1716 1717 def test_exec_wildcard_import(self): 1718 # Test that globals are not carried over from frame to frame 1719 def fn1(): 1720 exec("from torch import *") 1721 1722 def fn2(): 1723 x = torch.zeros(4) 1724 for i in range(5): 1725 x = x + i 1726 return x 1727 1728 def fn3(): 1729 fn1() 1730 return fn2() 1731 1732 ref = fn3() 1733 opt_fn3 = torch._dynamo.optimize("eager")(fn3) 1734 res = opt_fn3() 1735 self.assertTrue(same(ref, res)) 1736 1737 def test_with_on_graph_break_inst(self): 1738 def reversible(x): 1739 print("Hello world") # Cause graph break so inline fails 1740 return torch.sin(torch.cos(x)) 1741 1742 def fn(x): 1743 with torch.enable_grad(): 1744 a = torch.sin(x) 1745 b = reversible(a) 1746 c = torch.sigmoid(b) 1747 c.sum().backward() 1748 return x.grad 1749 1750 x = torch.randn(3, requires_grad=True) 1751 x.grad = None 1752 with torch.no_grad(): 1753 ref = fn(x) 1754 1755 x.grad = None 1756 opt_fn = torch._dynamo.optimize("eager")(fn) 1757 with torch.no_grad(): 1758 res = opt_fn(x) 1759 self.assertTrue(same(ref, res)) 1760 1761 def test_with_on_graph_break_nested(self): 1762 def reversible(x): 1763 torch._dynamo.graph_break() # Cause graph break so inline fails 1764 return torch.sin(torch.cos(x)) 1765 1766 def fn(x): 1767 # nested context manager failed previously 1768 with torch.no_grad(): 1769 with torch.enable_grad(): 1770 a = torch.sin(x) 1771 b = reversible(a) 1772 c = torch.sigmoid(b) 1773 c.sum().backward() 1774 return x.grad 1775 1776 x = torch.randn(3, requires_grad=True) 1777 x.grad = None 1778 with torch.no_grad(): 1779 ref = fn(x) 1780 1781 x.grad = None 1782 opt_fn = torch._dynamo.optimize("eager")(fn) 1783 with torch.no_grad(): 1784 res = opt_fn(x) 1785 self.assertTrue(same(ref, res)) 1786 1787 # https://github.com/pytorch/torchdynamo/issues/1446 1788 def test_grad_mode_carrying_correct_state_after_graph_break(self): 1789 def fn(x): 1790 with torch.no_grad(): 1791 y = x * 3 1792 print("Break") 1793 z = x + 2 1794 return y, z 1795 1796 x = torch.randn(3, requires_grad=True) 1797 opt_fn = torch._dynamo.optimize("eager")(fn) 1798 y, z = opt_fn(x) 1799 self.assertFalse(y.requires_grad) 1800 self.assertFalse(z.requires_grad) 1801 1802 def test_abc_setattr(self): 1803 # tests that we correctly bail out of __setattr__ calls 1804 1805 # TODO: does not ensure ABC classes are correctly inferred as ClassVariables 1806 # (doesn't test the fix for 'super()') 1807 1808 class BaseModule(torch.nn.Module, ABC): 1809 def blah(self, x): 1810 return x + 1 1811 1812 class Derived(BaseModule): 1813 def __setattr__(self, name, value) -> None: 1814 super().__setattr__(name, value) 1815 1816 def forward(self, x): 1817 # expect a graph break on __setattr__ 1818 self.foo = 0 1819 return self.blah(x) 1820 1821 def blah(self, x): 1822 return super().blah(x) 1823 1824 x = torch.randn(3, requires_grad=True) 1825 mod = Derived() 1826 opt_mod = torch._dynamo.optimize("eager")(mod) 1827 opt_mod(x) 1828 1829 # Not sure what this test is testing. It was earlier graph breaking on 1830 # __dict__, so the counter >= 2. With __dict__ support, there is no 1831 # graph break. 1832 self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) 1833 self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1) 1834 1835 @torch._dynamo.config.patch("suppress_errors", True) 1836 def test_guard_fail_tensor_bool(self): 1837 @torch._dynamo.disable(recursive=False) 1838 def fn(): 1839 condition_shape = (5, 5) 1840 dtypes = (torch.bool,) 1841 shapes = ( 1842 (), 1843 (5,), 1844 (1, 5), 1845 ) 1846 1847 tensors = [ 1848 torch.empty(shape, dtype=dtype).fill_(17) 1849 for shape, dtype in itertools.product(shapes, dtypes) 1850 ] 1851 1852 x_vals = (5.0, *tensors) 1853 y_vals = (6.0, *tensors) 1854 1855 @torch._dynamo.disable 1856 def get_expected(condition, x, y): 1857 x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x 1858 y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y 1859 return torch.from_numpy( 1860 np.where(condition.cpu().numpy(), x_np, y_np) 1861 ).to(common_dtype) 1862 1863 for x, y in zip(x_vals, y_vals): 1864 condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_() 1865 common_dtype = torch.result_type(x, y) 1866 1867 def check_equal(condition, x, y): 1868 # NumPy aggressively promotes to double, hence cast to output to correct dtype 1869 expected = get_expected(condition, x, y) 1870 result = torch.where(condition, x, y) 1871 assert torch.allclose(expected, result) 1872 1873 check_equal(condition, x, y) 1874 check_equal(condition, y, x) 1875 1876 fn() 1877 opt_fn = torch._dynamo.optimize("eager")(fn) 1878 opt_fn() 1879 1880 def test_guard_fail_nested_tuple(self): 1881 def fn(args): 1882 return torch.ones(()), args[0] * 2 1883 1884 # This adds a tensor check on args[1][0] and args[1][1] 1885 args1 = (torch.ones(1), (torch.ones(1), torch.ones(1))) 1886 args2 = (torch.ones(1), torch.ones(1)) 1887 opt_fn = torch._dynamo.optimize("eager")(fn) 1888 ref = opt_fn(args1) 1889 res = opt_fn(args2) 1890 1891 self.assertTrue(same(ref, res)) 1892 1893 def test_nullcontext1(self): 1894 @torch.compile(fullgraph=True, backend="eager") 1895 def fn(x, ctx): 1896 x = x.sin() 1897 with ctx: 1898 x = x.cos() 1899 x = x.sin() 1900 return x 1901 1902 y = torch.randn(10) 1903 self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin())) 1904 1905 def test_nullcontext2(self): 1906 @torch.compile(fullgraph=True, backend="eager") 1907 def fn(x, ctx): 1908 x = x.sin() 1909 with ctx(): 1910 x = x.cos() 1911 x = x.sin() 1912 return x 1913 1914 y = torch.randn(10) 1915 self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin())) 1916 1917 def test_no_grad_inline(self): 1918 @torch.no_grad() 1919 def a(x): 1920 return x.sin() 1921 1922 @torch.compile(backend="eager", fullgraph=True) 1923 def b(x): 1924 return a(x).cos() 1925 1926 y = torch.randn(10) 1927 self.assertTrue(same(b(y), y.sin().cos())) 1928 1929 @skipIfWindows( 1930 msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function <class 'torch.LongTensor'>(*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):" # noqa: B950 1931 ) 1932 def test_longtensor_list(self): 1933 for partition in [0, 5, 10]: 1934 1935 @torch._dynamo.disable 1936 def rand_gen(): 1937 rand_vals = [random.randint(5, 10) for _ in range(10)] 1938 # List of tensors mixed with np.arrays 1939 return list(np.array(rand_vals[:partition])) + [ 1940 torch.tensor(val) for val in rand_vals[partition:] 1941 ] 1942 1943 def fn(x): 1944 random_list = rand_gen() 1945 z = torch.LongTensor(random_list) 1946 return x * z 1947 1948 x = torch.ones(10) * 2 1949 1950 random.seed(0) 1951 ref0 = fn(x) 1952 ref1 = fn(x) 1953 1954 random.seed(0) 1955 opt_fn = torch._dynamo.optimize("eager")(fn) 1956 res0 = opt_fn(x) 1957 res1 = opt_fn(x) 1958 1959 self.assertTrue(same(ref0, res0)) 1960 self.assertTrue(same(ref1, res1)) 1961 1962 def test_primtorch(self): 1963 @torch._dynamo.optimize("eager") 1964 def fn(x): 1965 torch._refs.abs(x) 1966 1967 fn(torch.randn(3)) 1968 1969 @unittest.expectedFailure 1970 # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)] 1971 def test_primtorch_no_graph_break(self): 1972 @torch._dynamo.optimize("eager", nopython=True) 1973 def fn(x): 1974 torch._refs.abs(x) 1975 1976 fn(torch.randn(3)) 1977 1978 def test_torch_tensor_ops_no_graph_break(self): 1979 @torch._dynamo.optimize("eager", nopython=True) 1980 def fn(x): 1981 torch.Tensor.abs_(x) 1982 1983 fn(torch.randn(3)) 1984 1985 @unittest.skipIf( 1986 not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket), 1987 "old pt doesn't work", 1988 ) 1989 def test_torch_ops_aten(self): 1990 # Picked an op that doesn't show up in the default list 1991 @torch._dynamo.optimize("eager", nopython=True) 1992 def fn(x): 1993 return torch.ops.aten.absolute(x) 1994 1995 fn(torch.randn(3)) 1996 1997 def test_hf_gelu_inline(self): 1998 class GELUActivation(nn.Module): 1999 def __init__(self) -> None: 2000 super().__init__() 2001 self.act = nn.functional.gelu 2002 2003 def forward(self, input): 2004 return self.act(input) 2005 2006 @torch._dynamo.optimize("eager", nopython=True) 2007 def fn(x): 2008 return GELUActivation()(x) 2009 2010 y = torch.randn(10) 2011 self.assertTrue(same(fn(y), nn.functional.gelu(y))) 2012 2013 @torch._dynamo.optimize("eager", nopython=True) 2014 def fn_returns(x): 2015 return GELUActivation(), x + 1 2016 2017 act, _ = fn_returns(y) 2018 self.assertIsInstance(act, GELUActivation) 2019 self.assertIs(act.act, nn.functional.gelu) 2020 self.assertTrue(hasattr(act, "_buffers")) # check that __init__ got called 2021 2022 def test_dropout_inline(self): 2023 @torch._dynamo.optimize("eager") 2024 def fn(x): 2025 return torch.nn.Dropout(0.1)(x) 2026 2027 y = torch.randn(10) 2028 torch.manual_seed(1337) 2029 ref = nn.functional.dropout(y, 0.1) 2030 torch.manual_seed(1337) 2031 res = fn(y) 2032 self.assertTrue(same(ref, res)) 2033 2034 def test_setitem_boolean_mask_diff(self): 2035 def fn(x, b, y): 2036 x = x.clone() 2037 x[b] = y 2038 return x 2039 2040 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2041 x = torch.randn(4, requires_grad=True) 2042 b = torch.tensor([True, False, True, False]) 2043 y = torch.randn(2, requires_grad=True) 2044 opt_fn(x, b, y) 2045 2046 def test_setitem_tuple_boolean_mask_diff(self): 2047 def fn(x, b, y): 2048 x = x.clone() 2049 x[:, b] = y 2050 return x 2051 2052 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2053 x = torch.randn(8, 4, requires_grad=True) 2054 b = torch.tensor([True, False, True, False]) 2055 y = torch.randn(2, requires_grad=True) 2056 opt_fn(x, b, y) 2057 2058 def test_torch_tensor_ops(self): 2059 def fn(x): 2060 return torch.Tensor.abs_(x) 2061 2062 x = torch.randn(3) 2063 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2064 y = fn(x) 2065 y_ = opt_fn(x) 2066 self.assertTrue(same(y, y_)) 2067 2068 def test_guard_ordering_shape_fail(self): 2069 # If a function which takes a tensor has an inner function which 2070 # is compiled and generates a guard on its shape, 2071 # they are evaluated in the wrong order. So if on a subsequent call 2072 # an int is passed instead of a tensor, guard evaluation will crash 2073 # with a "no attribute: shape" error 2074 m = MockModule() 2075 opt_m = torch._dynamo.optimize("eager")(m) 2076 opt_m.fn(torch.ones((5, 5))) 2077 opt_m.fn(-3) 2078 2079 def test_tensor_isinstance_tuple(self): 2080 @torch._dynamo.optimize("eager") 2081 def fn(): 2082 t = torch.ones(5, 5) 2083 if not isinstance(t, (int, torch.Tensor)): 2084 msg = str.format( 2085 "{0} is not an instance of {1}", 2086 type(t), 2087 (int, torch.Tensor), 2088 ) 2089 raise ValueError(msg) 2090 return True 2091 2092 fn() 2093 2094 def test_isinstance_dtype(self): 2095 @torch._dynamo.optimize("eager", nopython=True) 2096 def fn(x): 2097 isinstance(torch.bfloat16, torch.dtype) 2098 return x 2099 2100 fn(torch.randn(3)) 2101 2102 def test_isinstance_storage(self): 2103 @torch._dynamo.optimize("eager") 2104 def fn(x): 2105 f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) 2106 bools = torch.BoolStorage.from_buffer(f, "big") 2107 assert isinstance(bools, torch.BoolStorage) 2108 return x 2109 2110 fn(torch.randn(3)) 2111 2112 def test_issue111522(self): 2113 @torch.compile(backend="eager", fullgraph=True) 2114 def f(x, y): 2115 return x + y.a 2116 2117 class A: 2118 a = 2 2119 2120 self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0)) 2121 2122 del A.a 2123 2124 # graph break on missing attr 2125 with self.assertRaises(torch._dynamo.exc.Unsupported): 2126 f(torch.zeros(2), A()) 2127 2128 def test_dict_list_values(self): 2129 def inner_fn(args): 2130 return [x[1].shape for x in args] 2131 2132 @torch._dynamo.optimize("eager") 2133 def fn(tensors): 2134 return inner_fn(zip(itertools.count(), tensors["args"])) 2135 2136 fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]}) 2137 fn({"args": [torch.ones(5, 5)]}) 2138 2139 def test_dict_iter(self): 2140 class MyMod(torch.nn.Module): 2141 def forward(self, x): 2142 z = {"my": 1, "const": 2, "dict": 3, "variable": 4} 2143 tot = 0 2144 for key in z: 2145 tot += z[key] 2146 2147 return tot 2148 2149 x = torch.tensor([0]) 2150 model = MyMod() 2151 opt_model = torch._dynamo.optimize("eager", nopython=True)(model) 2152 y = opt_model(x) 2153 2154 self.assertEqual(y, 10) 2155 2156 def test_sort_out(self): 2157 dtype = torch.float32 2158 device = "cpu" 2159 2160 def fn(): 2161 tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0] 2162 values1 = torch.tensor(0, dtype=dtype, device=device) 2163 indices1 = torch.tensor(0, dtype=torch.long, device=device) 2164 torch.sort(tensor, out=(values1, indices1)) 2165 self.assertEqual(values1.stride(), (1,)) 2166 self.assertEqual(indices1.stride(), (1,)) 2167 2168 fn() 2169 opt_fn = torch._dynamo.optimize("eager")(fn) 2170 opt_fn() 2171 2172 def test_sort_out2(self): 2173 class MyModule(torch.nn.Module): 2174 def __init__(self) -> None: 2175 super().__init__() 2176 self.sorted = torch.nn.Buffer(torch.ones(4, 4)) 2177 self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long)) 2178 2179 def forward(self, x): 2180 torch.sort(x, out=(self.sorted, self.indices)) 2181 return (x + 1, self.sorted, self.indices) 2182 2183 x = torch.randn(4, 4) 2184 m = MyModule() 2185 ref = m(x) 2186 opt_m = torch._dynamo.optimize("eager")(m) 2187 res = opt_m(x) 2188 self.assertTrue(same(ref, res)) 2189 2190 def test_sigmoid_out(self): 2191 dtype = torch.float32 2192 device = "cpu" 2193 2194 def fn(): 2195 inp = torch.randn((3, 5), dtype=dtype, device=device) 2196 out1 = torch.tensor(0, dtype=dtype, device=device) 2197 torch.sigmoid(inp, out=out1) 2198 self.assertEqual(out1.numel(), 15) 2199 2200 fn() 2201 opt_fn = torch._dynamo.optimize("eager")(fn) 2202 opt_fn() 2203 2204 def test_sigmoid_out2(self): 2205 class MyModule(torch.nn.Module): 2206 def __init__(self) -> None: 2207 super().__init__() 2208 self.base = torch.nn.Buffer(torch.ones(4, 4)) 2209 2210 def forward(self, x): 2211 torch.sigmoid(x, out=self.base) 2212 return x + self.base 2213 2214 x = torch.randn(4, 4) 2215 m = MyModule() 2216 ref = m(x) 2217 opt_m = torch._dynamo.optimize("eager")(m) 2218 res = opt_m(x) 2219 self.assertTrue(same(ref, res)) 2220 2221 def test_slice_into_list_mutable(self): 2222 class Mod(torch.nn.Module): 2223 def forward(self, listy): 2224 x = listy[3:5] 2225 for i in range(10): 2226 z = torch.abs(torch.randn(10)) + 1 2227 x[0] = z 2228 return x 2229 2230 m = Mod() 2231 listy = [torch.randn(10)] * 10 2232 2233 cnt = torch._dynamo.testing.CompileCounter() 2234 opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) 2235 opt_m.forward(listy) 2236 2237 self.assertEqual(cnt.frame_count, 1) 2238 2239 @torch._dynamo.config.patch(capture_scalar_outputs=True) 2240 def test_issue111918(self): 2241 cnt = CompileCounter() 2242 2243 @torch.compile(backend=cnt, dynamic=True) 2244 def fn(x): 2245 x = x + 1 2246 y = x.item() 2247 if y > 2: 2248 return x * 2 2249 else: 2250 return x * 3 2251 2252 x = torch.tensor([3.0]) 2253 fn(x) 2254 self.assertEqual(cnt.frame_count, 2) 2255 self.assertEqual(cnt.op_count, 4) 2256 2257 torch._dynamo.reset() 2258 fn = torch.compile(fn, fullgraph=True, backend="eager") 2259 with self.assertRaises(torch._dynamo.exc.UserError): 2260 fn(x) 2261 2262 def test_vdd_duplicate_error(self): 2263 def fn(a, dt): 2264 keys = list(dt._jt_dict.keys()) 2265 p = torch.cos(dt._jt_dict[keys[0]]._value) 2266 q = torch.sin(a) 2267 r = torch.sigmoid(dt._jt_dict[keys[0]]._value) 2268 return p + q + r 2269 2270 class Value: 2271 def __init__(self) -> None: 2272 self._value = torch.randn(4) 2273 2274 class Sample: 2275 def __init__(self) -> None: 2276 self._jt_dict = {} 2277 self._jt_dict["POSITION_ID"] = Value() 2278 2279 a = torch.randn(4) 2280 sample = Sample() 2281 2282 ref = fn(a, sample) 2283 2284 optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2285 res = optimized_fn(a, sample) 2286 2287 self.assertTrue(same(ref, res)) 2288 2289 def test_specialized_stride(self): 2290 def f(): 2291 e = torch.empty(4) 2292 x = e[::2] 2293 return x.stride() 2294 2295 self.assertEqual(f(), torch._dynamo.optimize("eager")(f)()) 2296 2297 def test_out_none(self): 2298 # https://github.com/pytorch/pytorch/issues/92814 2299 def fn(input): 2300 return torch.nn.functional.normalize(input, dim=0, out=None) 2301 2302 x = torch.rand([1]) 2303 self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x)) 2304 2305 def test_multi_import(self): 2306 if not has_detectron2(): 2307 raise unittest.SkipTest("requires detectron2") 2308 2309 @torch._dynamo.optimize("eager", nopython=True) 2310 def to_bitmasks(boxes): 2311 from detectron2.layers.mask_ops import ( 2312 _paste_masks_tensor_shape, 2313 paste_masks_in_image, 2314 ) 2315 2316 if ( 2317 paste_masks_in_image is not None 2318 and _paste_masks_tensor_shape is not None 2319 ): 2320 return boxes + 1 2321 2322 self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all()) 2323 2324 def test_multi_dot_import(self): 2325 def fn1(x): 2326 return torch.sin(x) 2327 2328 def fn(x): 2329 import torch.fx 2330 2331 _ = torch.fx.symbolic_trace(fn1) 2332 return x * 2 2333 2334 x = torch.randn(10) 2335 fn(x) 2336 cnt = torch._dynamo.testing.CompileCounter() 2337 opt_fn = torch._dynamo.optimize(cnt)(fn) 2338 opt_fn(x) 2339 self.assertEqual(cnt.frame_count, 1) 2340 2341 def test_relative_import(self): 2342 try: 2343 from . import utils as _ # noqa: F401 2344 2345 def fn(x): 2346 from .utils import tensor_for_import_testing 2347 2348 return x * 2 * tensor_for_import_testing 2349 2350 except ImportError: 2351 2352 def fn(x): 2353 from utils import tensor_for_import_testing 2354 2355 return x * 2 * tensor_for_import_testing 2356 2357 x = torch.randn(10) 2358 fn(x) 2359 cnt = torch._dynamo.testing.CompileCounter() 2360 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 2361 opt_fn(x) 2362 self.assertEqual(cnt.frame_count, 1) 2363 2364 def test_relative_import_no_modulename(self): 2365 try: 2366 from . import utils as _ # noqa: F401 2367 2368 def fn(x): 2369 from . import utils 2370 2371 return x * 2 * utils.tensor_for_import_testing 2372 2373 except ImportError: 2374 2375 def fn(x): 2376 import utils 2377 2378 return x * 2 * utils.tensor_for_import_testing 2379 2380 x = torch.randn(10) 2381 fn(x) 2382 cnt = torch._dynamo.testing.CompileCounter() 2383 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 2384 opt_fn(x) 2385 self.assertEqual(cnt.frame_count, 1) 2386 2387 def test_bigbird_unsqueeze_inplace(self): 2388 def fn(reshape_2): 2389 view_2 = reshape_2.clone() 2390 view_2.unsqueeze_(2) 2391 cat_11 = torch.cat([view_2], dim=2) 2392 view_13 = cat_11.view((2, 12, 64, -1)) 2393 return (view_13,) 2394 2395 x = torch.randn(2, 12, 64, 64, requires_grad=True) 2396 ref = fn(x) 2397 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2398 res = opt_fn(x) 2399 self.assertTrue(same(ref, res)) 2400 2401 def test_issue1466_size_aot_autograd(self): 2402 def fn(x): 2403 # do a tensor op and a size compute 2404 y = x * 2 2405 x_size = x.size() 2406 # trigger a graph break 2407 print("arf") 2408 # use the tensor op and size compute 2409 z = y.view(x_size) + 1 2410 return z 2411 2412 x = torch.randn(2, 3, requires_grad=True) 2413 ref = fn(x) 2414 opt_fn = torch._dynamo.optimize("aot_eager")(fn) 2415 res = opt_fn(x) 2416 self.assertTrue(same(ref, res)) 2417 2418 def test_ellipsis(self): 2419 class Repro(torch.nn.Module): 2420 def __init__(self) -> None: 2421 super().__init__() 2422 self.lnorm = torch.nn.LayerNorm( 2423 (256,), eps=1e-06, elementwise_affine=True 2424 ) 2425 self.linear = torch.nn.Linear( 2426 in_features=256, out_features=256, bias=True 2427 ) 2428 2429 def forward(self, cat_10): 2430 lnorm = self.lnorm(cat_10) 2431 getitem_64 = lnorm[ 2432 (slice(None, None, None), slice(0, 1, None), Ellipsis) 2433 ] 2434 linear = self.linear(getitem_64) 2435 return (linear,) 2436 2437 args = [torch.randn(2, 197, 256)] 2438 2439 mod = Repro() 2440 opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2441 2442 self.assertTrue(same(mod(*args), opt_mod(*args))) 2443 2444 def test_reinplacing(self): 2445 class MockModule(torch.nn.Module): 2446 def __init__(self) -> None: 2447 super().__init__() 2448 self.self_layoutlm_embeddings_x_position_embeddings = ( 2449 torch.nn.Embedding(1024, 768) 2450 ) 2451 self.self_layoutlm_embeddings_y_position_embeddings = ( 2452 torch.nn.Embedding(1024, 768) 2453 ) 2454 2455 def forward(self, getitem_1, getitem_2, add): 2456 self_layoutlm_embeddings_x_position_embeddings = ( 2457 self.self_layoutlm_embeddings_x_position_embeddings(getitem_1) 2458 ) 2459 self_layoutlm_embeddings_y_position_embeddings = ( 2460 self.self_layoutlm_embeddings_y_position_embeddings(getitem_2) 2461 ) 2462 add_1 = add + self_layoutlm_embeddings_x_position_embeddings 2463 add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings 2464 return (add_2,) 2465 2466 mod = MockModule() 2467 opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod) 2468 2469 args = [ 2470 ((2, 512), (2048, 4), torch.int64, "cpu", False), 2471 ((2, 512), (2048, 4), torch.int64, "cpu", False), 2472 ((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True), 2473 ] 2474 args = [ 2475 rand_strided(sh, st, dt, dev).requires_grad_(rg) 2476 for (sh, st, dt, dev, rg) in args 2477 ] 2478 self.assertTrue(same_two_models(mod, opt_mod, args)) 2479 2480 def test_optimized_deepcopy(self): 2481 # See https://github.com/pytorch/pytorch/pull/88629 2482 class Foo(torch.nn.Module): 2483 def __init__(self) -> None: 2484 super().__init__() 2485 self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True) 2486 2487 def forward(self, x): 2488 return self.fc(x) 2489 2490 mod = Foo() 2491 opt_mod = torch._dynamo.optimize("eager")(mod) 2492 args = [torch.randn(1, 2)] 2493 self.assertTrue(same_two_models(mod, opt_mod, args)) 2494 2495 def test_class_member(self): 2496 class Foo(torch.nn.Module): 2497 a = 4 2498 b = torch.ones(3, 4) 2499 2500 def __init__(self) -> None: 2501 super().__init__() 2502 self.c = 4 2503 2504 def forward(self, x): 2505 return x.cos() + self.a + self.b + self.c 2506 2507 mod = Foo() 2508 opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2509 args = (torch.randn(3, 4),) 2510 self.assertTrue(same(mod(*args), opt_mod(*args))) 2511 2512 def test_named_buffers(self): 2513 class Foo(torch.nn.Module): 2514 def __init__(self) -> None: 2515 super().__init__() 2516 self.x = torch.nn.Buffer(torch.ones(3)) 2517 self.y = torch.nn.Buffer(torch.ones(3)) 2518 2519 def forward(self, inp): 2520 res = 0 2521 for name, buffer in self.named_buffers(): 2522 res += buffer.sum() 2523 2524 return inp.cos() + res 2525 2526 mod = Foo() 2527 opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod) 2528 args = (torch.randn(3, 4),) 2529 self.assertTrue(same(mod(*args), opt_mod(*args))) 2530 2531 def test_requires_grad_guards_with_grad_mode1(self): 2532 def f(x): 2533 if x.requires_grad: 2534 return x + 1 2535 else: 2536 return x + 2 2537 2538 x = torch.ones(2, requires_grad=True) 2539 2540 f_compiled = torch.compile(f) 2541 with torch.no_grad(): 2542 # compile an inference graph 2543 f_compiled(x) 2544 2545 # Test: we should fail guards and recompile (even though it's still an inference graph) 2546 out_ref = f(x.detach()) 2547 out = f_compiled(x.detach()) 2548 2549 self.assertEqual(out_ref, out) 2550 self.assertEqual(out_ref.requires_grad, out.requires_grad) 2551 2552 def test_requires_grad_guards_with_grad_mode2(self): 2553 x = torch.ones(2, requires_grad=True) 2554 x_ref = x.clone().detach().requires_grad_(True) 2555 2556 m = torch.nn.Linear(2, 2) 2557 m_compiled = torch.compile(m) 2558 2559 with torch.no_grad(): 2560 # compile an inference graph 2561 m_compiled(x) 2562 2563 # Test: we should fail guards and recompile a training graph 2564 out_ref = m(x_ref) 2565 out = m_compiled(x) 2566 self.assertEqual(out_ref, out) 2567 self.assertEqual(out_ref.requires_grad, out.requires_grad) 2568 2569 def test_is_symbolic_tracing(self): 2570 # Ensure no graph break here 2571 def fn(x): 2572 if is_fx_tracing_test(): 2573 return x * 2 2574 return x * 4 2575 2576 a = torch.randn(4) 2577 ref = fn(a) 2578 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 2579 res = opt_fn(a) 2580 self.assertTrue(same(ref, res)) 2581 2582 def test_tokenization(self): 2583 from collections import UserDict 2584 2585 class BatchEncoding(UserDict): 2586 """ 2587 Copied from tokenization 2588 """ 2589 2590 def __init__( 2591 self, 2592 data, 2593 ): 2594 super().__init__(data) 2595 2596 def __getattr__(self, item: str): 2597 try: 2598 return self.data[item] 2599 except KeyError as e: 2600 raise AttributeError from e 2601 2602 def tokenization(x): 2603 encoding = BatchEncoding({"key": x}) 2604 return encoding["key"] 2605 2606 opt_fn = torch._dynamo.optimize("eager")(tokenization) 2607 x = torch.rand((1, 4)) 2608 ref = tokenization(x) 2609 res = opt_fn(x) 2610 self.assertTrue(same(ref, res)) 2611 2612 def test_modules(self): 2613 class Foo(torch.nn.Module): 2614 def __init__(self) -> None: 2615 super().__init__() 2616 self.fc = torch.nn.Linear(4, 3) 2617 2618 def forward(self, inp): 2619 res = torch.zeros(3, 3) 2620 for mod in self.modules(): 2621 res += self.fc(inp) 2622 return res 2623 2624 mod = Foo() 2625 args = (torch.ones(3, 4),) 2626 cnt = torch._dynamo.testing.CompileCounter() 2627 opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod) 2628 self.assertTrue(same(mod(*args), opt_mod(*args))) 2629 self.assertEqual(cnt.op_count, 5) 2630 self.assertEqual(cnt.frame_count, 1) 2631 2632 def test_omegaconf_listconfig_iter(self): 2633 obj = ListConfig() 2634 x = torch.zeros(2) 2635 2636 def fn(): 2637 y = x 2638 for i in obj: 2639 y += i 2640 return y 2641 2642 expected = fn() 2643 actual = torch.compile(fn, fullgraph=True, backend="eager")() 2644 self.assertEqual(actual, expected) 2645 2646 def test_user_defined_iter(self): 2647 class MyIter: 2648 def __init__(self) -> None: 2649 self.i = 0 2650 2651 def __iter__(self): 2652 return self 2653 2654 def __next__(self): 2655 if self.i < 3: 2656 self.i += 1 2657 return self.i 2658 raise StopIteration 2659 2660 @torch.compile(backend="eager", fullgraph=True) 2661 def fn(x): 2662 for i in MyIter(): 2663 x += i 2664 return x 2665 2666 self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0)) 2667 2668 def test_stop_iteration_reconstruct(self): 2669 @torch.compile(backend="eager", fullgraph=True) 2670 def fn(x): 2671 return x.sin(), StopIteration(1, 2, 3) 2672 2673 _, res = fn(torch.ones(1)) 2674 self.assertEqual(str(res), str(StopIteration(1, 2, 3))) 2675 2676 def test_tensor_data_kwarg(self): 2677 # https://github.com/pytorch/pytorch/issues/96278 2678 def f(): 2679 return torch.tensor(data=[[1.0, -1.0]]) 2680 2681 cnt = torch._dynamo.testing.CompileCounter() 2682 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 2683 self.assertTrue(same(f(), opt_fn())) 2684 self.assertEqual(cnt.frame_count, 1) 2685 2686 @requires_cuda 2687 def test_norm_dtype(self): 2688 def foo(_stack0): 2689 getitem = _stack0[(slice(None, None, None), -1)] 2690 _stack0 = None 2691 normalize = torch.nn.functional.normalize(getitem, p=2, dim=1) 2692 getitem = None 2693 return (normalize,) 2694 2695 args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)] 2696 args = [ 2697 rand_strided(sh, st, dt, dev).requires_grad_(rg) 2698 for (sh, st, dt, dev, rg) in args 2699 ] 2700 2701 opt_foo = torch._dynamo.optimize("aot_eager_decomp_partition")(foo) 2702 with torch.cuda.amp.autocast(enabled=True): 2703 ref = foo(*args)[0] 2704 res = foo(*args)[0] 2705 self.assertEqual(ref.dtype, res.dtype) 2706 2707 self.assertTrue(same(res, ref)) 2708 2709 def test_for_loop_graph_break(self): 2710 def inner(x): 2711 return torch.sin(x) 2712 2713 def fn(x): 2714 for _ in range(100): 2715 inner(x) 2716 torch._dynamo.graph_break() 2717 return x 2718 2719 cnt = torch._dynamo.testing.CompileCounter() 2720 opt_fn = torch._dynamo.optimize(cnt)(fn) 2721 x = torch.randn(4) 2722 opt_fn(x) 2723 self.assertEqual(cnt.frame_count, 1) 2724 self.assertEqual(cnt.op_count, 1) 2725 2726 def test_for_loop_graph_break_before(self): 2727 # Checks that the backedge is calculated correctly 2728 def inner(x): 2729 return torch.sin(x) 2730 2731 def fn(x): 2732 torch._dynamo.graph_break() 2733 for _ in range(100): 2734 inner(x) 2735 return x 2736 2737 cnt = torch._dynamo.testing.CompileCounter() 2738 opt_fn = torch._dynamo.optimize(cnt)(fn) 2739 x = torch.randn(4) 2740 opt_fn(x) 2741 self.assertEqual(cnt.frame_count, 1) 2742 self.assertEqual(cnt.op_count, 100) 2743 2744 def test_avoid_dupe_specialization(self): 2745 def f(x, y): 2746 return (x + y) * 1 2747 2748 opt_f = torch._dynamo.optimize("aot_eager")(f) 2749 2750 for b in [True, False]: 2751 x = torch.randn(4, requires_grad=b) 2752 y = torch.randn(4, requires_grad=b) 2753 self.assertEqual(f(x, x), opt_f(x, x)) 2754 self.assertEqual(f(x, y), opt_f(x, y)) 2755 2756 def test_validate_model_kwargs(self): 2757 cnt = CompileCounter() 2758 2759 def f1(a, b): 2760 return torch.sin(a) + torch.cos(b) 2761 2762 @torch.compile(backend=cnt, fullgraph=True) 2763 def f2(**kwargs): 2764 _validate_model_kwargs(f1, kwargs) 2765 return f1(**kwargs) 2766 2767 x = torch.randn(10) 2768 y = torch.randn(10) 2769 2770 self.assertEqual(f2(a=x, b=y), f1(x, y)) 2771 self.assertEqual(cnt.frame_count, 1) 2772 self.assertEqual(cnt.op_count, 3) 2773 2774 def test_swin_base_tensor_attr(self): 2775 class Foo(torch.nn.Module): 2776 def __init__(self) -> None: 2777 super().__init__() 2778 # NB: not a parameter or buffer 2779 self.t = torch.randn(3) 2780 2781 def forward(self, x): 2782 return x + torch.cat((self.t, self.t)) 2783 2784 mod = Foo() 2785 opt_mod = torch._dynamo.optimize("eager")(mod) 2786 args = [torch.randn(6)] 2787 self.assertTrue(same_two_models(mod, opt_mod, args)) 2788 opt_mod(*args) 2789 2790 def test_pointless_graph_removal(self): 2791 cnt = torch._dynamo.testing.CompileCounter() 2792 2793 @torch.compile(backend=cnt) 2794 def fn(x): 2795 with torch.no_grad(): 2796 torch._dynamo.graph_break() 2797 return x + 1 2798 2799 fn(torch.randn(4)) 2800 self.assertEqual(cnt.frame_count, 1) 2801 self.assertEqual(cnt.op_count, 3) 2802 2803 def test_output_aliases_intermediate(self): 2804 def f(x): 2805 intermediate = x.mul(2) 2806 return intermediate.view(-1), intermediate 2807 2808 opt_f = torch._dynamo.optimize("aot_eager")(f) 2809 2810 for b in [True, False]: 2811 x = torch.randn(4, requires_grad=b) 2812 out = f(x) 2813 out_test = opt_f(x) 2814 self.assertEqual(out[0], out_test[0]) 2815 self.assertEqual(out[1], out_test[1]) 2816 self.assertEqual(out[0].requires_grad, out_test[0].requires_grad) 2817 self.assertEqual(out[1].requires_grad, out_test[1].requires_grad) 2818 # test that the aliasing relationship of outputs is preserved 2819 out[0].mul_(2) 2820 out_test[0].mul_(2) 2821 self.assertEqual(out[0], out_test[0]) 2822 self.assertEqual(out[1], out_test[1]) 2823 2824 def test_while_loop_graph_break(self): 2825 # Repro of tacotron2 cache_size_recompilation 2826 def inner(x): 2827 return torch.sin(x) 2828 2829 def fn(x): 2830 i = 20 2831 while i > 10: 2832 x = inner(x) 2833 i -= 1 2834 torch._dynamo.graph_break() 2835 return x 2836 2837 cnt = torch._dynamo.testing.CompileCounter() 2838 opt_fn = torch._dynamo.optimize(cnt)(fn) 2839 x = torch.randn(4) 2840 opt_fn(x) 2841 self.assertEqual(cnt.frame_count, 1) 2842 self.assertEqual(cnt.op_count, 1) 2843 2844 def test_nested_while_loop_graph_break(self): 2845 def inner_loop(x): 2846 i = 3 2847 while i > 0: 2848 i -= 1 2849 x += 1 2850 torch._dynamo.graph_break() 2851 return x 2852 2853 def inner(x): 2854 inner_loop(x) 2855 return torch.sin(x) 2856 2857 def fn(x): 2858 i = 20 2859 while i > 10: 2860 x = inner(x) 2861 i -= 1 2862 torch._dynamo.graph_break() 2863 return x 2864 2865 cnt = torch._dynamo.testing.CompileCounter() 2866 opt_fn = torch._dynamo.optimize(cnt)(fn) 2867 x = torch.randn(4) 2868 opt_fn(x) 2869 self.assertEqual(cnt.frame_count, 1) 2870 self.assertEqual(cnt.op_count, 1) 2871 2872 def test_while_loop_graph_break_inside_call_function(self): 2873 # Repro of huggingface graph break inside loop in `get_parameter_dtype`. 2874 # Skip only the inner frame that has loop that contains graph break. 2875 def inner(x): 2876 for i in range(3): 2877 x += 1 2878 torch._dynamo.graph_break() 2879 return x 2880 2881 def fn(x): 2882 x += 2 2883 inner(x) 2884 x += 3 2885 return x 2886 2887 cnt = torch._dynamo.testing.CompileCounter() 2888 opt_fn = torch._dynamo.optimize(cnt)(fn) 2889 x = torch.randn(4) 2890 opt_fn(x) 2891 self.assertEqual(cnt.frame_count, 2) 2892 self.assertEqual(cnt.op_count, 2) 2893 2894 def test_exception_in_dynamo_handling(self): 2895 hit_handler = False 2896 2897 # See https://github.com/pytorch/pytorch/pull/96488 2898 @contextlib.contextmanager 2899 def ctx(): 2900 try: 2901 yield 2902 except RuntimeError: 2903 nonlocal hit_handler 2904 hit_handler = True 2905 2906 @torch._dynamo.optimize("eager") 2907 def f(): 2908 with ctx(): 2909 h() 2910 2911 def h(): 2912 raise RuntimeError("boof") 2913 2914 # Should not error 2915 f() 2916 self.assertTrue(hit_handler) 2917 2918 def test_generator_dealloc(self): 2919 # See https://github.com/pytorch/pytorch/pull/96488 2920 # 2921 # NB: yes, [(...)] is intentional, this is a list containing a 2922 # generator 2923 generator_box = [(x for x in [1, 2, 3])] 2924 2925 counter = torch._dynamo.testing.CompileCounter() 2926 2927 def g(x): 2928 return x + 2 2929 2930 # TODO: This test is pretty delicate. To test if it's actually doing 2931 # anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1' 2932 # and then look at the logs for: 2933 # 2934 # TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0 2935 # TRACE[_custom_eval_frame:664] throw <genexpr> 2936 # 2937 # This means we're actually hitting the relevant codepath 2938 2939 # NB: Make sure we don't actually Dynamo this frame; if we do Dynamo 2940 # this frame, Dynamo actually DOES understand list.clear and will 2941 # arrange for the generator deallocation to happen when the eval frame 2942 # handler is disabled, which will prevent the bug from happening (we 2943 # specifically want to trigger the generator deallocation WHILE the 2944 # dynamo eval frame handler is active), as that will cause the 2945 # generator to become exhausted and trigger the throw_flag == TRUE 2946 # case. 2947 @torch._dynamo.disable(recursive=False) 2948 def f(x): 2949 generator_box.clear() 2950 return g(x) 2951 2952 self.assertNoUnraisable( 2953 lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3)) 2954 ) 2955 2956 # Make sure the x + 2 is captured (a previous incorrect implementation 2957 # of this fix would have disabled the eval frame callback, which means 2958 # g wouldn't get traced 2959 self.assertEqual(counter.op_count, 1) 2960 2961 def test_error_return_without_exception_set(self): 2962 # https://github.com/pytorch/pytorch/issues/93781 2963 @torch.compile 2964 def f(): 2965 _generator_type = type(_ for _ in ()) 2966 2967 self.assertNoUnraisable(f) 2968 2969 def common_merge_criteria_processor_list(self, list_cls, fullgraph): 2970 cnt = CompileCounter() 2971 2972 @torch.compile(backend=cnt, fullgraph=fullgraph) 2973 def f(x, left, right): 2974 combined = _merge_criteria_processor_list(left, right) 2975 return combined(x) 2976 2977 l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()]) 2978 l2 = list_cls([]) 2979 input = torch.randn(16) 2980 result = f(input, l1, l2) 2981 self.assertEqual(result, l1(input)) 2982 self.assertEqual(cnt.frame_count, 1) 2983 self.assertEqual(cnt.op_count, 2) 2984 2985 cnt.clear() 2986 l3 = list_cls([torch.nn.SiLU()]) 2987 expected = l3(l1(input)) 2988 result = f(input, l1, l3) 2989 self.assertEqual(len(l1), 3) 2990 self.assertEqual(result, expected) 2991 self.assertEqual(cnt.frame_count, 1) 2992 self.assertEqual(cnt.op_count, 3) 2993 2994 def test_merge_criteria_processor_list1(self): 2995 self.common_merge_criteria_processor_list(CustomList1, False) 2996 2997 def test_merge_criteria_processor_list2(self): 2998 self.common_merge_criteria_processor_list(CustomList2, True) 2999 3000 def test_restricted_list_subclass1(self): 3001 cnt = CompileCounter() 3002 3003 @torch.compile(backend=cnt, fullgraph=True) 3004 def fn(a, b): 3005 l = CustomList2() 3006 l.extend([True]) 3007 l.append(a) 3008 l.extend([b]) 3009 l.pop(0) 3010 l.append(l.length_times_10()) 3011 return sum(l) 3012 3013 x = torch.randn(10) 3014 y = torch.randn(10) 3015 self.assertEqual(fn(x, y), x + y + 20) 3016 self.assertEqual(cnt.op_count, 3) 3017 3018 def test_restricted_list_subclass2(self): 3019 cnt = CompileCounter() 3020 3021 @torch.compile(backend=cnt, fullgraph=True) 3022 def fn(a, b): 3023 l1 = CustomList2([a + 1]) 3024 l2 = CustomList2([b + 2]) 3025 l1.extend(l2) 3026 return l1 3027 3028 x = torch.randn(10) 3029 y = torch.randn(10) 3030 z = fn(x, y) 3031 self.assertEqual(type(z), CustomList2) 3032 self.assertEqual(len(z), 2) 3033 self.assertEqual(z.length_times_10(), 20) 3034 self.assertEqual(list(z), [x + 1, y + 2]) 3035 3036 def test_restricted_list_subclass3(self): 3037 cnt = CompileCounter() 3038 3039 @torch.compile(backend=cnt, fullgraph=True) 3040 def fn(a: CustomList2, b: CustomList2): 3041 a.extend(b) 3042 a.append_twice(b[2] + 1) 3043 a.append(b[3] + 2) 3044 return b 3045 3046 x = torch.randn(10) 3047 y = torch.randn(10) 3048 l = CustomList2([x, y]) 3049 self.assertIs(fn(l, l), l) 3050 self.assertEqual(len(l), 7) 3051 self.assertIs(l[0], x) 3052 self.assertIs(l[1], y) 3053 self.assertIs(l[2], x) 3054 self.assertIs(l[3], y) 3055 self.assertEqual(l[4], x + 1) 3056 self.assertIs(l[5], l[4]) 3057 self.assertEqual(l[6], y + 2) 3058 3059 def test_rewrite_assert_with_msg(self): 3060 def f(x): 3061 b = x.sin() 3062 assert x[0] == 3, "First dim need to be 3" 3063 return x.cos() + b 3064 3065 args = (torch.Tensor([3, 4, 5]),) 3066 cnt = torch._dynamo.testing.CompileCounter() 3067 3068 opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3069 self.assertTrue(same(f(*args), opt_f(*args))) 3070 self.assertEqual(cnt.op_count, 6) 3071 self.assertEqual(cnt.frame_count, 1) 3072 3073 exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3074 self.assertTrue(same(exported(*args), f(*args))) 3075 3076 def test_list_aliasing(self): 3077 cnt = CompileCounter() 3078 3079 @torch.compile(backend=cnt, fullgraph=True) 3080 def fn(a): 3081 a.append(torch.sin(a[0])) 3082 return a 3083 3084 x = torch.randn(10) 3085 l = [x] 3086 self.assertIs(fn(l), l) 3087 self.assertEqual(len(l), 2) 3088 self.assertIs(l[0], x) 3089 self.assertEqual(l[1], torch.sin(x)) 3090 self.assertEqual(cnt.frame_count, 1) 3091 self.assertEqual(cnt.op_count, 1) 3092 3093 def test_not_rewrite_assert_for_other_errors(self): 3094 def f(x): 3095 b = x.sin() 3096 if not x.sum() <= 3: 3097 raise ValueError("input sum needs to be 3") 3098 return x.cos() + b 3099 3100 args = (torch.Tensor([3, 4, 5]),) 3101 opt_fn = torch._dynamo.optimize("eager")(f) 3102 with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): 3103 opt_fn(*args) 3104 3105 def test_rewrite_assert_dont_change_bytecode(self): 3106 def fn(x): 3107 with torch.no_grad(): 3108 assert x.max() < 5, f"invalid max {x.max()}" 3109 x = torch.sin(x) 3110 return x 3111 3112 x = torch.ones(4) 3113 opt_fn = torch._dynamo.optimize("eager")(fn) 3114 self.assertTrue(same(fn(x), opt_fn(x))) 3115 3116 def test_rewrite_assert_without_msg(self): 3117 def f(x): 3118 b = x.sin() 3119 assert x[0] == 3 3120 return x.cos() + b 3121 3122 args = (torch.Tensor([3, 4, 5]),) 3123 exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3124 self.assertTrue(same(exported(*args), f(*args))) 3125 3126 with self.assertRaisesRegex(RuntimeError, "assertion error"): 3127 exported(torch.Tensor([5, 6, 7])) 3128 3129 def test_rewrite_assert_with_non_string_msg(self): 3130 def f(x): 3131 b = x.sin() 3132 assert x[0] == 2, x.size() 3133 return x.cos() + b 3134 3135 torch._dynamo.utils.counters.clear() 3136 args = torch.Tensor([3, 4, 5]) 3137 opt_f = torch._dynamo.optimize("eager")(f) 3138 with self.assertRaisesRegex(AssertionError, "torch.Size"): 3139 opt_f(args) 3140 self.assertEqual( 3141 torch._dynamo.utils.counters["graph_break"][ 3142 "assert with non-string message" 3143 ], 3144 1, 3145 ) 3146 3147 def test_rewrite_assert_noop(self): 3148 def f(x): 3149 b = x.sin() 3150 assert True 3151 assert x.dtype == torch.float32 3152 return x.cos() + b 3153 3154 args = (torch.Tensor([3, 4, 5]),) 3155 exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5])) 3156 self.assertTrue(same(exported(*args), f(*args))) 3157 3158 cnt = torch._dynamo.testing.CompileCounter() 3159 opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3160 self.assertTrue(same(f(*args), opt_f(*args))) 3161 # torch._assert shouldn't be in the graph 3162 self.assertEqual(cnt.op_count, 3) 3163 self.assertEqual(cnt.frame_count, 1) 3164 3165 exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5])) 3166 self.assertTrue(same(exported(*args), f(*args))) 3167 3168 def test_size_typematch(self): 3169 def f(x, y): 3170 if isinstance(x, torch.Size): 3171 return y + 1 3172 else: 3173 return y + 2 3174 3175 y = torch.zeros(1) 3176 x1 = torch.Size((3,)) 3177 x2 = (3,) 3178 3179 cnt = torch._dynamo.testing.CompileCounter() 3180 opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) 3181 self.assertTrue(same(f(x1, y), opt_f(x1, y))) 3182 self.assertTrue(same(f(x2, y), opt_f(x2, y))) 3183 self.assertEqual(cnt.frame_count, 2) 3184 3185 def test_dict_subclass_contains(self): 3186 # pattern from huggingface 3187 class ClassInstantier(collections.OrderedDict): 3188 pass 3189 3190 @torch.compile(fullgraph=True, backend="eager") 3191 def f(x, d): 3192 if "key1" in d: 3193 x = x + 2 3194 if "key2" in d: 3195 x = x + 4 3196 x = x + 8 3197 return x 3198 3199 result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)})) 3200 self.assertTrue(same(result, torch.full([8], 11.0))) 3201 3202 result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)})) 3203 self.assertTrue(same(result, torch.full([8], 13.0))) 3204 3205 def test_hf_classinstantier(self): 3206 # hf activations.py 3207 class ClassInstantier(collections.OrderedDict): 3208 def __getitem__(self, key): 3209 content = super().__getitem__(key) 3210 cls, kwargs = content if isinstance(content, tuple) else (content, {}) 3211 return cls(**kwargs) 3212 3213 ACT2CLS = ClassInstantier( 3214 { 3215 "relu": (nn.ReLU, {"inplace": False}), 3216 "tanh": nn.Tanh, 3217 } 3218 ) 3219 3220 @torch.compile(fullgraph=True, backend="eager") 3221 def f(x, act): 3222 return ACT2CLS[act](x) 3223 3224 y = torch.randn(10) 3225 self.assertTrue(same(f(y, "tanh"), torch.tanh(y))) 3226 self.assertTrue(same(f(y, "relu"), torch.relu(y))) 3227 3228 def test_ephemeral_module(self): 3229 # hf activations.py 3230 class ReLUSquaredActivation(nn.Module): 3231 def forward(self, input): 3232 relu_applied = torch.nn.functional.relu(input) 3233 squared = torch.square(relu_applied) 3234 return squared 3235 3236 @torch.compile(fullgraph=True, backend="eager") 3237 def f(x): 3238 x = x + 0.2 3239 x = ReLUSquaredActivation()(x) 3240 x = x + 1 3241 return x 3242 3243 y = torch.randn(10) 3244 self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1)) 3245 3246 def test_inplace_unsqueeze_input(self): 3247 def backend(gm, example_inputs): 3248 self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4])) 3249 return gm 3250 3251 @torch.compile(backend=backend) 3252 def fn(x): 3253 x.unsqueeze_(0) 3254 return x + 1 3255 3256 inputs = [torch.randn(3, 4)] 3257 self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4])) 3258 self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4])) 3259 3260 def test_batchnorm_e2e(self): 3261 class Repro(torch.nn.Module): 3262 def __init__(self) -> None: 3263 super().__init__() 3264 self.bn = torch.nn.BatchNorm2d( 3265 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True 3266 ) 3267 self.conv1 = torch.nn.Conv2d( 3268 64, 3269 64, 3270 kernel_size=(3, 3), 3271 stride=(1, 1), 3272 padding=(1, 1), 3273 bias=False, 3274 ) 3275 3276 def forward(self, x): 3277 x1 = self.bn(x) 3278 x2 = self.conv1(x1) 3279 out = torch.nn.functional.relu(x2) 3280 return (out,) 3281 3282 torch.manual_seed(1337) 3283 3284 m_ref = Repro() 3285 m_test = deepcopy(m_ref) 3286 3287 @torch._dynamo.optimize("aot_eager_decomp_partition") 3288 def compiled_fn(x): 3289 return m_test(x) 3290 3291 x_ref = torch.randn(2, 64, 32, 32, requires_grad=True) 3292 x_test = x_ref.clone() 3293 3294 # Loop multiple times: each iteration the running_mean/var on batchnorm will update, 3295 # which changes the output of the next iteration 3296 for _ in range(3): 3297 ref = m_ref(x_ref) 3298 res = compiled_fn(x_test) 3299 3300 self.assertTrue(same(ref, res)) 3301 3302 for r in ref: 3303 if r.requires_grad: 3304 r.sum().backward() 3305 for r in res: 3306 if r.requires_grad: 3307 r.sum().backward() 3308 3309 for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()): 3310 self.assertTrue(same(param_ref, param_test)) 3311 # Assert running_mean/var 3312 for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()): 3313 self.assertTrue(same(buffer_ref, buffer_test)) 3314 3315 @torch._dynamo.config.patch("assume_static_by_default", False) 3316 def test_dynamic_shapes_right_side(self): 3317 def f(x): 3318 return torch.ones(5 * x.shape[0]) 3319 3320 inp = torch.randn(6, 5) 3321 3322 gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5)) 3323 self.assertEqual(gm(inp).shape, f(inp).shape) 3324 3325 @torch._dynamo.config.patch("specialize_int", False) 3326 def test_maybe_multiply_symint(self): 3327 # https://github.com/pytorch/pytorch/issues/97346 3328 from torch._functorch.aot_autograd import aot_module_simplified 3329 3330 def my_aot_compiler(gm, example_inputs): 3331 def my_compiler(gm, example_inputs): 3332 return gm.forward 3333 3334 # Invoke AOTAutograd 3335 return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler) 3336 3337 def my_example(t1, t2, d): 3338 out = torch.add(t1, t2, alpha=d) 3339 return out 3340 3341 compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example) 3342 3343 t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True) 3344 t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True) 3345 3346 ra = compiled_fn(t1, t2, 5) 3347 self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0])) 3348 3349 ra = compiled_fn(t1, t2, 6) 3350 self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0])) 3351 3352 def test_build_map_unpack_with_call(self): 3353 def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2): 3354 return x.sin() + t + cond_scale + self_cond + other1 + other2 3355 3356 @torch.compile(backend="eager", fullgraph=True) 3357 def fn(x): 3358 d1 = dict(other1=5) 3359 d2 = dict(other2=4) 3360 text_cond = {**d1, **d2} 3361 return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond) 3362 3363 self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15)) 3364 3365 @torch._dynamo.config.patch(verbose=True) 3366 def test_graph_break_unsupported_fake(self): 3367 counter = torch._dynamo.testing.CompileCounter() 3368 3369 @torch._dynamo.optimize(counter) 3370 def f(x): 3371 return torch.ops.test_sample.foo(x + 1) + 1 3372 3373 f(torch.randn(3)) 3374 3375 self.assertEqual(counter.op_count, 2) 3376 self.assertEqual(counter.frame_count, 2) 3377 3378 def test_delattr(self): 3379 class MyObj: 3380 def __init__(self, a, b): 3381 self.a = a 3382 self.b = b 3383 3384 @torch.compile(backend="eager", fullgraph=True) 3385 def fn(x, obj): 3386 del obj.a 3387 obj.c = x + 1 3388 del obj.c 3389 tmp = MyObj(x + 2, x + 3) 3390 del tmp.b 3391 if hasattr(obj, "a"): 3392 return x + 1 3393 return tmp 3394 3395 x = torch.zeros([]) 3396 obj1 = MyObj(x, x) 3397 obj2 = fn(x, obj1) 3398 self.assertFalse(hasattr(obj1, "a")) 3399 self.assertFalse(hasattr(obj1, "c")) 3400 self.assertFalse(hasattr(obj2, "b")) 3401 self.assertEqual(obj1.b.item(), 0) 3402 self.assertEqual(obj2.a.item(), 2) 3403 3404 def test_delattr_raises(self): 3405 class MyObj: 3406 def __init__(self, a, b): 3407 self.a = a 3408 self.b = b 3409 3410 @torch.compile(backend="eager") 3411 def fn(x, obj): 3412 del obj.a 3413 x = x + 1 3414 obj.a # will raise 3415 return x 3416 3417 x = torch.zeros([]) 3418 obj1 = MyObj(x, x) 3419 self.assertRaises(AttributeError, lambda: fn(x, obj1)) 3420 3421 def test_delsubscr(self): 3422 @torch.compile(backend="eager") 3423 def fn(x): 3424 del x["a"] 3425 y = x["b"] + 1 3426 return y 3427 3428 x = {"a": torch.tensor([1]), "b": torch.tensor([1])} 3429 result = fn(x) 3430 self.assertFalse(hasattr(x, "a")) 3431 self.assertEqual(result.item(), 2) 3432 3433 def test_delsubscr_raises(self): 3434 @torch.compile(backend="eager") 3435 def fn(x): 3436 del x["a"] 3437 y = x["a"] + 1 # should raise KeyError 3438 return y 3439 3440 x = {"a": torch.tensor([1]), "b": torch.tensor([1])} 3441 self.assertRaises(KeyError, lambda: fn(x)) 3442 3443 def test_attached_attribute_in_dir(self): 3444 class MyModule(torch.nn.Module): 3445 def __init__(self) -> None: 3446 super().__init__() 3447 self.linear = torch.nn.Linear(16, 16) 3448 self.relu = torch.nn.ReLU() 3449 3450 def forward(self, x): 3451 return self.relu(self.linear(x)) 3452 3453 mod = torch.compile(MyModule(), backend="eager") 3454 mod.is_compiled = True 3455 self.assertTrue("is_compiled" in dir(mod)) 3456 3457 @torch._dynamo.config.patch("automatic_dynamic_shapes", False) 3458 def test_dynamic_shapes_implicit_guard(self): 3459 def f(x): 3460 y = x * x.size(x.shape[0]) 3461 torch.sum(y, [y.shape[0]]) 3462 return y 3463 3464 cnt = torch._dynamo.testing.CompileCounter() 3465 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 3466 opt_fn(torch.randn(3, 1, 1, 1, 1)) 3467 self.assertEqual(cnt.frame_count, 1) 3468 3469 def test_dalle2_maybe(self): 3470 def normalize(x): 3471 return x.cos() 3472 3473 @torch.compile(backend="eager", fullgraph=True) 3474 def fn(x, normalize_img): 3475 lowres_cond_img = x.sin() 3476 lowres_cond_img = maybe(normalize_img)(lowres_cond_img) 3477 return lowres_cond_img 3478 3479 self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos()) 3480 3481 def test_functools_wraps(self): 3482 def cool_name(x): 3483 return x.sin() 3484 3485 @torch.compile(backend="eager", fullgraph=True) 3486 def fn(x): 3487 y = x.cos() 3488 3489 @functools.wraps(cool_name) 3490 def uncool_name(): 3491 return cool_name(y) 3492 3493 return uncool_name 3494 3495 result = fn(torch.ones([])) 3496 self.assertEqual(result.__name__, "cool_name") 3497 self.assertEqual(result(), torch.ones([]).cos().sin()) 3498 3499 def test_dynamic_shapes_float_guard(self): 3500 def f(x): 3501 return torch.nn.functional.dropout(x, x.shape[0] / 6) 3502 3503 cnt = torch._dynamo.testing.CompileCounter() 3504 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f) 3505 opt_fn(torch.randn(3)) 3506 self.assertEqual(cnt.frame_count, 1) 3507 3508 @torch._dynamo.config.patch(capture_scalar_outputs=True) 3509 def test_tensor_item(self): 3510 def f(x, y): 3511 val = y.item() 3512 return x.sum() + val 3513 3514 gm, _ = torch._dynamo.export( 3515 f, 3516 aten_graph=True, 3517 )( 3518 torch.zeros(6, 4), 3519 torch.tensor(1), 3520 ) 3521 self.assertEqual( 3522 f(torch.zeros(6, 4), torch.tensor(1)), 3523 gm(torch.zeros(6, 4), torch.tensor(1)), 3524 ) 3525 self.assertEqual( 3526 f(torch.zeros(6, 4), torch.tensor(2)), 3527 gm(torch.zeros(6, 4), torch.tensor(2)), 3528 ) 3529 3530 def test_dataclass_init_with_default_factory_with_inputs(self): 3531 @dataclasses.dataclass 3532 class DClass: 3533 sharding_contexts: Any = dataclasses.field(default_factory=list) 3534 a: int = 1 3535 3536 def fn(x, inp_list): 3537 d = DClass(inp_list) 3538 d.sharding_contexts.append(x.sin() + d.a) 3539 return d 3540 3541 x = torch.randn(4) 3542 inp_list1 = [1, 2, 3] 3543 inp_list2 = [2, 3, 4] 3544 inp_list3 = [1, 2] 3545 ref1 = fn(x, inp_list1) 3546 ref2 = fn(x, inp_list2) 3547 ref3 = fn(x, inp_list3) 3548 3549 cnt = torch._dynamo.testing.CompileCounter() 3550 opt_fn = torch.compile(fn, fullgraph=True) 3551 3552 opt_ret1 = opt_fn(x, inp_list1) 3553 opt_ret2 = opt_fn(x, inp_list2) 3554 opt_ret3 = opt_fn(x, inp_list3) 3555 self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts) 3556 self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts) 3557 self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts) 3558 3559 def test_list_index(self): 3560 for i, list_type in enumerate( 3561 ( 3562 list, 3563 tuple, 3564 torch.Size, 3565 collections.deque, 3566 namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]), 3567 ) 3568 ): 3569 torch._dynamo.reset() 3570 for index in ([], [2], [0, 3]): 3571 3572 def f(t): 3573 if i == 4: # namedtuple 3574 xs = list_type(1, 2, 3, 4) 3575 else: 3576 xs = list_type([1, 2, 3, 4]) 3577 res = xs.index(3, *index) 3578 return t + res 3579 3580 res = torch._dynamo.optimize(backend="eager", nopython=True)(f)( 3581 torch.zeros(1) 3582 ) 3583 3584 self.assertEqual(res, torch.tensor([2.0])) 3585 3586 def test_list_index_not_found(self): 3587 def f(t): 3588 xs = ["bar", "foo", "baz", "buzz"] 3589 res = xs.index("non-existent") 3590 return t + res 3591 3592 # Raising ValueError from item not found is unsupported 3593 with self.assertRaises( 3594 torch._dynamo.exc.Unsupported, 3595 ): 3596 torch._dynamo.optimize(backend="eager", nopython=True)(f)(torch.zeros(1)) 3597 3598 def test_list_index_tensor_unsupported(self): 3599 for index in ([], [2], [0, 3]): 3600 3601 def f(t): 3602 xs = [torch.tensor([i]) for i in range(4)] 3603 res = xs.index(torch.tensor([2]), *index) 3604 return t + res 3605 3606 with self.assertRaisesRegex( 3607 torch._dynamo.exc.UserError, "Dynamic control flow is not supported" 3608 ): 3609 torch._dynamo.optimize(backend="eager", nopython=True)(f)( 3610 torch.zeros(1) 3611 ) 3612 3613 def test_hf_xsoftmax_inference(self): 3614 def fn(input, mask): 3615 return XSoftmax.apply(input + 1, mask, 1) + 2 3616 3617 fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 3618 3619 inputs = [ 3620 torch.randn(4, 10), 3621 torch.randn(4, 10) < 0, 3622 ] 3623 expected = fn(*inputs) 3624 actual = fn_opt(*inputs) 3625 self.assertTrue(same(actual, expected)) 3626 3627 @mock.patch("torch._dynamo.config.guard_nn_modules", True) 3628 def test_hf_xsoftmax_training(self): 3629 from torch._dynamo.utils import counters 3630 3631 counters.clear() 3632 3633 def fn(input, mask): 3634 return XSoftmax.apply(input, mask, 1) 3635 3636 cnt = torch._dynamo.testing.CompileCounter() 3637 fn_opt = torch.compile(fn, backend=cnt, fullgraph=False) 3638 3639 torch.manual_seed(1234) 3640 inputs1 = [ 3641 torch.randn(4, 10, requires_grad=True), 3642 torch.randn(4, 10) < 0, 3643 ] 3644 torch.manual_seed(1234) 3645 inputs2 = [ 3646 torch.randn(4, 10, requires_grad=True), 3647 torch.randn(4, 10) < 0, 3648 ] 3649 3650 expected = fn(*inputs1) 3651 actual = fn_opt(*inputs2) 3652 self.assertTrue(same(actual, expected)) 3653 self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1}) 3654 self.assertEqual(cnt.op_count, 2) 3655 self.assertEqual(cnt.frame_count, 1) 3656 cnt.clear() 3657 counters.clear() 3658 3659 expected.sum().backward() 3660 actual.sum().backward() 3661 self.assertTrue(same(inputs1[0].grad, inputs2[0].grad)) 3662 3663 # currently we don't capture the backwards frame 3664 self.assertEqual(cnt.frame_count, 0) 3665 self.assertEqual(cnt.op_count, 0) 3666 self.assertEqual(dict(counters["frames"]), {}) 3667 self.assertEqual(dict(counters["graph_break"]), {}) 3668 3669 def test_autograd_function_graph_break(self): 3670 class MySin(torch.autograd.Function): 3671 @staticmethod 3672 def forward(ctx, x): 3673 torch._dynamo.graph_break() 3674 ctx.save_for_backward(x) 3675 return x.sin() 3676 3677 @staticmethod 3678 def backward(ctx, gx): 3679 (x,) = ctx.saved_tensors 3680 return gx * x.cos() 3681 3682 x = torch.randn([], requires_grad=True) 3683 3684 @torch.compile(backend="eager") 3685 def fn(x): 3686 return MySin.apply(x) 3687 3688 y = fn(x) 3689 self.assertEqual(y, x.sin()) 3690 3691 (gx,) = torch.autograd.grad(y, x) 3692 self.assertEqual(gx, x.cos()) 3693 3694 def test_jit_trace_errors(self): 3695 @torch.compile(backend="eager", dynamic=True) 3696 def f(x): 3697 return x + 1 3698 3699 with self.assertRaises(RuntimeError): 3700 torch.jit.trace(f, torch.randn(3)) 3701 3702 with torch._dynamo.config.patch(error_on_nested_jit_trace=False): 3703 torch.jit.trace(f, torch.randn(3)) 3704 3705 @torch._dynamo.config.patch("assume_static_by_default", False) 3706 def test_tensor_split(self): 3707 def f(x): 3708 return torch.split(x, x.shape[0] // 2, dim=0)[0] 3709 3710 gm, _ = torch._dynamo.export( 3711 f, 3712 aten_graph=True, 3713 )( 3714 torch.zeros(6, 4), 3715 ) 3716 3717 self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4))) 3718 3719 def test_optim_state_references_cleared(self): 3720 model = torch.nn.Linear(2048, 2048, bias=False) 3721 x = torch.ones(2048) 3722 state_ref = 0 3723 3724 optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01) 3725 3726 def opt_step(): 3727 optimizer.step() 3728 3729 compiled_opt_step = torch._dynamo.optimize("eager")(opt_step) 3730 3731 def compiled_model_step(x): 3732 optimizer.zero_grad() 3733 y = model(x) 3734 torch.sum(y).backward() 3735 compiled_opt_step() 3736 3737 compiled_model_step(x) 3738 3739 # Picked "square_avg" arbitrarily to check that 3740 # optimizer state tensors are deallocated 3741 state_ref = weakref.ref( 3742 optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"] 3743 ) 3744 optimizer = None 3745 3746 self.assertIsNone(state_ref()) 3747 3748 def test_grad_references_cleared(self): 3749 model = torch.nn.Linear(2048, 2048, bias=False) 3750 x = torch.ones(2048) 3751 optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01) 3752 3753 def opt_step(): 3754 optimizer.step() 3755 3756 compiled_opt_step = torch._dynamo.optimize("eager")(opt_step) 3757 3758 def compiled_model_step(x): 3759 optimizer.zero_grad(True) 3760 y = model(x) 3761 torch.sum(y).backward() 3762 compiled_opt_step() 3763 3764 compiled_model_step(x) 3765 param_grad_ref = weakref.ref(next(iter(model.parameters())).grad) 3766 optimizer.zero_grad(True) 3767 self.assertIsNone(param_grad_ref()) 3768 3769 def test_batch_encoding_clone_inputs(self): 3770 class BatchEncoding(dict): 3771 """ 3772 Copied from test_tokenization 3773 """ 3774 3775 def __init__( 3776 self, 3777 data, 3778 ): 3779 super().__init__(data) 3780 3781 def __getattr__(self, item: str): 3782 try: 3783 return self.data[item] 3784 except KeyError as e: 3785 raise AttributeError from e 3786 3787 encoding = BatchEncoding({"key": torch.rand((1, 4))}) 3788 cloned_encoding = torch._dynamo.utils.clone_inputs(encoding) 3789 self.assertTrue(type(cloned_encoding) is not dict) 3790 3791 def test_iadd_graph_break(self): 3792 def fn(x): 3793 a = () 3794 x = torch.sin(x) 3795 a += (x,) 3796 return a 3797 3798 x = torch.randn(4) 3799 ref = fn(x) 3800 3801 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 3802 res = opt_fn(x) 3803 self.assertTrue(same(ref, res)) 3804 3805 def test_odict_get_item_index_name(self): 3806 d = {float: torch.float32, np.float16: torch.float16} 3807 3808 @torch.compile(backend="eager") 3809 def f(x, y1, y2): 3810 return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2]) 3811 3812 f(torch.zeros(4), float, np.float16) 3813 3814 def test_dedup_global(self): 3815 @torch.compile() 3816 def f(): 3817 return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR 3818 3819 self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR) 3820 3821 def test_randint_out_dynamic(self): 3822 def randint_fn(high, size, out): 3823 return torch.randint(high, size, out=out) 3824 3825 opt_model = torch.compile(randint_fn) 3826 3827 out1 = torch.empty(10, dtype=torch.int32) 3828 opt_model(17, (10,), out1) 3829 3830 out2 = torch.empty(12, dtype=torch.int32) 3831 opt_model(17, (12,), out2) 3832 3833 @requires_cuda 3834 def test_guard_default_device(self): 3835 try: 3836 torch.set_default_device("cuda") 3837 3838 counter = torch._dynamo.testing.CompileCounter() 3839 3840 @torch._dynamo.optimize(counter) 3841 def f(): 3842 x = torch.randn(3) 3843 return x * 2 3844 3845 self.assertEqual(f().device.type, "cuda") 3846 self.assertEqual(counter.frame_count, 1) 3847 3848 torch.set_default_device("cpu") 3849 3850 self.assertEqual(f().device.type, "cpu") 3851 self.assertEqual(counter.frame_count, 2) 3852 3853 finally: 3854 torch.set_default_device(None) 3855 3856 def test_list_self_reference(self): 3857 # Issue - https://github.com/pytorch/pytorch/issues/100150 3858 root = [] 3859 root[:] = [root, root, None, None] 3860 3861 @torch._dynamo.optimize("eager") 3862 def test_bug(): 3863 return root 3864 3865 test_bug() 3866 3867 def test_hf_bigbird_unsqueeze(self): 3868 def torch_bmm_nd(inp_1, inp_2, ndim=None): 3869 torch._dynamo.graph_break() 3870 return torch.bmm(inp1, inp2) 3871 3872 def fn(inp1, inp2, inp3, inp4, c): 3873 a = torch_bmm_nd(inp1, inp2, 4) 3874 a.unsqueeze_(2) 3875 a = a * 2 3876 3877 b = torch_bmm_nd(inp3, inp4, 4) 3878 b.unsqueeze_(2) 3879 l = a + b 3880 3881 out = torch.cat([a, b, c], dim=2) 3882 return out, l 3883 3884 inp1 = torch.rand(1, 64, 448) 3885 inp2 = torch.rand(1, 448, 64) 3886 inp3 = torch.rand(1, 64, 448) 3887 inp4 = torch.rand(1, 448, 64) 3888 c = torch.rand(1, 64, 1, 64) 3889 3890 cnt = torch._dynamo.testing.CompileCounter() 3891 opt_fn = torch._dynamo.optimize(cnt)(fn) 3892 opt_fn(inp1, inp2, inp3, inp4, c) 3893 self.assertEqual(cnt.frame_count, 3) 3894 3895 def test_torch_variable_type(self): 3896 # from torchvision 3897 def check_type(obj, types_or_checks): 3898 for type_or_check in types_or_checks: 3899 if ( 3900 isinstance(obj, type_or_check) 3901 if isinstance(type_or_check, type) 3902 else type_or_check(obj) 3903 ): 3904 return True 3905 return False 3906 3907 opt_check_type = torch._dynamo.optimize("eager")(check_type) 3908 ref = check_type(torch.randn(4), [torch.Tensor]) 3909 res = opt_check_type(torch.randn(4), [torch.Tensor]) 3910 self.assertEqual(ref, res) 3911 3912 # Test for https://github.com/pytorch/pytorch/issues/103132 3913 @torch._dynamo.config.patch("assume_static_by_default", False) 3914 def test_inference_mode_dynamic_shapes(self): 3915 class Repro(torch.nn.Module): 3916 def __init__(self) -> None: 3917 super().__init__() 3918 3919 def forward(self, param): 3920 z = torch.matmul(param, param) 3921 return z 3922 3923 model = Repro() 3924 # Need a 3d tensor to actually cause the error: 3925 # we go down a path of the C++ matmul decomp that calls sizes(). 3926 inp = torch.randn(4, 4, 4, requires_grad=True) 3927 model = torch.compile(model, backend="aot_eager", dynamic=True) 3928 with torch.inference_mode(): 3929 model(inp) 3930 3931 def test_kwargs_out_list_variable(self): 3932 class Repro(torch.nn.Module): 3933 def __init__(self) -> None: 3934 super().__init__() 3935 3936 def forward(self, param): 3937 z = torch.frexp(**param) 3938 return z 3939 3940 model = Repro() 3941 params = {"input": torch.tensor([[0.0, 1, 2, 4]])} 3942 params["out"] = [ 3943 torch.empty(0, dtype=torch.float32), # mantissa 3944 torch.empty(0, dtype=torch.int32), # exponent 3945 ] 3946 3947 model = torch.compile(model, backend="eager") 3948 mantissa, exponent = model(params) 3949 ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]]) 3950 ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32) 3951 self.assertEqual(ref_mantissa, mantissa) 3952 self.assertEqual(ref_exponent, exponent) 3953 3954 @torch._dynamo.config.patch(capture_scalar_outputs=True) 3955 def test_split_with_sizes_aot_autograd(self): 3956 def fn(result, split_sizes): 3957 rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist()) 3958 return rs 3959 3960 example_inputs = ( 3961 torch.randn(32, requires_grad=True), 3962 torch.tensor((7, 16, 9)), 3963 ) 3964 actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs) 3965 expected = fn(*example_inputs) 3966 self.assertEqual(actual, expected) 3967 3968 def test_unspecialized_nn_module_with_torch_variable_attribute(self): 3969 """ 3970 In this case self.fn = something that should be a TorchVariable. 3971 When it's not a TorchVariable, dynamo tries to trace through and fails. 3972 This makes sure that the self.fn is handled as a TorchVariable. 3973 """ 3974 3975 class UserModule(torch.nn.Module): 3976 torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule 3977 3978 def __init__(self, fn): 3979 super().__init__() 3980 self.fn = fn 3981 3982 def forward(self, **inp): 3983 return self.fn(**inp) 3984 3985 inputs = { 3986 "input": torch.randn([2, 9]).uniform_(0, 1), 3987 "target": torch.randn([2, 9]).uniform_(0, 1), 3988 "reduction": "mean", 3989 } 3990 3991 mod = UserModule(torch.nn.functional.binary_cross_entropy) 3992 ref = mod(**inputs) 3993 res = torch._dynamo.optimize("eager", nopython=True)(mod)(**inputs) 3994 self.assertEqual(ref, res) 3995 3996 def test_call_finally_python_3_8(self): 3997 # Issue - https://github.com/pytorch/pytorch/issues/97811 3998 def make_fn(g): 3999 def fn(): 4000 while True: 4001 try: 4002 print(g) 4003 break 4004 except Exception as _: 4005 break 4006 4007 return torch.compile(fn, backend="eager") 4008 4009 make_fn(None)() 4010 4011 def test_call_finally_python_3_8_2(self): 4012 def f(x): 4013 while x: 4014 try: 4015 pass 4016 except Exception as _: 4017 continue 4018 4019 torch.compile(f, backend="eager")(0) 4020 4021 def test_call_finally_opcode_python_3_8(self): 4022 def fn(): 4023 try: 4024 return torch.zeros(4) 4025 finally: 4026 return torch.ones(4) # noqa: SIM107, B012 4027 4028 result = torch.compile(fn, backend="aot_eager")() 4029 self.assertEqual(result, torch.ones(4)) 4030 4031 def test_string_format(self): 4032 s = "temp{i}" 4033 4034 @torch.compile(backend="eager", fullgraph=True) 4035 def fn(x): 4036 if s.format(i=4) == "temp4": 4037 return torch.sin(x) 4038 return torch.cos(x) 4039 4040 x = torch.randn(4) 4041 self.assertEqual(fn(x), torch.sin(x)) 4042 4043 # Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards' 4044 # due to bad empty list handling 4045 def test_empty_list_contains_with_jump(self): 4046 def fn(x, l): 4047 if x in l: 4048 return x.cos() 4049 return x.sin() 4050 4051 counter = CompileCounter() 4052 compiled_fn = torch._dynamo.optimize(counter)(fn)(torch.randn([2, 2]), []) 4053 self.assertEqual(counter.frame_count, 1) 4054 4055 def test_graph_break_on_jit_isinstance(self): 4056 @torch.compile(backend="eager") 4057 def fn(x): 4058 if torch.jit.isinstance(x, List[str]): 4059 return x * 2 4060 return x 4061 4062 opt_fn = torch.compile(fn, backend="eager") 4063 x = torch.rand(4) 4064 self.assertTrue(same(fn(x), opt_fn(x))) 4065 4066 def test_add_sub_alpha_out(self): 4067 inp = torch.randn(2, 3, 4) 4068 other = 1 4069 alpha = 2 4070 for op in [torch.add, torch.sub]: 4071 out = torch.zeros(2, 3, 4) 4072 compile_out = torch.zeros(2, 3, 4) 4073 op(inp, other, alpha=alpha, out=out) 4074 compiled_fn = torch.compile(op, dynamic=True) 4075 compiled_fn(inp, other, alpha=alpha, out=compile_out) 4076 self.assertTrue(same(out, compile_out)) 4077 4078 def test_negative_shape_guard(self): 4079 def fn(x): 4080 if x.size() != (5, 1, 2, 3): 4081 return x.cos() 4082 return x.sin() 4083 4084 counter = torch._dynamo.testing.CompileCounter() 4085 opt_fn = torch.compile(fn, backend=counter, dynamic=True) 4086 4087 x = torch.ones(5, 1, 3, 4) 4088 x2 = torch.ones(5, 1, 2, 3) 4089 self.assertEqual(fn(x), opt_fn(x)) 4090 self.assertEqual(fn(x2), opt_fn(x2)) 4091 self.assertEqual(counter.frame_count, 2) 4092 4093 @torch._dynamo.config.patch(capture_scalar_outputs=True) 4094 def test_deferred_runtime_asserts(self): 4095 @torch.compile(fullgraph=True) 4096 def f(x): 4097 y = x.item() 4098 torch._check_is_size(y) 4099 if y >= 0: 4100 return x * 2 4101 else: 4102 return x * 3 4103 4104 f(torch.tensor([3])) 4105 self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2]))) 4106 4107 def test_addr_alpha_beta_out(self): 4108 inp = torch.randn(2, 3) 4109 vec1 = torch.randn(2) 4110 vec2 = torch.randn(3) 4111 alpha = 2 4112 beta = 5 4113 4114 out = torch.zeros(2, 3) 4115 compile_out = torch.zeros(2, 3) 4116 4117 torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out) 4118 compiled_fn = torch.compile(torch.addr, dynamic=True) 4119 compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out) 4120 self.assertTrue(same(out, compile_out)) 4121 4122 def test_setattr_requires_grad_graph_breaks(self): 4123 def fn(x): 4124 z = x + 4 4125 x.requires_grad = True 4126 y = x * z 4127 return y 4128 4129 for backend in ["count", "eager", "aot_eager"]: 4130 if backend == "count": 4131 backend = CompileCounter() 4132 opt_fn = torch.compile(fn, backend=backend) 4133 4134 eager = torch.zeros(5) 4135 compiled = eager.clone() 4136 4137 out_eager = fn(eager) 4138 out_opt = opt_fn(compiled) 4139 4140 self.assertEqual(out_eager, out_opt) 4141 4142 out_eager.sum().backward() 4143 out_opt.sum().backward() 4144 4145 self.assertEqual(eager, compiled) 4146 if isinstance(backend, CompileCounter): 4147 self.assertEqual(backend.frame_count, 2) # graph breaks 4148 4149 def test_dynamic_shapes_double_not_equal(self): 4150 # https://github.com/pytorch/pytorch/issues/113393 4151 def fn(x): 4152 if x.size() != (5, 1, 2, 3): 4153 return x.cos() 4154 return x.sin() 4155 4156 opt_fn = torch.compile(fn, backend="eager") 4157 4158 x = torch.ones(5, 1, 2, 3) 4159 x2 = torch.ones(5, 1, 3, 4) 4160 self.assertEqual(fn(x), opt_fn(x)) 4161 self.assertEqual(fn(x2), opt_fn(x2)) 4162 4163 def test_inductor_no_recursionerror_on_for_loops(self): 4164 def forward(x): 4165 for _ in range(1000): 4166 x = 1.0 * x 4167 return x 4168 4169 self.assertTrue( 4170 same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0])) 4171 ) 4172 4173 def test_user_defined_object_callable(self): 4174 # https://github.com/pytorch/pytorch/issues/114019 4175 class MyCallable: 4176 def __call__(self, x): 4177 return x + 1 4178 4179 def fn(x): 4180 # Create in graph - will not have source 4181 return MyCallable()(x) 4182 4183 fn_opt = torch.compile(fn, backend="eager", fullgraph=True) 4184 self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1))) 4185 4186 @torch._dynamo.config.patch(log_compilation_metrics=True) 4187 def test_many_views_with_mutation(self): 4188 # When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap 4189 # began adding shape guards - a quadratic amount relative to the number of inputs. 4190 # Test this configuration, and test that a reasonable number of guards are added. 4191 # Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards. 4192 def fn(x): 4193 x[0].relu_() 4194 return torch.cat(x).sum() 4195 4196 AMT = 32 4197 src = torch.rand(16 * (AMT + 1)) 4198 4199 x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)] 4200 4201 torch._dynamo.reset() 4202 torch._dynamo.utils.clear_compilation_metrics() 4203 4204 res = torch.compile(fn, backend="aot_eager")(x) 4205 4206 all_metrics = torch._dynamo.utils.get_compilation_metrics() 4207 4208 total_guards = sum(metric.guard_count for metric in all_metrics) 4209 self.assertLess(total_guards, AMT * 8) 4210 4211 total_shape_env_guards = sum( 4212 metric.shape_env_guard_count for metric in all_metrics 4213 ) 4214 self.assertLess(total_shape_env_guards, AMT * 8) 4215 4216 # https://github.com/pytorch/pytorch/issues/118799 4217 def test_subclass_graph_output_repro(self): 4218 @torch._dynamo.allow_in_graph 4219 def to_subclass(x): 4220 return TwoTensor(x.clone(), x.clone()) 4221 4222 def f(x): 4223 tmp_subclass = to_subclass(x) 4224 return tmp_subclass.view(-1) 4225 4226 x = torch.ones(2) 4227 out_ref = f(x) 4228 out_test = torch.compile(f, backend="aot_eager")(x) 4229 self.assertEqual(out_ref, out_test) 4230 4231 def test_numpy_tobytes_no_error(self): 4232 def fn(x): 4233 x += 1 4234 z = x.tobytes() 4235 x += 1 4236 return z 4237 4238 cnt = torch._dynamo.testing.CompileCounter() 4239 opt_fn = torch._dynamo.optimize(cnt)(fn) 4240 opt_arg, arg = np.array([1, 2]), np.array([1, 2]) 4241 self.assertEqual(opt_fn(opt_arg), fn(arg)) 4242 self.assertEqual(cnt.frame_count, 2) 4243 4244 def test_numpy_not_ndarray_recompiles(self): 4245 import torch 4246 4247 def fn(x=None): 4248 if x is None: 4249 x = np.ones(3) 4250 elif isinstance(x, int): 4251 x = np.ones(6) 4252 elif isinstance(x, str): 4253 x = np.ones(9) 4254 return x**2 4255 4256 cnt = torch._dynamo.testing.CompileCounter() 4257 opt_fn = torch._dynamo.optimize(cnt)(fn) 4258 4259 x = np.zeros((2, 2)) 4260 4261 self.assertEqual(opt_fn(x), fn(x)) 4262 self.assertEqual(cnt.frame_count, 1) 4263 self.assertEqual(opt_fn(), fn()) 4264 self.assertEqual(cnt.frame_count, 2) 4265 self.assertEqual(opt_fn(10), fn(10)) 4266 self.assertEqual(cnt.frame_count, 3) 4267 self.assertEqual(opt_fn("10"), fn("10")) 4268 self.assertEqual(cnt.frame_count, 4) 4269 4270 @parametrize( 4271 "backend", 4272 ["eager", "aot_eager", "inductor"], 4273 ) 4274 @parametrize( 4275 "func_name", 4276 ["func1", "func2", "func3"], 4277 ) 4278 def test_tensor_set_data(self, backend, func_name): 4279 # https://github.com/pytorch/pytorch/issues/113030 4280 def func1(x, y): 4281 x.data = y 4282 x.add_(1) 4283 return x 4284 4285 def func2(x, y): 4286 x.data = y 4287 y.data = torch.zeros([0]) 4288 return x 4289 4290 def func3(x, y): 4291 z = x 4292 x.data = y 4293 y.data = torch.zeros([0]) 4294 return torch.tensor(x is z) 4295 4296 funcs = {"func1": func1, "func2": func2, "func3": func3} 4297 func = funcs[func_name] 4298 4299 if backend != "eager" and func is func1: 4300 # add_ not working w/ aot_autograd? 4301 return 4302 4303 torch._dynamo.reset() 4304 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 4305 4306 compiled_fn = torch.compile(func, backend=cnt, fullgraph=True) 4307 requires_grad = func is not func1 4308 for i in range(0, 5): 4309 # Inputs 4310 eager_a = torch.ones([6], requires_grad=requires_grad) 4311 compiled_a = torch.ones([6], requires_grad=requires_grad) 4312 4313 eager_b = torch.ones([6], requires_grad=requires_grad) 4314 compiled_b = torch.ones([6], requires_grad=requires_grad) 4315 4316 # Eager 4317 out_eager = func(eager_a, eager_b) 4318 # Compiled 4319 out_compiled = compiled_fn(compiled_a, compiled_b) 4320 self.assertEqual(eager_a, compiled_a) 4321 self.assertEqual(eager_b, compiled_b) 4322 self.assertTrue(torch.equal(out_eager, out_compiled)) 4323 4324 # func1 hits a leaf Variable that requires grad is being used in an in-place operation 4325 if requires_grad: 4326 bwd_inp_eager = torch.randn([6]) 4327 bwd_inp_compiled = torch.clone(bwd_inp_eager) 4328 eager_a.backward(bwd_inp_eager) 4329 compiled_a.backward(bwd_inp_compiled) 4330 self.assertEqual(eager_a.grad, compiled_a.grad) 4331 4332 # Prove guarding works - we run the compiled_fn 5 times 4333 # frame_count should stay at 1. 4334 self.assertEqual(cnt.frame_count, 1) 4335 4336 @unittest.skipIf( 4337 TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION, 4338 "flash attention not supported", 4339 ) 4340 def test_flash_attn_backward_mixed_strides(self): 4341 # in this repro, "grad_out" and "value" are transposed tensors, 4342 # but "key" and "value" are contiguous 4343 def gen_inputs(device): 4344 return ( 4345 torch.randn( 4346 2, 513, 16, 64, dtype=torch.float16, device=device 4347 ).transpose(1, 2), 4348 torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4349 torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4350 torch.randn( 4351 2, 513, 16, 64, dtype=torch.float16, device=device 4352 ).transpose(1, 2), 4353 torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device), 4354 torch.randn(2, 16, 513, device=device), 4355 None, 4356 None, 4357 513, 4358 513, 4359 0.0, 4360 False, 4361 torch.tensor(1, dtype=torch.int64), 4362 torch.tensor(1, dtype=torch.int64), 4363 ) 4364 4365 inps_cuda = gen_inputs("cuda") 4366 inps_meta = gen_inputs("meta") 4367 ( 4368 out1_ref, 4369 out2_ref, 4370 out3_ref, 4371 ) = torch.ops.aten._scaled_dot_product_flash_attention_backward( 4372 *inps_cuda, scale=0.125 4373 ) 4374 from torch._meta_registrations import meta__scaled_dot_product_flash_backward 4375 4376 out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward( 4377 *inps_meta, scale=0.125 4378 ) 4379 4380 self.assertEqual(out1_ref.shape, out1_test.shape) 4381 self.assertEqual(out1_ref.stride(), out1_test.stride()) 4382 self.assertEqual(out2_ref.shape, out2_test.shape) 4383 self.assertEqual(out2_ref.stride(), out2_test.stride()) 4384 self.assertEqual(out3_ref.shape, out3_test.shape) 4385 self.assertEqual(out3_ref.stride(), out3_test.stride()) 4386 4387 def test_user_ctor_ctx_manager(self): 4388 class UserCtxManager: 4389 def __enter__(self): 4390 return 1 4391 4392 def __exit__(self, exc_type, exc_val, exc_tb): 4393 pass 4394 4395 def fn(x, y): 4396 ucm = UserCtxManager() 4397 return x * x 4398 4399 cnt = torch._dynamo.testing.CompileCounter() 4400 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 4401 x = torch.rand([2, 2]) 4402 opt_fn(x, x) 4403 self.assertExpectedInline(cnt.frame_count, """1""") 4404 4405 @torch._dynamo.config.patch(capture_scalar_outputs=True) 4406 def test_unbacked_arange_in_bounds(self): 4407 # see https://github.com/pytorch/pytorch/issues/113002 4408 class PaddingNet(nn.Module): 4409 def __init__(self) -> None: 4410 super().__init__() 4411 4412 def forward(self, lengths): 4413 max_seq_len = lengths.max().item() 4414 row_vector = torch.arange(0, max_seq_len, 1) 4415 matrix = torch.unsqueeze(lengths, dim=-1) 4416 mask = row_vector < matrix 4417 mask = mask.type(torch.float32) 4418 mask_3d_btd = mask[:, :, None] 4419 return mask_3d_btd 4420 4421 model = PaddingNet() 4422 lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32) 4423 4424 cnt = torch._dynamo.testing.CompileCounter() 4425 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(model) 4426 opt_fn(lengths) 4427 self.assertEqual(cnt.frame_count, 1) 4428 4429 def test_overlapping_inputs_with_dynamic_shapes_error(self): 4430 @torch.compile(backend="aot_eager") 4431 def fn(a, b, c, d, e, f): 4432 a.mul_(2) 4433 b.mul_(2) 4434 c.mul_(2) 4435 d.mul_(2) 4436 e.mul_(2) 4437 f.mul_(2) 4438 4439 base = torch.ones(2, 20) 4440 a = base[:, 0:2] 4441 b = base[:, 2:4] 4442 c = base[:, 4:6] 4443 d = base[:, 6:8] 4444 e = base[:, 8:10] 4445 f = base[:, 10:12] 4446 f2 = base[:, 10:14] 4447 out = fn(a, b, c, d, e, f) 4448 with self.assertRaisesRegex( 4449 AssertionError, "is being compiled with dynamic shapes" 4450 ): 4451 out2 = fn(a, b, c, d, e, f2) 4452 4453 def test_user_ctor_ctx_manager_custom_init(self): 4454 class UserCtxManager: 4455 def __init__(self, x): 4456 x[0] = 10 4457 4458 def __enter__(self): 4459 return 1 4460 4461 def __exit__(self, exc_type, exc_val, exc_tb): 4462 pass 4463 4464 def fn(x, y): 4465 ucm = UserCtxManager(y) 4466 return x * y[0] 4467 4468 cnt = torch._dynamo.testing.CompileCounter() 4469 opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn) 4470 x = torch.rand([2, 2]) 4471 self.assertEqual(opt_fn(x, [5]), fn(x, [5])) 4472 self.assertExpectedInline(cnt.frame_count, """1""") 4473 4474 def test_user_ctor_ctx_manager_custom_init_graph_break(self): 4475 counter = [0] 4476 4477 class UserCtxManager: 4478 def __init__(self, k): 4479 k[0] += 1 4480 4481 def __enter__(self): 4482 return 1 4483 4484 def __exit__(self, exc_type, exc_val, exc_tb): 4485 pass 4486 4487 def fn(x, counter): 4488 x = x * x 4489 ucm = UserCtxManager(counter) 4490 return x * x 4491 4492 cnt = torch._dynamo.testing.CompileCounter() 4493 opt_fn = torch._dynamo.optimize(cnt)(fn) 4494 x = torch.rand([2, 2]) 4495 self.assertEqual(opt_fn(x, counter), fn(x, counter)) 4496 self.assertEqual(counter[0], 2) 4497 for i in range(0, 10): 4498 opt_fn(x, counter) 4499 self.assertEqual(counter[0], 12) 4500 if torch._dynamo.config.assume_static_by_default: 4501 self.assertExpectedInline(cnt.frame_count, """2""") 4502 else: 4503 self.assertExpectedInline(cnt.frame_count, """1""") 4504 4505 @unittest.expectedFailure 4506 def test_many_overlapping_inputs_does_not_explode_guards(self): 4507 from torch._dynamo.backends.common import aot_autograd 4508 4509 # Before, this was (9702, 0) 4510 num_shape_guards = None 4511 num_aot_guards = None 4512 num_compiles = 0 4513 4514 def guard_count_backend(gm, *args): 4515 nonlocal num_shape_guards 4516 nonlocal num_aot_guards 4517 nonlocal num_compiles 4518 num_shape_guards = len( 4519 torch._guards.TracingContext.try_get().fake_mode.shape_env.guards 4520 ) 4521 num_aot_guards = len( 4522 torch._guards.TracingContext.try_get().guards_context.aotautograd_guards 4523 ) 4524 num_compiles += 1 4525 return gm 4526 4527 aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend) 4528 4529 @torch.compile(backend=aot_guard_counter, dynamic=True) 4530 def f(*args): 4531 for a in args: 4532 a.add_(1) 4533 4534 x = torch.ones(1000, requires_grad=True) 4535 args = x.split(10) 4536 4537 with torch.no_grad(): 4538 f(*args) 4539 # In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2), 4540 # because every pair of aliased inputs needs a guard. 4541 self.assertTrue(num_aot_guards < 5000) 4542 # But there are no dynamic shape guards. 4543 self.assertEqual(num_shape_guards, 0) 4544 # don't recompile 4545 with torch.no_grad(): 4546 f(*args) 4547 self.assertEqual(num_compiles, 1) 4548 4549 def test_invalid_seq_unpack(self): 4550 def myfn(arg): 4551 (a, b) = arg 4552 4553 def fn(): 4554 return myfn((1, 2, 3)) 4555 4556 try: 4557 torch.compile(fn)() 4558 except ValueError: 4559 pass 4560 else: 4561 self.fail("expected exception") 4562 4563 def test_megablocks_moe(self): 4564 try: 4565 from megablocks.layers import moe 4566 from megablocks.layers.arguments import Arguments 4567 except ImportError as e: 4568 raise unittest.SkipTest("requires megablocks") from e 4569 bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1) 4570 args = Arguments( 4571 hidden_size=hs, 4572 ffn_hidden_size=hs * 2, 4573 moe_num_experts=num_experts, 4574 moe_capacity_factor=1, 4575 moe_top_k=top_k, 4576 ) 4577 moe_mlp = moe.MoE(args) 4578 moe_mlp.cuda(torch.cuda.current_device()).half() 4579 x = torch.randn(sl, bs, hs).cuda().half() 4580 out1, _ = moe_mlp(x) 4581 out2, _ = torch.compile(moe_mlp, backend="eager")(x) 4582 self.assertEqual(out1, out2) 4583 4584 def test_udf_classes_reconstruction(self): 4585 def fn(x): 4586 o = T(5) 4587 return o.x + x 4588 4589 opt_fn = torch.compile(fn, backend="eager") 4590 T = IncByOne 4591 4592 x = torch.randn(4) 4593 self.assertEqual(fn(x), opt_fn(x)) 4594 4595 # This should recompile 4596 T = IncByTwo 4597 self.assertEqual(fn(x), opt_fn(x)) 4598 4599 def test_contains_range_constprop(self): 4600 def fn(x): 4601 # dynamo should const prop to False 4602 if 3 in range(0, 10): 4603 return x + 1 4604 else: 4605 return x + 2 4606 4607 opt_fn = torch.compile(fn, backend="eager") 4608 x = torch.zeros(4) 4609 self.assertEqual(fn(x), opt_fn(x)) 4610 4611 # https://github.com/pytorch/pytorch/issues/104505 4612 def test_as_strided_on_base_with_mutation_works(self): 4613 def foo(a): 4614 f = a.as_strided((2,), (1,), 0) 4615 f.add_(1.0) 4616 return a 4617 4618 a = torch.randn(2, 4) 4619 a_ref = a.clone() 4620 out_ref = foo(a_ref) 4621 f_compiled = torch.compile(foo, backend="aot_eager") 4622 out = f_compiled(a) 4623 self.assertEqual(out_ref, out) 4624 self.assertEqual(a_ref, a) 4625 4626 # https://github.com/pytorch/pytorch/issues/104505 4627 def test_as_strided_on_existing_view_banned(self): 4628 def foo(a): 4629 e = a.diagonal() 4630 f = e.as_strided((2,), (1,), 0) 4631 f.add_(1.0) 4632 return a 4633 4634 a = torch.randn(2, 4) 4635 a_ref = a.clone() 4636 out_ref = foo(a_ref) 4637 f_compiled = torch.compile(foo, backend="aot_eager") 4638 with self.assertRaisesRegex( 4639 RuntimeError, 4640 "encountered a mutation on a view chain of length 2, where view 1 was an as_strided", 4641 ): 4642 out = f_compiled(a) 4643 4644 def test_dont_aggressively_write_assert(self): 4645 record_graph = torch._dynamo.testing.EagerAndRecordGraphs() 4646 4647 @torch.compile(dynamic=True, backend=record_graph) 4648 def f(x): 4649 assert x.shape[0] > 3 4650 assert x[0].sum() > 0 4651 assert 1 % (x.shape[0] // 2) != 0 4652 assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0 4653 return x.cos() 4654 4655 f(torch.ones(6, 4)) 4656 graph = record_graph.graphs[0] 4657 # It is bit annoying that we generate useless statements for 4658 # shape guards, but DCE should be able to remove them since t 4659 # there is no backed assert on them. The reason this is ok is 4660 # because dynamo will only skip the assert statement, but not 4661 # the instructions before it. 4662 self.assertExpectedInline( 4663 str(graph.code).strip(), 4664 """\ 4665def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): 4666 l_x_ = L_x_ 4667 getitem_2 = l_x_[0] 4668 sum_1 = getitem_2.sum(); getitem_2 = None 4669 gt_1 = sum_1 > 0; sum_1 = None 4670 _assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = _assert_async = None 4671 cos = l_x_.cos(); l_x_ = None 4672 return (cos,)""", 4673 ) 4674 for node in graph.graph.nodes: 4675 if "example_value" in node.meta and isinstance( 4676 node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor 4677 ): 4678 shape_env = node.meta["example_value"].fake_mode.shape_env 4679 lower_ranges = [val.lower for val in shape_env.var_to_range.values()] 4680 self.assertTrue(lower_ranges == [4, 2]) 4681 4682 @torch.compile(dynamic=True, backend=record_graph) 4683 def f_fail(x): 4684 assert x.shape[0] < 3 4685 4686 # We graph-break here, so the failure should be eager 4687 with self.assertRaisesRegex(AssertionError, ""): 4688 f_fail(torch.ones(6, 4)) 4689 4690 def test_detectron2_instances_cat(self): 4691 class Instances: 4692 def __init__(self, image_size: Tuple[int, int], **kwargs: Any): 4693 self._image_size = image_size 4694 self._fields: Dict[str, Any] = {} 4695 for k, v in kwargs.items(): 4696 self.set(k, v) 4697 4698 @property 4699 def image_size(self) -> Tuple[int, int]: 4700 return self._image_size 4701 4702 def __setattr__(self, name: str, val: Any) -> None: 4703 if name.startswith("_"): 4704 super().__setattr__(name, val) 4705 else: 4706 self.set(name, val) 4707 4708 def __getattr__(self, name: str) -> Any: 4709 if name == "_fields" or name not in self._fields: 4710 raise AttributeError( 4711 f"Cannot find field '{name}' in the given Instances!" 4712 ) 4713 return self._fields[name] 4714 4715 def __len__(self) -> int: 4716 for v in self._fields.values(): 4717 # use __len__ because len() has to be int and is not friendly to tracing 4718 return v.__len__() 4719 raise NotImplementedError("Empty Instances does not support __len__!") 4720 4721 def set(self, name: str, value: Any) -> None: 4722 with warnings.catch_warnings(record=True): 4723 data_len = len(value) 4724 if len(self._fields): 4725 assert ( 4726 len(self) == data_len 4727 ), f"Adding a field of length {data_len} to a Instances of length {len(self)}" 4728 self._fields[name] = value 4729 4730 def get(self, name: str) -> Any: 4731 return self._fields[name] 4732 4733 @staticmethod 4734 def cat(instance_lists: List["Instances"]) -> "Instances": 4735 assert all(isinstance(i, Instances) for i in instance_lists) 4736 assert len(instance_lists) > 0 4737 if len(instance_lists) == 1: 4738 return instance_lists[0] 4739 4740 image_size = instance_lists[0].image_size 4741 if not isinstance( 4742 image_size, torch.Tensor 4743 ): # could be a tensor in tracing 4744 for i in instance_lists[1:]: 4745 assert i.image_size == image_size 4746 ret = Instances(image_size) 4747 for k in instance_lists[0]._fields.keys(): 4748 values = [i.get(k) for i in instance_lists] 4749 v0 = values[0] 4750 if isinstance(v0, torch.Tensor): 4751 values = torch.cat(values, dim=0) 4752 elif isinstance(v0, list): 4753 values = list(itertools.chain(*values)) 4754 elif hasattr(type(v0), "cat"): 4755 values = type(v0).cat(values) 4756 else: 4757 raise ValueError( 4758 f"Unsupported type {type(v0)} for concatenation" 4759 ) 4760 ret.set(k, values) 4761 return ret 4762 4763 instances = [ 4764 Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16)) 4765 for _ in range(3) 4766 ] 4767 4768 @torch.compile(backend="eager", fullgraph=True) 4769 def fn(instances): 4770 return instances[0].cat(instances) 4771 4772 actual = fn(instances) 4773 expected = instances[0].cat(instances) 4774 self.assertEqual(type(actual), type(expected)) 4775 self.assertEqual(actual.__dict__, expected.__dict__) 4776 4777 def test_weakref(self): 4778 def fn(x_weak, weight, y): 4779 if x_weak is not None and x_weak() is not weight: 4780 return torch.sin(y) 4781 return torch.cos(y) 4782 4783 weight = torch.randn(4) 4784 y = torch.randn(4) 4785 x_weak = weakref.ref(weight) 4786 4787 ref = fn(x_weak, weight, y) 4788 4789 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4790 res = opt_fn(x_weak, weight, y) 4791 self.assertEqual(ref, res) 4792 4793 def test_weakref_reconstruct(self): 4794 def fn(x_weak, weight, y): 4795 y = torch.sin(y) 4796 referent = x_weak() 4797 torch._dynamo.graph_break() 4798 if referent is not weight: 4799 return torch.sin(y) 4800 return torch.cos(y) 4801 4802 weight = torch.randn(4) 4803 y = torch.randn(4) 4804 x_weak = weakref.ref(weight) 4805 4806 ref = fn(x_weak, weight, y) 4807 4808 cnt = torch._dynamo.testing.CompileCounter() 4809 opt_fn = torch.compile(fn, backend=cnt) 4810 res = opt_fn(x_weak, weight, y) 4811 self.assertEqual(ref, res) 4812 self.assertEqual(cnt.frame_count, 2) 4813 4814 def test_weakref_del(self): 4815 def fn(x_weak, y): 4816 x = x_weak() 4817 if x is not None: 4818 return torch.sin(y) 4819 return torch.cos(y) 4820 4821 weight = torch.randn(4) 4822 x_weak = weakref.ref(weight) 4823 y = torch.randn(4) 4824 4825 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4826 4827 ref = fn(x_weak, y) 4828 res = opt_fn(x_weak, y) 4829 self.assertEqual(ref, res) 4830 4831 del weight 4832 gc.collect() 4833 ref = fn(x_weak, y) 4834 res = opt_fn(x_weak, y) 4835 self.assertEqual(ref, res) 4836 4837 # @torch._functorch.config.patch( 4838 # recompute_views=True, 4839 # ) 4840 # def test_storage_resize_forward_full_graph(self): 4841 # class TestModule(torch.nn.Module): 4842 # def __init__(self) -> None: 4843 # super().__init__() 4844 # self.param = torch.nn.Parameter(torch.randn(4, 4)) 4845 4846 # def forward(self, x): 4847 # self.param.untyped_storage().resize_( 4848 # self.param.numel() * self.param.itemsize 4849 # ) 4850 # with torch.no_grad(): 4851 # torch._foreach_copy_([self.param], [x]) 4852 # out = torch.matmul(self.param, self.param) 4853 # self.param.untyped_storage().resize_(0) 4854 # return out 4855 4856 # def post_accumulate_grad_hook(param): 4857 # param.untyped_storage().resize_(0) 4858 4859 # # Beginning of backward, resize and put data into the param 4860 # def pre_backward_hook(module, grad) -> None: 4861 # module.param.untyped_storage().resize_( 4862 # self.param.numel() * self.param.itemsize 4863 # ) 4864 # with torch.no_grad(): 4865 # # simulates loading data into param from allgather 4866 # module.param.fill_(2) 4867 4868 # def post_forward_hook(module, args, output): 4869 # output.register_hook(functools.partial(pre_backward_hook, module)) 4870 4871 # x = torch.randn(4, 4) 4872 4873 # mod_ref = TestModule() 4874 # mod_test = deepcopy(mod_ref) 4875 4876 # # Start the param off with zero storage size to mimic fsdp 4877 # mod_ref.param.untyped_storage().resize_(0) 4878 # mod_test.param.untyped_storage().resize_(0) 4879 4880 # # Resize storage at beginning of backward 4881 # # Free storage at end of backward 4882 # mod_ref.register_forward_hook(post_forward_hook, prepend=False) 4883 # mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook) 4884 # mod_test.register_forward_hook(post_forward_hook, prepend=False) 4885 # mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook) 4886 4887 # mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend) 4888 4889 # out_ref = mod_ref(x) 4890 # out_test = mod_test(x) 4891 # self.assertExpectedInline( 4892 # str(fw_graph[0].code.strip()), 4893 # """\ 4894 # def forward(self, primals_1, primals_2): 4895 # _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None 4896 # getitem = _foreach_copy[0]; _foreach_copy = None 4897 # mm = torch.ops.aten.mm.default(getitem, getitem) 4898 # return [mm, getitem]""", 4899 # ) 4900 # self.assertEqual(out_ref, out_test) 4901 4902 def test_super_in_staticmethod(self): 4903 class A: 4904 @staticmethod 4905 def foo(): 4906 return super().__init__() 4907 4908 def fn(obj): 4909 return obj.foo() 4910 4911 obj = A() 4912 4913 try: 4914 fn(obj) 4915 except Exception as e: 4916 orig_str = str(e) 4917 self.assertIn("no arguments", orig_str) 4918 4919 try: 4920 torch.compile(backend="eager")(fn)(obj) 4921 except Exception as e: 4922 compiled_str = str(e) 4923 self.assertEqual(orig_str, compiled_str) 4924 4925 def test_super_staticmethod(self): 4926 class Parent: 4927 @staticmethod 4928 def greet(): 4929 return 5 4930 4931 class Child(Parent): 4932 @staticmethod 4933 def greet(x): 4934 return x * super(Child, Child).greet() 4935 4936 child = Child() 4937 4938 def fn(x): 4939 return child.greet(x) 4940 4941 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4942 x = torch.ones(4) 4943 ref = fn(x) 4944 res = opt_fn(x) 4945 self.assertEqual(ref, res) 4946 4947 def test_super_diamond(self): 4948 class A: 4949 def __init__(self): 4950 super().__init__() 4951 self.a = 5 4952 4953 class Nothing: 4954 pass 4955 4956 class B(Nothing, A): 4957 def __init__(self): 4958 super().__init__() 4959 self.b = 10 4960 4961 def run(self, x): 4962 return self.a * self.b * x 4963 4964 def fn(x): 4965 b = B() 4966 return b.run(x) 4967 4968 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 4969 x = torch.randn(4) 4970 ref = fn(x) 4971 res = opt_fn(x) 4972 self.assertEqual(ref, res) 4973 4974 def test_vc_bumped_in_inference_graph(self): 4975 @torch.compile 4976 def f(x): 4977 return x.mul_(2) 4978 4979 x = torch.randn(4) 4980 vc_before = x._version 4981 f(x) 4982 vc_after = x._version 4983 self.assertTrue(vc_after > vc_before) 4984 4985 def test_nn_module_callable(self): 4986 class M(nn.Module): 4987 def forward(self, x): 4988 return x.sin() 4989 4990 def f(m): 4991 return callable(m) 4992 4993 res = torch.compile(f, fullgraph=True)(M()) 4994 self.assertTrue(res) 4995 4996 def test_stk_sdd_is_transposed(self): 4997 trigger_graph_break = False 4998 4999 def _is_transposed(x): 5000 return ( 5001 not x.is_contiguous() 5002 and x.stride()[0] == 1 5003 and x.stride()[1] == x.size()[0] 5004 ) 5005 5006 class SDD(torch.autograd.Function): 5007 @staticmethod 5008 def forward(ctx, lhs, rhs): 5009 ctx.save_for_backward(lhs, rhs) 5010 out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device) 5011 return out 5012 5013 @staticmethod 5014 def backward(ctx, dy): 5015 saved_tensors = ctx.saved_tensors 5016 lhs, rhs = saved_tensors[:2] 5017 trans_a = _is_transposed(lhs) 5018 trans_b = _is_transposed(rhs) 5019 dlhs = None 5020 if ctx.needs_input_grad[0]: 5021 dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0) 5022 drhs = None 5023 if ctx.needs_input_grad[1]: 5024 drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0) 5025 if trigger_graph_break: 5026 if _is_transposed(dy): 5027 return dlhs + 1, drhs + 1, None, None 5028 return dlhs, drhs, None, None 5029 5030 x1 = torch.randn((8, 8), requires_grad=True) 5031 y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True) 5032 x2 = torch.randn((8, 8), requires_grad=True) 5033 y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True) 5034 5035 SDD.apply(x1, y1).sum().backward() 5036 5037 @torch.compile(backend="eager", fullgraph=True) 5038 def fn(): 5039 return SDD.apply(x2, y2) 5040 5041 fn().sum().backward() 5042 5043 self.assertEqual(x1.grad, x2.grad) 5044 self.assertEqual(y1.grad, y2.grad) 5045 5046 trigger_graph_break = True 5047 with self.assertRaises(torch._dynamo.exc.Unsupported): 5048 fn().sum().backward() 5049 5050 def test_partially_initialized_module_property(self): 5051 class Matrix(torch.nn.Module): 5052 def __init__(self, data): 5053 super().__init__() 5054 self._data = data 5055 self.foo = 10 * self.blocking 5056 5057 @property 5058 def data(self): 5059 return self._data 5060 5061 @property 5062 def blocking(self): 5063 return self.data.shape[1] 5064 5065 @torch.compile(backend="eager", fullgraph=True) 5066 def fn(): 5067 return Matrix(torch.randn(10, 20)) 5068 5069 v = fn() 5070 self.assertEqual(v.foo, 200) 5071 self.assertEqual(v.data.shape, (10, 20)) 5072 self.assertEqual(type(v), Matrix) 5073 5074 def test_classmethod_with_slots(self): 5075 class Mock: 5076 __slots__ = ("_a",) 5077 5078 def __init__(self): 5079 self._a = 2 5080 5081 @classmethod 5082 def _m(cls): 5083 return 3 5084 5085 def run(self, x): 5086 return torch.sin(x) * self._a * self._m() 5087 5088 def fn(x): 5089 mock = Mock() 5090 return mock.run(x) 5091 5092 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5093 x = torch.randn(4) 5094 self.assertEqual(fn(x), opt_fn(x)) 5095 5096 def test_nn_parametrize(self): 5097 class Module(nn.Module): 5098 def __init__(self) -> None: 5099 super().__init__() 5100 self.param = torch.nn.Parameter(torch.randn(10, 10)) 5101 5102 def forward(self, x): 5103 return self.param @ x 5104 5105 class Parametrization(torch.nn.Module): 5106 def forward(self, x): 5107 return torch.sin(x) 5108 5109 m = Module() 5110 torch.nn.utils.parametrize.register_parametrization( 5111 m, "param", Parametrization() 5112 ) 5113 5114 sin_found = False 5115 5116 def backend(gm, _): 5117 nonlocal sin_found 5118 for node in gm.graph.nodes: 5119 if node.target is torch.sin: 5120 sin_found = True 5121 return gm 5122 5123 opt_m = torch.compile(m, backend=backend, fullgraph=True) 5124 inp = torch.randn(10, 10) 5125 self.assertEqual(m(inp), opt_m(inp)) 5126 self.assertTrue(sin_found) 5127 5128 torch.nn.utils.parametrize.remove_parametrizations(m, "param") 5129 sin_found = False 5130 self.assertEqual(m(inp), opt_m(inp)) 5131 self.assertFalse(sin_found) 5132 5133 def test_nn_module_property_closure(self): 5134 x = torch.randn(10, 10) 5135 5136 class Mod(torch.nn.Module): 5137 @property 5138 def y(self): 5139 return torch.ones(10, 10) + x 5140 5141 def forward(self, x): 5142 return x @ self.y 5143 5144 mod = Mod() 5145 5146 def fn(x): 5147 return mod(x) 5148 5149 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5150 5151 inp = torch.randn(10, 10) 5152 self.assertEqual(fn(inp), opt_fn(inp)) 5153 5154 def test_global_fn_mutation(self): 5155 def foo(x, y): 5156 return global_fn(x) + y 5157 5158 x = torch.ones(1) 5159 y = torch.ones(1) 5160 5161 opt = torch.compile(foo, fullgraph=True, backend="eager") 5162 self.assertEqual(opt(x, y), foo(x, y)) 5163 5164 # Change global_fn 5165 global global_fn 5166 5167 def new_fn(x): 5168 return torch.cos(x) 5169 5170 global_fn = new_fn 5171 self.assertEqual(opt(x, y), foo(x, y)) 5172 5173 # ref https://github.com/pytorch/pytorch/issues/123974 5174 def test_list_reverse(self): 5175 def ladder(x): 5176 trail = x.size(-1) 5177 assert trail > 2 5178 weights = [] 5179 for s in [trail, trail - 1, trail - 2]: 5180 weights.append(torch.ones(s, s - 1)) 5181 5182 for w in weights: 5183 x = x @ w 5184 5185 weights.reverse() 5186 5187 for w in weights: 5188 x = x @ w.t() 5189 5190 return x 5191 5192 data = torch.randn(3, 4) 5193 opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager") 5194 self.assertEqual(opt_ladder(data), ladder(data)) 5195 5196 def test_trace_functional_tensor_with(self): 5197 from torch._subclasses.fake_tensor import FakeTensorMode 5198 from torch._subclasses.functional_tensor import ( 5199 FunctionalTensor, 5200 FunctionalTensorMode, 5201 ) 5202 5203 def f(a, tmp): 5204 a_view = a.view(-1) 5205 with torch.no_grad(): 5206 a.set_(tmp) 5207 a_view.mul_(2) 5208 return a + tmp 5209 5210 fake_mode = FakeTensorMode() 5211 with FunctionalTensorMode(): 5212 inp = torch.ones(3, 3, requires_grad=True) 5213 inp = fake_mode.from_tensor(inp, static_shapes=True) 5214 inp = FunctionalTensor.to_functional(inp) 5215 5216 tmp = torch.ones(3, 3, requires_grad=True) 5217 tmp = fake_mode.from_tensor(tmp, static_shapes=True) 5218 tmp = FunctionalTensor.to_functional(tmp) 5219 5220 opt_f = torch.compile(f, backend="eager") 5221 with self.assertRaisesRegex( 5222 RuntimeError, "cannot mutate tensors with frozen storage" 5223 ): 5224 opt_f(inp, tmp) 5225 5226 def test_const_dict_keyerror(self): 5227 d = {} 5228 5229 def fn(x): 5230 try: 5231 y = d[0] 5232 except KeyError: 5233 y = 1 5234 return x + y 5235 5236 opt_fn = torch.compile(fn, backend="eager") 5237 inp = torch.randn(3, 3) 5238 self.assertEqual(fn(inp), opt_fn(inp)) 5239 5240 def test_dict_tag_guard(self): 5241 class Foo: 5242 def __init__(self) -> None: 5243 self.scalar = 10 5244 5245 def fn(d, x): 5246 return d["a"] * d["b"] * d["c"].scalar * x 5247 5248 foo = Foo() 5249 5250 d = {"a": 2, "b": 3, "c": foo} 5251 5252 opt_fn = torch.compile(fn, backend="eager") 5253 inp = torch.randn(3, 3) 5254 self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5255 5256 d["a"] = 4 5257 self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5258 5259 # Check that recompilation happens 5260 foo.scalar = 12 5261 self.assertEqual(fn(d, inp), opt_fn(d, inp)) 5262 5263 def test_nonconst_issubclass(self): 5264 def fn(x): 5265 if issubclass(x.__class__, np.ndarray): 5266 return 1 5267 return 0 5268 5269 opt_fn = torch.compile(fn, backend="eager") 5270 opt_fn(np.ones([3, 3])) 5271 5272 def test_issue126128(self): 5273 def fn(): 5274 x = torch.randn(1, 10) 5275 y = torch.randn(10, 1) 5276 return torch.mm(x, y).sum() 5277 5278 def fn2(): 5279 x = torch.randn(10, 100) 5280 y = torch.randn(100, 10) 5281 return torch.mm(x, y).sum() 5282 5283 with fresh_inductor_cache(): 5284 torch.compile(fn)() 5285 5286 torch.compile(fn2)() 5287 5288 def test_jit_script_defaults(self): 5289 @torch.jit.script 5290 def fast_cos(x, c: float = 2.0): 5291 return torch.cos(x) * c 5292 5293 class Mod(torch.nn.Module): 5294 def __init__(self) -> None: 5295 super().__init__() 5296 self.fast_cos = fast_cos 5297 5298 def forward(self, x): 5299 return self.fast_cos(x) 5300 5301 mod = Mod() 5302 opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 5303 x = torch.randn(4) 5304 self.assertEqual(mod(x), opt_mod(x)) 5305 5306 def test_enum(self): 5307 class ExplicitEnum(str, Enum): 5308 @classmethod 5309 def _missing_(cls, value): 5310 raise ValueError( 5311 f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" 5312 ) 5313 5314 class PaddingStrategy(ExplicitEnum): 5315 LONGEST = "longest" 5316 MAX_LENGTH = "max_length" 5317 DO_NOT_PAD = "do_not_pad" 5318 5319 def fn(x): 5320 a = PaddingStrategy("longest") 5321 if a == PaddingStrategy.LONGEST: 5322 return torch.sin(x) 5323 return torch.cos(x) 5324 5325 x = torch.randn(3, 3) 5326 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 5327 self.assertEqual(fn(x), opt_fn(x)) 5328 5329 def test_hasattr_builtin(self): 5330 class MyClass: 5331 foo: int = 1 5332 5333 def func(x, m): 5334 if getattr(type(m), "foo", 0): 5335 return x + MyClass.foo 5336 return x 5337 5338 opt_func = torch.compile(func, backend="eager", fullgraph=True) 5339 m = MyClass() 5340 x = torch.zeros(()) 5341 self.assertEqual(func(x, m), opt_func(x, m)) 5342 self.assertEqual(func(x, 0), opt_func(x, 0)) 5343 5344 def test_grad(self): 5345 def fn(x, y): 5346 x._grad = y 5347 return x.grad.data 5348 5349 x = torch.randn(4, requires_grad=True) 5350 y = torch.randn(4) 5351 opt_fn = torch.compile(fn, backend="eager") 5352 self.assertEqual(fn(x, y), opt_fn(x, y)) 5353 5354 def test_nn_module_stack_bc(self): 5355 from torch._dynamo.mutation_guard import GenerationTracker 5356 5357 def compiler(gm, *args): 5358 module_stacks = [ 5359 node.meta.get("nn_module_stack", None) for node in gm.graph.nodes 5360 ] 5361 module_stacks, _ = pytree.tree_flatten(module_stacks) 5362 module_stacks = [x for x in module_stacks if isinstance(x, str)] 5363 for stack in module_stacks: 5364 self.assertTrue("_module" not in stack) 5365 return gm.forward 5366 5367 class SubMod(torch.nn.Module): 5368 def __init__(self) -> None: 5369 super().__init__() 5370 self.linear = torch.nn.Linear(2, 2) 5371 5372 def forward(self, x): 5373 return self.linear(x) 5374 5375 class Mod(torch.nn.Module): 5376 def __init__(self) -> None: 5377 super().__init__() 5378 self.submod1 = SubMod() 5379 self.submod2 = SubMod() 5380 5381 def forward(self, x): 5382 return self.submod1(x) + self.submod2(x) 5383 5384 mod = Mod() 5385 opt_mod = torch.compile(mod, backend=compiler) 5386 opt_mod(torch.randn(2, 2)) 5387 5388 with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): 5389 mod = Mod() 5390 opt_mod = torch.compile(mod, backend=compiler) 5391 opt_mod(torch.randn(2, 2)) 5392 5393 # an example similar to Pippy usecase 5394 mod = Mod() 5395 GenerationTracker.tag(mod.submod1) 5396 GenerationTracker.mark_class_dynamic(type(mod.submod1)) 5397 mod = Mod() 5398 opt_mod = torch.compile(mod, backend=compiler) 5399 opt_mod(torch.randn(2, 2)) 5400 5401 def test_is_make_fx_tracing(self): 5402 @torch.compile(backend="eager", fullgraph=True) 5403 def fn(x): 5404 torch.nn.modules.activation._is_make_fx_tracing() 5405 return torch.sin(x) 5406 5407 fn(torch.rand(4)) 5408 5409 def test_negative_floor_div_solve(self): 5410 class CompiledClass(nn.Module): 5411 def __init__(self) -> None: 5412 super().__init__() 5413 self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) 5414 self.t = 5 5415 5416 def forward(self): 5417 self.num = self.nums[self.t // 12] 5418 self.t += 1 5419 return self.num 5420 5421 m = CompiledClass() 5422 m = torch.compile(m, backend="eager") 5423 5424 # the first call works 5425 m() 5426 # the second call causes a failure 5427 m() 5428 5429 # https://github.com/pytorch/pytorch/issues/121621 5430 def test_tensor_random(self): 5431 def random_op(tensor, params): 5432 res = tensor.random_(**params) 5433 return res 5434 5435 random_op = torch.compile(random_op) 5436 params = {"from": -10, "to": 10} 5437 tensor = torch.randn([2, 3]) 5438 res = random_op(tensor, params) 5439 5440 # https://github.com/pytorch/pytorch/issues/131019 5441 def test_tensor_uniform(self): 5442 def uniform_op(tensor, params): 5443 res = tensor.uniform_(**params) 5444 return res 5445 5446 uniform_op = torch.compile(uniform_op) 5447 params = {"from": -10, "to": 10} 5448 tensor = torch.randn([2, 3]) 5449 res = uniform_op(tensor, params) 5450 5451 def test_data_attr_mutation_after_saved_for_bw(self): 5452 def f(x): 5453 out = x.sin() 5454 x.data.mul_(2) 5455 return out 5456 5457 x = torch.randn(4, requires_grad=True) 5458 x_test = x.clone().detach().requires_grad_(True) 5459 5460 out = f(x) 5461 out_test = torch.compile(f, backend="aot_eager")(x_test) 5462 self.assertEqual(out, out_test) 5463 5464 out.sum().backward() 5465 out_test.sum().backward() 5466 self.assertEqual(x.grad, x_test.grad) 5467 5468 # https://github.com/pytorch/pytorch/issues/128072 5469 def test_map_with_multiple_args(self): 5470 def f(a, b): 5471 return a[0] * b[0] + a[1] * b[1] 5472 5473 def gen_inps(len_x, len_y): 5474 x = [torch.randn(5) for _ in range(len_x)] 5475 y = [torch.randn(5) for _ in range(len_y)] 5476 return x, y 5477 5478 def g(x, y): 5479 return map(f, x, y) 5480 5481 opt_g = torch.compile(g, fullgraph=True, backend="eager") 5482 5483 inps = gen_inps(3, 3) 5484 self.assertEqual(type(g(*inps)), type(opt_g(*inps))) 5485 self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) 5486 5487 inps = gen_inps(3, 5) 5488 self.assertEqual(type(g(*inps)), type(opt_g(*inps))) 5489 self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps))) 5490 5491 def test_staticmethod_allow_in_graph(self): 5492 class MyClass: 5493 i = 3 5494 5495 @staticmethod 5496 def foo_inner(x): 5497 return torch.mul(x, MyClass.i) 5498 5499 # if dynamo inlines with fullgraph, will error 5500 # verify that dynamo doesn't inline 5501 @staticmethod 5502 @torch._dynamo.allow_in_graph 5503 def foo1(x): 5504 torch._dynamo.graph_break() 5505 return MyClass.foo_inner(x) 5506 5507 @torch.compile(backend="eager", fullgraph=True) 5508 def f_bad(x): 5509 return MyClass.foo1(x) 5510 5511 f_bad(torch.ones(2, 2)) 5512 5513 def test_guard_with_tuple_mutation(self): 5514 class Foo: 5515 def __init__(self) -> None: 5516 self.x = 10 5517 5518 foo = Foo() 5519 d = { 5520 "a": 2, 5521 "b": (foo,), 5522 } 5523 5524 def fn(x, d): 5525 return x * d["a"] * d["b"][0].x 5526 5527 opt_fn = torch.compile(fn, backend="eager") 5528 inp = torch.randn(3, 3) 5529 self.assertEqual(fn(inp, d), opt_fn(inp, d)) 5530 d["b"][0].x = 12 5531 self.assertEqual(fn(inp, d), opt_fn(inp, d)) 5532 5533 def test_compile_complex_conj(self): 5534 def f(x): 5535 return torch.mul(x, 2j) 5536 5537 x_ref = torch.randn(4, 2, requires_grad=True) 5538 x_test = x_ref.clone().detach().requires_grad_(True) 5539 5540 out_ref = f(torch.view_as_complex(x_ref)) 5541 out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test)) 5542 self.assertEqual(out_ref, out_test) 5543 5544 torch.view_as_real(out_ref).sum().backward() 5545 torch.view_as_real(out_test).sum().backward() 5546 self.assertEqual(x_ref.grad, x_test.grad) 5547 5548 # https://github.com/pytorch/pytorch/issues/132200 5549 def test_partitioner_cse_respects_mutation_boundaries(self): 5550 set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") 5551 if not set_available: 5552 return 5553 5554 @torch.compile(backend="aot_eager_decomp_partition") 5555 def f(x, l): 5556 # z0 and z1 can be CSEd 5557 z0 = x.sin() 5558 z1 = x.sin() 5559 y = x + 1 5560 torch.ops.fsdp.set_.default(x, y) 5561 # z3 and z3 can be CSEd with each other, 5562 # but *not* with z0/z1 (they cross a mutation boundary) 5563 z2 = x.sin() 5564 z3 = x.sin() 5565 return z0, z1, z2, z3, l**2 5566 5567 x = torch.randn(3) 5568 x_clone = x.clone() 5569 l = torch.randn(3, requires_grad=True) 5570 z0, z1, z2, z3, _ = f(x, l) 5571 5572 # the partitioner runs CSE. We expect that of the 4 sin() ops above: 5573 # - the first 2 are CSE'd 5574 # - the last 2 are CSE'd 5575 # - the set_() op in the middle is a mutation barrier, preventing CSE 5576 self.assertEqual(z0, (x_clone).sin()) 5577 self.assertEqual(z1, (x_clone).sin()) 5578 self.assertEqual(z2, (x_clone + 1).sin()) 5579 self.assertEqual(z3, (x_clone + 1).sin()) 5580 5581 # https://github.com/pytorch/pytorch/issues/132197 5582 def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self): 5583 set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_") 5584 if not set_available: 5585 return 5586 5587 @torch.compile(backend="aot_eager_decomp_partition") 5588 def f(x, l): 5589 z = x.sin() 5590 y = x + 1 5591 # graph input has its storage mutated 5592 torch.ops.fsdp.set_.default(x, y) 5593 z2 = x.sin() 5594 return z2, l**2 5595 5596 x = torch.randn(3) 5597 x_test = x.clone() 5598 l = torch.randn(3, requires_grad=True) 5599 result, _ = f(x, l) 5600 result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")( 5601 x_test, l 5602 ) 5603 5604 self.assertEqual(result, result_test) 5605 self.assertEqual(x, x_test) 5606 5607 def test_changing_stride(self): 5608 cnt = torch._dynamo.testing.CompileCounter() 5609 5610 @torch.compile(backend=cnt) 5611 def fn(x, y): 5612 return x * y 5613 5614 for i in range(1, 4): 5615 x = torch.randn(4, i) 5616 5617 # create a view for i > 1 5618 if i == 1: 5619 x1 = x 5620 else: 5621 x1 = x[:, 0:1] 5622 5623 y = torch.randn(4, 1) 5624 print(x1.shape, y.shape) 5625 fn(x1, y) 5626 5627 self.assertTrue(cnt.frame_count <= 2) 5628 5629 @torch._dynamo.config.patch(guard_nn_modules=False) 5630 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) 5631 def test_inlining_cornercase(self): 5632 """ 5633 nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the 5634 tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs. 5635 5636 But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is 5637 UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of 5638 guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also 5639 cudagraphs which make it harder to fix the corner case right away. The long term solution is 5640 inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term. 5641 5642 We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked 5643 as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this 5644 changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior. 5645 5646 5647 To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on. 5648 5649 In this test, we purposely turn the flag off, testing that the tagging is disabled. 5650 """ 5651 5652 class SubMod(torch.nn.Module): 5653 def __init__(self): 5654 super().__init__() 5655 self.linear = torch.nn.Linear(1, 1) 5656 self.a = torch.randn(1, 1) 5657 self.counter = 0 5658 self.multipliers = [2.2, 3.3] 5659 5660 def forward(self, x): 5661 self.counter += 1 5662 return ( 5663 self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1] 5664 ) 5665 5666 class Mod(torch.nn.Module): 5667 def __init__(self): 5668 super().__init__() 5669 self.submod = SubMod() 5670 5671 def forward(self, x): 5672 return self.submod(x) 5673 5674 mod = Mod() 5675 opt_mod = torch.compile(mod, backend="eager") 5676 5677 x = torch.randn(1, 1) 5678 ref = mod(x) 5679 res = opt_mod(x) 5680 5681 mod.submod.multipliers = [3.3, 4.4] 5682 # Since guard_nn_modules is False, this will not recompile 5683 with torch._dynamo.config.patch(error_on_recompile=True): 5684 ref = mod(x) 5685 res = opt_mod(x) 5686 5687 def test_optimized_module_training(self): 5688 mod = torch.nn.Linear(3, 3) 5689 mod.eval() 5690 5691 opt_mod = torch.compile(mod, backend="eager") 5692 self.assertFalse(opt_mod.training) 5693 5694 opt_mod.train() 5695 self.assertTrue(opt_mod.training) 5696 self.assertTrue(mod.training) 5697 5698 mod.eval() 5699 self.assertFalse(opt_mod.training) 5700 5701 @requires_cuda 5702 def test_memleak_when_graph_input_has_tensor_attr(self): 5703 @torch.compile(backend="eager") 5704 def f(x): 5705 x.add_(1) 5706 5707 mem_before = torch.cuda.memory_allocated() 5708 5709 x = torch.ones(2, device="cuda") 5710 x.foo = torch.zeros(2, device="cuda") 5711 f(x) 5712 del x.foo 5713 del x 5714 mem_after = torch.cuda.memory_allocated() 5715 self.assertEqual(mem_before, mem_after) 5716 5717 # check when non-tensor data structure attribute contains a tensor 5718 @torch.compile(backend="eager") 5719 def f(x): 5720 x.add_(1) 5721 5722 mem_before = torch.cuda.memory_allocated() 5723 x = torch.ones(2, device="cuda") 5724 x.foo = [torch.zeros(2, device="cuda") for _ in range(5)] 5725 f(x) 5726 del x.foo 5727 del x 5728 mem_after = torch.cuda.memory_allocated() 5729 self.assertEqual(mem_before, mem_after) 5730 5731 # check with tensor refcycle 5732 @torch.compile(backend="eager") 5733 def g(x, y): 5734 return x + y 5735 5736 mem_before = torch.cuda.memory_allocated() 5737 x = torch.ones(2, device="cuda") 5738 y = torch.zeros(2, device="cuda") 5739 x.foo = [y] 5740 y.foo = [x] 5741 g(x, y) 5742 del x.foo 5743 del y.foo 5744 del x 5745 del y 5746 mem_after = torch.cuda.memory_allocated() 5747 self.assertEqual(mem_before, mem_after) 5748 5749 def test_os_fspath(self): 5750 @torch.compile(backend="eager", fullgraph=True) 5751 def fn(x): 5752 os.fspath(".") 5753 return torch.sin(x) 5754 5755 fn(torch.randn(4)) 5756 5757 @requires_cuda 5758 # This test will fail as flip in combination with particular input lenghts 5759 # produces weird results. 5760 # This is under investigations in 5761 # https://github.com/pytorch/pytorch/issues/131805 5762 @unittest.skip("Skip this flip test for the moment. It is under investigation") 5763 def test_flip_bad_accuracy(self): 5764 import torch 5765 import torch._dynamo.config 5766 import torch._functorch.config 5767 import torch._inductor.config 5768 import torch._inductor.inductor_prims 5769 import torch.fx.experimental._config 5770 5771 class Repro(torch.nn.Module): 5772 def __init__(self): 5773 super().__init__() 5774 5775 def forward(self, arg0_1): 5776 rev = torch.ops.prims.rev.default(arg0_1, [0]) 5777 arg0_1 = None 5778 slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2) 5779 slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2) 5780 add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2) 5781 slice_1 = slice_2 = None 5782 slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2) 5783 slice_4 = torch.ops.aten.slice.Tensor( 5784 add_1, 0, 1, 9223372036854775807, 2 5785 ) 5786 add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4) 5787 slice_3 = slice_4 = None 5788 slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2) 5789 slice_6 = torch.ops.aten.slice.Tensor( 5790 add_2, 0, 1, 9223372036854775807, 2 5791 ) 5792 add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6) 5793 slice_5 = slice_6 = None 5794 slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1) 5795 add_2 = None 5796 unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1) 5797 slice_9 = None 5798 unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1) 5799 add_3 = None 5800 cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) 5801 unsqueeze = unsqueeze_1 = None 5802 view = torch.ops.aten.view.default(cat, [2]) 5803 cat = None 5804 slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1) 5805 slice_11 = torch.ops.aten.slice.Tensor( 5806 add_1, 0, 2, 9223372036854775807, 2 5807 ) 5808 add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11) 5809 slice_10 = slice_11 = None 5810 slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1) 5811 add_1 = None 5812 cat_1 = torch.ops.aten.cat.default([slice_12, add_5]) 5813 slice_12 = add_5 = None 5814 unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1) 5815 cat_1 = None 5816 unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1) 5817 view = None 5818 cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1) 5819 unsqueeze_2 = unsqueeze_3 = None 5820 view_1 = torch.ops.aten.view.default(cat_2, [4]) 5821 cat_2 = None 5822 slice_13 = torch.ops.aten.slice.Tensor( 5823 rev, 0, 2, 9223372036854775807, 2 5824 ) 5825 add_6 = torch.ops.aten.add.Tensor(view_1, slice_13) 5826 slice_13 = None 5827 slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1) 5828 rev = None 5829 cat_3 = torch.ops.aten.cat.default([slice_14, add_6]) 5830 slice_14 = add_6 = None 5831 constant_pad_nd = torch.ops.aten.constant_pad_nd.default( 5832 view_1, [0, 1], 0.0 5833 ) 5834 view_1 = None 5835 unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1) 5836 cat_3 = None 5837 unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1) 5838 constant_pad_nd = None 5839 cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1) 5840 unsqueeze_4 = unsqueeze_5 = None 5841 view_2 = torch.ops.aten.view.default(cat_4, [10]) 5842 cat_4 = None 5843 slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9) 5844 view_2 = None 5845 rev_1 = torch.ops.prims.rev.default(slice_15, [0]) 5846 slice_15 = None 5847 return (rev_1,) 5848 5849 mod = Repro() 5850 x = torch.arange(9, device=torch.device("cuda")) 5851 5852 @torch.compile 5853 def f(x): 5854 return mod(x) 5855 5856 out = f(x) 5857 self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0]) 5858 5859 # https://github.com/pytorch/pytorch/issues/88813 5860 def test_return_value_duplication_tensor(self) -> None: 5861 def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5862 return val * 2, val * 2 5863 5864 x = torch.randn(2, requires_grad=True) 5865 5866 expect = fn(x) 5867 self.assertNotEqual( 5868 expect[0].untyped_storage().data_ptr(), 5869 expect[1].untyped_storage().data_ptr(), 5870 ) 5871 5872 actual = torch.compile(fn, backend="aot_eager")(x) 5873 self.assertNotEqual( 5874 actual[0].untyped_storage().data_ptr(), 5875 actual[1].untyped_storage().data_ptr(), 5876 ) 5877 5878 # https://github.com/pytorch/pytorch/issues/114344 5879 def test_return_value_duplication_mixed_grad(self) -> None: 5880 def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5881 with torch.no_grad(): 5882 out0 = val + 1 5883 out1 = val + 1 5884 return out0, out1 5885 5886 x = torch.randn(2, requires_grad=True) 5887 5888 with torch.enable_grad(): 5889 expect = fn(x) 5890 actual = torch.compile(fn, backend="aot_eager")(x) 5891 5892 self.assertEqual(expect[0].requires_grad, actual[0].requires_grad) 5893 self.assertEqual(expect[1].requires_grad, actual[1].requires_grad) 5894 5895 # https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371 5896 def test_return_value_duplication_scalar(self) -> None: 5897 def fn(val: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 5898 x, y = val * 2, val * 2 5899 return x[0], y[0] 5900 5901 x = torch.randn(2, requires_grad=True) 5902 5903 expect = fn(x) 5904 self.assertNotEqual( 5905 expect[0].untyped_storage().data_ptr(), 5906 expect[1].untyped_storage().data_ptr(), 5907 ) 5908 5909 actual = torch.compile(fn, backend="aot_eager")(x) 5910 self.assertNotEqual( 5911 actual[0].untyped_storage().data_ptr(), 5912 actual[1].untyped_storage().data_ptr(), 5913 ) 5914 5915 5916instantiate_parametrized_tests(ReproTests) 5917 5918 5919if __name__ == "__main__": 5920 from torch._dynamo.test_case import run_tests 5921 5922 run_tests() 5923