1from collections import OrderedDict 2 3import torch 4from torch._C import _disabled_torch_function_impl 5 6 7# Metaclass to combine _TensorMeta and the instance check override for Parameter. 8class _ParameterMeta(torch._C._TensorMeta): 9 # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. 10 def __instancecheck__(self, instance): 11 if self is Parameter: 12 if isinstance(instance, torch.Tensor) and getattr( 13 instance, "_is_param", False 14 ): 15 return True 16 return super().__instancecheck__(instance) 17 18 19class Parameter(torch.Tensor, metaclass=_ParameterMeta): 20 r"""A kind of Tensor that is to be considered a module parameter. 21 22 Parameters are :class:`~torch.Tensor` subclasses, that have a 23 very special property when used with :class:`Module` s - when they're 24 assigned as Module attributes they are automatically added to the list of 25 its parameters, and will appear e.g. in :meth:`~Module.parameters` iterator. 26 Assigning a Tensor doesn't have such effect. This is because one might 27 want to cache some temporary state, like last hidden state of the RNN, in 28 the model. If there was no such class as :class:`Parameter`, these 29 temporaries would get registered too. 30 31 Args: 32 data (Tensor): parameter tensor. 33 requires_grad (bool, optional): if the parameter requires gradient. Note that 34 the torch.no_grad() context does NOT affect the default behavior of 35 Parameter creation--the Parameter will still have `requires_grad=True` in 36 :class:`~no_grad` mode. See :ref:`locally-disable-grad-doc` for more 37 details. Default: `True` 38 """ 39 40 def __new__(cls, data=None, requires_grad=True): 41 if data is None: 42 data = torch.empty(0) 43 if type(data) is torch.Tensor or type(data) is Parameter: 44 # For ease of BC maintenance, keep this path for standard Tensor. 45 # Eventually (tm), we should change the behavior for standard Tensor to match. 46 return torch.Tensor._make_subclass(cls, data, requires_grad) 47 48 # Path for custom tensors: set a flag on the instance to indicate parameter-ness. 49 t = data.detach().requires_grad_(requires_grad) 50 if type(t) is not type(data): 51 raise RuntimeError( 52 f"Creating a Parameter from an instance of type {type(data).__name__} " 53 "requires that detach() returns an instance of the same type, but return " 54 f"type {type(t).__name__} was found instead. To use the type as a " 55 "Parameter, please correct the detach() semantics defined by " 56 "its __torch_dispatch__() implementation." 57 ) 58 t._is_param = True 59 return t 60 61 # Note: the 3 methods below only apply to standard Tensor. Parameters of custom tensor types 62 # are still considered that custom tensor type and these methods will not be called for them. 63 def __deepcopy__(self, memo): 64 if id(self) in memo: 65 return memo[id(self)] 66 else: 67 result = type(self)( 68 self.data.clone(memory_format=torch.preserve_format), self.requires_grad 69 ) 70 memo[id(self)] = result 71 return result 72 73 def __repr__(self): 74 return "Parameter containing:\n" + super().__repr__() 75 76 def __reduce_ex__(self, proto): 77 state = torch._utils._get_obj_state(self) 78 79 # See Note [Don't serialize hooks] 80 hooks = OrderedDict() 81 if not state: 82 return ( 83 torch._utils._rebuild_parameter, 84 (self.data, self.requires_grad, hooks), 85 ) 86 87 return ( 88 torch._utils._rebuild_parameter_with_state, 89 (self.data, self.requires_grad, hooks, state), 90 ) 91 92 __torch_function__ = _disabled_torch_function_impl 93 94 95class UninitializedTensorMixin: 96 _allowed_methods = [ 97 torch.Tensor.__hash__, 98 torch.Tensor.size, 99 torch.Tensor.copy_, 100 torch.Tensor.is_complex, 101 torch.Tensor.is_floating_point, 102 torch.Tensor.half, 103 torch.Tensor.float, 104 torch.Tensor.double, 105 torch.Tensor.char, 106 torch.Tensor.short, 107 torch.Tensor.int, 108 torch.Tensor.long, 109 torch.Tensor.cuda, 110 torch.Tensor.cpu, 111 torch.Tensor.to, 112 torch.Tensor.get_device, 113 torch._has_compatible_shallow_copy_type, 114 ] 115 116 def materialize(self, shape, device=None, dtype=None): 117 r"""Create a Parameter or Tensor with the same properties of the uninitialized one. 118 119 Given a shape, it materializes a parameter in the same device 120 and with the same `dtype` as the current one or the specified ones in the 121 arguments. 122 123 Args: 124 shape : (tuple): the shape for the materialized tensor. 125 device (:class:`torch.device`): the desired device of the parameters 126 and buffers in this module. Optional. 127 dtype (:class:`torch.dtype`): the desired floating point type of 128 the floating point parameters and buffers in this module. Optional. 129 """ 130 if device is None: 131 device = self.data.device 132 if dtype is None: 133 dtype = self.data.dtype 134 self.data = torch.empty(shape, device=device, dtype=dtype) 135 self.__class__ = self.cls_to_become 136 137 @property 138 def shape(self): 139 raise RuntimeError( 140 "Can't access the shape of an uninitialized parameter or buffer. " 141 "This error usually happens in `load_state_dict` when trying to load " 142 "an uninitialized parameter into an initialized one. " 143 "Call `forward` to initialize the parameters before accessing their attributes." 144 ) 145 146 def share_memory_(self): 147 raise RuntimeError( 148 "Can't share memory on an uninitialized parameter or buffer. " 149 "Call `forward` to initialize the parameters before calling " 150 "`module.share_memory()`." 151 ) 152 153 def __repr__(self): 154 return f"<{self.__class__.__name__}>" 155 156 def __reduce_ex__(self, proto): 157 # See Note [Don't serialize hooks] 158 return (self.__class__, (self.requires_grad,)) 159 160 @classmethod 161 def __torch_function__(cls, func, types, args=(), kwargs=None): 162 # method-wrapper is to detect access to Tensor properties that are 163 # wrapped in descriptors 164 if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": 165 if kwargs is None: 166 kwargs = {} 167 return super().__torch_function__(func, types, args, kwargs) 168 raise ValueError( 169 f"Attempted to use an uninitialized parameter in {func}. " 170 "This error happens when you are using a `LazyModule` or " 171 f"explicitly manipulating `torch.nn.parameter.{cls.__name__}` " 172 "objects. When using LazyModules Call `forward` with a dummy batch " 173 "to initialize the parameters before calling torch functions" 174 ) 175 176 177def is_lazy(param): 178 return isinstance(param, UninitializedTensorMixin) 179 180 181class UninitializedParameter(UninitializedTensorMixin, Parameter): 182 r"""A parameter that is not initialized. 183 184 Uninitialized Parameters are a a special case of :class:`torch.nn.Parameter` 185 where the shape of the data is still unknown. 186 187 Unlike a :class:`torch.nn.Parameter`, uninitialized parameters 188 hold no data and attempting to access some properties, like their shape, 189 will throw a runtime error. The only operations that can be performed on a uninitialized 190 parameter are changing its datatype, moving it to a different device and 191 converting it to a regular :class:`torch.nn.Parameter`. 192 193 The default device or dtype to use when the parameter is materialized can be set 194 during construction using e.g. ``device='cuda'``. 195 """ 196 197 cls_to_become = Parameter 198 199 def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: 200 factory_kwargs = {"device": device, "dtype": dtype} 201 data = torch.empty(0, **factory_kwargs) 202 return torch.Tensor._make_subclass(cls, data, requires_grad) 203 204 def __deepcopy__(self, memo): 205 if id(self) in memo: 206 return memo[id(self)] 207 else: 208 result = type(self)(self.requires_grad, self.data.device, self.data.dtype) 209 memo[id(self)] = result 210 return result 211 212 213# Metaclass to combine _TensorMeta and the instance check override for Buffer. 214class _BufferMeta(torch._C._TensorMeta): 215 # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. 216 def __instancecheck__(self, instance): 217 if self is Buffer: 218 if isinstance(instance, torch.Tensor) and getattr( 219 instance, "_is_buffer", False 220 ): 221 return True 222 return super().__instancecheck__(instance) 223 224 225class Buffer(torch.Tensor, metaclass=_BufferMeta): 226 r"""A kind of Tensor that should not be considered a model 227 parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. 228 229 Buffers are :class:`~torch.Tensor` subclasses, that have a 230 very special property when used with :class:`Module` s -- when they're 231 assigned as Module attributes they are automatically added to the list of 232 its buffers, and will appear e.g. in :meth:`~torch.nn.Module.buffers` iterator. 233 Assigning a Tensor doesn't have such effect. One can still assign a Tensor as explicitly by using 234 the :meth:`~torch.nn.Module.register_buffer` function. 235 236 Args: 237 data (Tensor): buffer tensor. 238 persistent (bool, optional): whether the buffer is part of the module's 239 :attr:`state_dict`. Default: ``True`` 240 """ 241 242 def __new__(cls, data=None, *, persistent=True): 243 if data is None: 244 data = torch.empty(0) 245 246 t = data.detach().requires_grad_(data.requires_grad) 247 t.persistent = persistent 248 t._is_buffer = True 249 return t 250 251 __torch_function__ = _disabled_torch_function_impl 252 253 254class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): 255 r"""A buffer that is not initialized. 256 257 Uninitialized Buffer is a a special case of :class:`torch.Tensor` 258 where the shape of the data is still unknown. 259 260 Unlike a :class:`torch.Tensor`, uninitialized parameters 261 hold no data and attempting to access some properties, like their shape, 262 will throw a runtime error. The only operations that can be performed on a uninitialized 263 parameter are changing its datatype, moving it to a different device and 264 converting it to a regular :class:`torch.Tensor`. 265 266 The default device or dtype to use when the buffer is materialized can be set 267 during construction using e.g. ``device='cuda'``. 268 """ 269 270 cls_to_become = torch.Tensor 271 272 def __new__( 273 cls, requires_grad=False, device=None, dtype=None, persistent=True 274 ) -> None: 275 factory_kwargs = {"device": device, "dtype": dtype} 276 data = torch.empty(0, **factory_kwargs) 277 ret = torch.Tensor._make_subclass(cls, data, requires_grad) 278 ret.persistent = persistent 279 ret._is_buffer = True 280 return ret 281