xref: /aosp_15_r20/external/pytorch/torch/_dynamo/create_parameter_op.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import threading
3from contextlib import contextmanager
4
5import torch
6
7
8doc = """
9This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
10with AOTAutograd.  We instead create a placeholder torch.nn.Parameter before the graph, which
11becomes a graph arg and has no storage backing it.  At the point in the graph where the parameter
12actually should be created we mutate this sacrificial placeholder into it.  This allows gradients
13to flow into the parameter as if it were an input to the graph (which is the only thing we are
14allowed to compute gradients on).
15""".strip()
16
17
18class TracableCreateParameter(torch.autograd.Function):
19    @staticmethod
20    def forward(ctx, tensor, placeholder):
21        assert not tensor.requires_grad
22        return placeholder.set_(tensor)
23
24    @staticmethod
25    def backward(ctx, grad):
26        return None, grad  # grad flows to placeholder
27
28
29def tracable_create_parameter(tensor, placeholder):
30    with torch.set_grad_enabled(placeholder.requires_grad):
31        out = TracableCreateParameter.apply(tensor, placeholder)
32    return out
33
34
35def new_parameter_placeholder(size, dtype, device, requires_grad):
36    """Create a placeholder to be passed to the above functions"""
37    result = torch.nn.Parameter(
38        torch.empty(size, dtype=dtype, device=device), requires_grad=requires_grad
39    )
40    # TODO(jansel): alloc followed by free is inefficient, need a way to allocate an unbacked tensor.
41    # Allocating a zero tensor would causes assert failures in autograd.
42    result.untyped_storage().resize_(0)
43    return result
44
45
46_TLS = threading.local()
47
48
49@contextmanager
50def do_not_convert_to_tracable_parameter():
51    old_flag = getattr(_TLS, "convert_tracable_parameter", True)
52    _TLS.convert_tracable_parameter = False
53    try:
54        yield False
55    finally:
56        _TLS.convert_tracable_parameter = old_flag
57
58
59def can_convert_to_tracable_parameter():
60    return getattr(_TLS, "convert_tracable_parameter", True)
61