1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import functools 4from collections import defaultdict 5from typing import Callable, Dict 6 7import torch 8import torch._decomp as decomp 9from torch._decomp import get_decompositions 10from torch._ops import OpOverload 11 12 13aten = torch.ops.aten 14 15rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict) 16 17 18def register_rng_decomposition(aten_op): 19 return decomp.register_decomposition(aten_op, rng_decompositions) 20 21 22def throw_on_non_cuda(device): 23 raise RuntimeError( 24 f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " 25 f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " 26 "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." 27 ) 28 29 30# TODO - We have to register many more distributions here, and also higher level 31# ops like dropout which have fused implementation and can hide the rand inside. 32@register_rng_decomposition(aten.rand) 33def rand(shape, dtype=None, layout=torch.strided, device=None, pin_memory=False): 34 if device and device.type != "cuda": 35 throw_on_non_cuda(device) 36 seed, offset = PhiloxStateTracker.get_state_as_tuple() 37 dtype = dtype or torch.float32 38 out, offset_jump = torch.ops.rngprims.philox_rand( 39 shape, seed, offset, None, device, dtype 40 ) 41 PhiloxStateTracker.advance_offset(offset_jump) 42 return out 43 44 45@register_rng_decomposition(aten.rand_like) 46def rand_like( 47 x: torch.Tensor, 48 dtype=None, 49 layout=None, 50 device=None, 51 pin_memory=False, 52 memory_format=torch.preserve_format, 53): 54 device = device or x.device 55 if device.type != "cuda": 56 throw_on_non_cuda(device) 57 dtype = dtype or x.dtype 58 seed, offset = PhiloxStateTracker.get_state_as_tuple() 59 out, offset_jump = torch.ops.rngprims.philox_rand( 60 x.shape, seed, offset, None, device, dtype 61 ) 62 PhiloxStateTracker.advance_offset(offset_jump) 63 return out 64 65 66class PhiloxState: 67 """ 68 Represents a PhiloxRngState - (seed, offset) where offset = base_offset + 69 relative_offset. seed and base_offset basically point to the rng state just 70 before tracing starts. relative offset tracks the totally consumed offset at 71 trace time. 72 """ 73 74 def __init__(self) -> None: 75 self.reset() 76 77 def reset(self): 78 self.seed = torch.tensor(()) 79 self.base_offset = torch.tensor(()) 80 self.relative_offset = 0 81 self.offset_advanced_alteast_once = False 82 83 def validate_state(self): 84 assert self.seed.numel() != 0 and self.base_offset.numel() != 0 85 86 def advance_offset(self, consumed_offset): 87 self.offset_advanced_alteast_once = True 88 self.relative_offset = self.relative_offset + consumed_offset 89 90 def set_state(self, seed, base_offset, relative_offset=0): 91 self.seed = seed 92 self.base_offset = base_offset 93 self.relative_offset = relative_offset 94 95 def get_state_as_tuple(self): 96 self.validate_state() 97 return (self.seed, self.base_offset + self.relative_offset) 98 99 def get_state_as_tensor(self): 100 # Only needed because we override get_rng_state. 101 self.validate_state() 102 return torch.stack([self.seed, self.base_offset + self.relative_offset]) 103 104 def set_state_from_tensor(self, state): 105 # Only needed because we override set_rng_state. 106 self.seed, self.base_offset = torch.unbind(state) 107 self.relative_offset = 0 108 109 110class PhiloxStateTracker: 111 """ 112 Singleton class to track the philox rng state during AOT Autograd tracing. 113 For each aot tracing instance, AOT Autograd resets this tracker and keeps 114 track of both forward and backward offsets. At runtime, we only care about 115 the total consumed forward and backward offsets. For dynamic shapes, these 116 offsets are a function of input shapes. Therefore, the AOT generated graphs 117 have additional outputs that compute total consumed forward and backward 118 offsets. 119 """ 120 121 running_state: PhiloxState 122 fwd_state: PhiloxState 123 bwd_state: PhiloxState 124 125 def __enter__(self): 126 PhiloxStateTracker.reset() 127 return self 128 129 def __exit__(self, exc_type, exc_cal, exc_tb): 130 PhiloxStateTracker.reset() 131 132 @classmethod 133 def reset(cls): 134 cls.running_state = PhiloxState() 135 cls.fwd_state = PhiloxState() 136 cls.bwd_state = PhiloxState() 137 138 @classmethod 139 def mark_beginning_of_forward(cls): 140 # Tells the tracker to use fwd_state as the running state 141 cls.running_state = cls.fwd_state 142 143 @classmethod 144 def mark_beginning_of_backward(cls): 145 # Tells the tracker to use bwd_state as the running state 146 cls.running_state = cls.bwd_state 147 148 @classmethod 149 def record_state(cls, seed, offset, mode): 150 # Records the seed and offset tensors. These tensors are used to invoke 151 # the philox_rand functional primitives. 152 if mode == "forward": 153 cls.fwd_state.set_state(seed, offset) 154 cls.mark_beginning_of_forward() 155 else: 156 assert mode == "backward" 157 cls.bwd_state.set_state(seed, offset) 158 159 @classmethod 160 def get_state_as_tensor(cls): 161 # The only reason this exists is because we override get_rng_state and 162 # set_rng_state during tracing. get_rng_state expects a tensor output, 163 # so return (seed, offset) tuple upset other parts of the program like 164 # ctx.saved_tensors. 165 166 # A bad consequence is that if user saves and restores rng state, we 167 # have little bit of ugliness in the generated code, where we first 168 # concat the (seed, offset) to create a tensor for get_rng_state, and 169 # then split it back to get (seed, offset) tuple in set_rng_state. 170 171 # TODO: Investigate if there is be a better way to wrap the tuple in a 172 # false Tensor object, and then desugar it later on. 173 return cls.running_state.get_state_as_tensor() 174 175 @classmethod 176 def get_state_as_tuple(cls): 177 return cls.running_state.get_state_as_tuple() 178 179 @classmethod 180 def set_state_from_tensor(cls, x): 181 # This is only needed because we override set_rng_state. Look at the 182 # comment in get_state_from_tensor method. 183 cls.running_state.set_state_from_tensor(x) 184 185 @classmethod 186 def advance_offset(cls, consumed_offset): 187 cls.running_state.advance_offset(consumed_offset) 188 189 @classmethod 190 def get_current_relative_offset(cls): 191 return cls.running_state.relative_offset 192 193 @staticmethod 194 def multiple_of_4(offset): 195 # torch cuda rng state offset must be a multiple of 4. For inductor, as 196 # we sum up all the numel, the result might not be a multiple of 4. This 197 # method achieves that. 198 return (offset + 3) // 4 * 4 199 200 @classmethod 201 def get_updated_fwd_offset(cls): 202 # Short circuit if no rand ops were observed 203 if not cls.fwd_state.offset_advanced_alteast_once: 204 return cls.fwd_state.base_offset 205 return cls.multiple_of_4( 206 cls.fwd_state.base_offset + cls.fwd_state.relative_offset 207 ) 208 209 @classmethod 210 def get_updated_bwd_offset(cls): 211 # Short circuit if no rand ops were observed 212 if not cls.bwd_state.offset_advanced_alteast_once: 213 return cls.bwd_state.base_offset 214 return cls.multiple_of_4( 215 cls.bwd_state.base_offset + cls.bwd_state.relative_offset 216 ) 217 218 219# Adding more decompositions which eventually use rand_like inside decomps. 220# Adding these in rng_decompositions ensures the functionalization of rand_like 221# ops used in these decomps. The list is copied from inductor codebase, which 222# uses it for similar purpose. 223# 224# Caution - These decomps do not have same accuracy as that of eager. However, 225# we can't just disable them with a config flag like fallback_random, because 226# for functionalization of rng ops, we have to decompose these ops. 227extra_random_decomps = get_decompositions( 228 [ 229 aten.cauchy, 230 aten.cauchy_, 231 aten.exponential, 232 aten.exponential_, 233 aten.geometric, 234 aten.geometric_, 235 aten.native_dropout, 236 aten.normal, 237 aten.normal_, 238 aten.normal_functional, 239 aten.log_normal, 240 aten.log_normal_, 241 aten.rrelu_with_noise, 242 aten.rrelu_with_noise_, 243 aten.uniform_, 244 ] 245) 246register_extra_random_decomp = functools.partial( 247 decomp.register_decomposition, registry=extra_random_decomps 248) 249 250 251@register_extra_random_decomp([aten.bernoulli_]) 252def bernoulli_(self, p=0.5): 253 if self.device == torch.device("cpu"): 254 return NotImplemented 255 return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) 256 257 258@register_extra_random_decomp([aten.bernoulli.p]) 259def bernoulli_p(self, p=0.5, *, generator=None): 260 if self.device == torch.device("cpu"): 261 return NotImplemented 262 assert generator is None 263 return torch.rand_like(self, dtype=torch.float32) < p 264 265 266rng_decompositions.update(extra_random_decomps) # type: ignore[arg-type] 267