xref: /aosp_15_r20/external/pytorch/torch/_decomp/decompositions_for_rng.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import functools
4from collections import defaultdict
5from typing import Callable, Dict
6
7import torch
8import torch._decomp as decomp
9from torch._decomp import get_decompositions
10from torch._ops import OpOverload
11
12
13aten = torch.ops.aten
14
15rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
16
17
18def register_rng_decomposition(aten_op):
19    return decomp.register_decomposition(aten_op, rng_decompositions)
20
21
22def throw_on_non_cuda(device):
23    raise RuntimeError(
24        f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
25        f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
26        "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
27    )
28
29
30# TODO - We have to register many more distributions here, and also higher level
31# ops like dropout which have fused implementation and can hide the rand inside.
32@register_rng_decomposition(aten.rand)
33def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False):
34    if device and device.type != "cuda":
35        throw_on_non_cuda(device)
36    seed, offset = PhiloxStateTracker.get_state_as_tuple()
37    dtype = dtype or torch.float32
38    out, offset_jump = torch.ops.rngprims.philox_rand(
39        shape, seed, offset, None, device, dtype
40    )
41    PhiloxStateTracker.advance_offset(offset_jump)
42    return out
43
44
45@register_rng_decomposition(aten.rand_like)
46def rand_like(
47    x: torch.Tensor,
48    dtype=None,
49    layout=None,
50    device=None,
51    pin_memory=False,
52    memory_format=torch.preserve_format,
53):
54    device = device or x.device
55    if device.type != "cuda":
56        throw_on_non_cuda(device)
57    dtype = dtype or x.dtype
58    seed, offset = PhiloxStateTracker.get_state_as_tuple()
59    out, offset_jump = torch.ops.rngprims.philox_rand(
60        x.shape, seed, offset, None, device, dtype
61    )
62    PhiloxStateTracker.advance_offset(offset_jump)
63    return out
64
65
66class PhiloxState:
67    """
68    Represents a PhiloxRngState - (seed, offset) where offset = base_offset +
69    relative_offset. seed and base_offset basically point to the rng state just
70    before tracing starts. relative offset tracks the totally consumed offset at
71    trace time.
72    """
73
74    def __init__(self) -> None:
75        self.reset()
76
77    def reset(self):
78        self.seed = torch.tensor(())
79        self.base_offset = torch.tensor(())
80        self.relative_offset = 0
81        self.offset_advanced_alteast_once = False
82
83    def validate_state(self):
84        assert self.seed.numel() != 0 and self.base_offset.numel() != 0
85
86    def advance_offset(self, consumed_offset):
87        self.offset_advanced_alteast_once = True
88        self.relative_offset = self.relative_offset + consumed_offset
89
90    def set_state(self, seed, base_offset, relative_offset=0):
91        self.seed = seed
92        self.base_offset = base_offset
93        self.relative_offset = relative_offset
94
95    def get_state_as_tuple(self):
96        self.validate_state()
97        return (self.seed, self.base_offset + self.relative_offset)
98
99    def get_state_as_tensor(self):
100        # Only needed because we override get_rng_state.
101        self.validate_state()
102        return torch.stack([self.seed, self.base_offset + self.relative_offset])
103
104    def set_state_from_tensor(self, state):
105        # Only needed because we override set_rng_state.
106        self.seed, self.base_offset = torch.unbind(state)
107        self.relative_offset = 0
108
109
110class PhiloxStateTracker:
111    """
112    Singleton class to track the philox rng state during AOT Autograd tracing.
113    For each aot tracing instance, AOT Autograd resets this tracker and keeps
114    track of both forward and backward offsets. At runtime, we only care about
115    the total consumed forward and backward offsets. For dynamic shapes, these
116    offsets are a function of input shapes. Therefore, the AOT generated graphs
117    have additional outputs that compute total consumed forward and backward
118    offsets.
119    """
120
121    running_state: PhiloxState
122    fwd_state: PhiloxState
123    bwd_state: PhiloxState
124
125    def __enter__(self):
126        PhiloxStateTracker.reset()
127        return self
128
129    def __exit__(self, exc_type, exc_cal, exc_tb):
130        PhiloxStateTracker.reset()
131
132    @classmethod
133    def reset(cls):
134        cls.running_state = PhiloxState()
135        cls.fwd_state = PhiloxState()
136        cls.bwd_state = PhiloxState()
137
138    @classmethod
139    def mark_beginning_of_forward(cls):
140        # Tells the tracker to use fwd_state as the running state
141        cls.running_state = cls.fwd_state
142
143    @classmethod
144    def mark_beginning_of_backward(cls):
145        # Tells the tracker to use bwd_state as the running state
146        cls.running_state = cls.bwd_state
147
148    @classmethod
149    def record_state(cls, seed, offset, mode):
150        # Records the seed and offset tensors. These tensors are used to invoke
151        # the philox_rand functional primitives.
152        if mode == "forward":
153            cls.fwd_state.set_state(seed, offset)
154            cls.mark_beginning_of_forward()
155        else:
156            assert mode == "backward"
157            cls.bwd_state.set_state(seed, offset)
158
159    @classmethod
160    def get_state_as_tensor(cls):
161        # The only reason this exists is because we override get_rng_state and
162        # set_rng_state during tracing. get_rng_state expects a tensor output,
163        # so return (seed, offset) tuple upset other parts of the program like
164        # ctx.saved_tensors.
165
166        # A bad consequence is that if user saves and restores rng state, we
167        # have little bit of ugliness in the generated code, where we first
168        # concat the (seed, offset) to create a tensor for get_rng_state, and
169        # then split it back to get (seed, offset) tuple in set_rng_state.
170
171        # TODO: Investigate if there is be a better way to wrap the tuple in a
172        # false Tensor object, and then desugar it later on.
173        return cls.running_state.get_state_as_tensor()
174
175    @classmethod
176    def get_state_as_tuple(cls):
177        return cls.running_state.get_state_as_tuple()
178
179    @classmethod
180    def set_state_from_tensor(cls, x):
181        # This is only needed because we override set_rng_state. Look at the
182        # comment in get_state_from_tensor method.
183        cls.running_state.set_state_from_tensor(x)
184
185    @classmethod
186    def advance_offset(cls, consumed_offset):
187        cls.running_state.advance_offset(consumed_offset)
188
189    @classmethod
190    def get_current_relative_offset(cls):
191        return cls.running_state.relative_offset
192
193    @staticmethod
194    def multiple_of_4(offset):
195        # torch cuda rng state offset must be a multiple of 4. For inductor, as
196        # we sum up all the numel, the result might not be a multiple of 4. This
197        # method achieves that.
198        return (offset + 3) // 4 * 4
199
200    @classmethod
201    def get_updated_fwd_offset(cls):
202        # Short circuit if no rand ops were observed
203        if not cls.fwd_state.offset_advanced_alteast_once:
204            return cls.fwd_state.base_offset
205        return cls.multiple_of_4(
206            cls.fwd_state.base_offset + cls.fwd_state.relative_offset
207        )
208
209    @classmethod
210    def get_updated_bwd_offset(cls):
211        # Short circuit if no rand ops were observed
212        if not cls.bwd_state.offset_advanced_alteast_once:
213            return cls.bwd_state.base_offset
214        return cls.multiple_of_4(
215            cls.bwd_state.base_offset + cls.bwd_state.relative_offset
216        )
217
218
219# Adding more decompositions which eventually use rand_like inside decomps.
220# Adding these in rng_decompositions ensures the functionalization of rand_like
221# ops used in these decomps. The list is copied from inductor codebase, which
222# uses it for similar purpose.
223#
224# Caution - These decomps do not have same accuracy as that of eager. However,
225# we can't just disable them with a config flag like fallback_random, because
226# for functionalization of rng ops, we have to decompose these ops.
227extra_random_decomps = get_decompositions(
228    [
229        aten.cauchy,
230        aten.cauchy_,
231        aten.exponential,
232        aten.exponential_,
233        aten.geometric,
234        aten.geometric_,
235        aten.native_dropout,
236        aten.normal,
237        aten.normal_,
238        aten.normal_functional,
239        aten.log_normal,
240        aten.log_normal_,
241        aten.rrelu_with_noise,
242        aten.rrelu_with_noise_,
243        aten.uniform_,
244    ]
245)
246register_extra_random_decomp = functools.partial(
247    decomp.register_decomposition, registry=extra_random_decomps
248)
249
250
251@register_extra_random_decomp([aten.bernoulli_])
252def bernoulli_(self, p=0.5):
253    if self.device == torch.device("cpu"):
254        return NotImplemented
255    return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
256
257
258@register_extra_random_decomp([aten.bernoulli.p])
259def bernoulli_p(self, p=0.5, *, generator=None):
260    if self.device == torch.device("cpu"):
261        return NotImplemented
262    assert generator is None
263    return torch.rand_like(self, dtype=torch.float32) < p
264
265
266rng_decompositions.update(extra_random_decomps)  # type: ignore[arg-type]
267