1# mypy: allow-untyped-defs 2import threading 3from contextlib import contextmanager 4 5import torch 6 7 8doc = """ 9This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly 10with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which 11becomes a graph arg and has no storage backing it. At the point in the graph where the parameter 12actually should be created we mutate this sacrificial placeholder into it. This allows gradients 13to flow into the parameter as if it were an input to the graph (which is the only thing we are 14allowed to compute gradients on). 15""".strip() 16 17 18class TracableCreateParameter(torch.autograd.Function): 19 @staticmethod 20 def forward(ctx, tensor, placeholder): 21 assert not tensor.requires_grad 22 return placeholder.set_(tensor) 23 24 @staticmethod 25 def backward(ctx, grad): 26 return None, grad # grad flows to placeholder 27 28 29def tracable_create_parameter(tensor, placeholder): 30 with torch.set_grad_enabled(placeholder.requires_grad): 31 out = TracableCreateParameter.apply(tensor, placeholder) 32 return out 33 34 35def new_parameter_placeholder(size, dtype, device, requires_grad): 36 """Create a placeholder to be passed to the above functions""" 37 result = torch.nn.Parameter( 38 torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad 39 ) 40 # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor. 41 # Allocating a zero tensor would causes assert failures in autograd. 42 result.untyped_storage().resize_(0) 43 return result 44 45 46_TLS = threading.local() 47 48 49@contextmanager 50def do_not_convert_to_tracable_parameter(): 51 old_flag = getattr(_TLS, "convert_tracable_parameter", True) 52 _TLS.convert_tracable_parameter = False 53 try: 54 yield False 55 finally: 56 _TLS.convert_tracable_parameter = old_flag 57 58 59def can_convert_to_tracable_parameter(): 60 return getattr(_TLS, "convert_tracable_parameter", True) 61