1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 4 5import torch 6import torch.nn as nn 7from torch import Tensor 8from torch._functorch.utils import exposed_in 9 10 11@exposed_in("torch.func") 12def functional_call( 13 module: "torch.nn.Module", 14 parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]], 15 args: Union[Any, Tuple], 16 kwargs: Optional[Dict[str, Any]] = None, 17 *, 18 tie_weights: bool = True, 19 strict: bool = False, 20): 21 r"""Performs a functional call on the module by replacing the module parameters 22 and buffers with the provided ones. 23 24 .. note:: If the module has active parametrizations, passing a value in the 25 :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter 26 name will completely disable the parametrization. 27 If you want to apply the parametrization function to the value passed 28 please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. 29 30 .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected 31 in the ``parameter_and_buffer_dicts`` input. 32 33 34 Example:: 35 36 >>> a = {'foo': torch.zeros(())} 37 >>> # xdoctest: +SKIP 38 >>> mod = Foo() # does self.foo = self.foo + 1 39 >>> print(mod.foo) # tensor(0.) 40 >>> functional_call(mod, a, torch.ones(())) 41 >>> print(mod.foo) # tensor(0.) 42 >>> print(a['foo']) # tensor(1.) 43 44 .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the 45 tie_weights flag. 46 47 Example:: 48 49 >>> a = {'foo': torch.zeros(())} 50 >>> # xdoctest: +SKIP 51 >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied 52 >>> print(mod.foo) # tensor(1.) 53 >>> mod(torch.zeros(())) # tensor(2.) 54 >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too 55 >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated 56 >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} 57 >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) 58 59 An example of passing multiple dictionaries 60 61 .. code-block:: python 62 63 a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries 64 mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer 65 print(mod.weight) # tensor(...) 66 print(mod.buffer) # tensor(...) 67 x = torch.randn((1, 1)) 68 print(x) 69 functional_call(mod, a, x) # same as x 70 print(mod.weight) # same as before functional_call 71 72 73 And here is an example of applying the grad transform over the parameters 74 of a model. 75 76 .. code-block:: python 77 78 import torch 79 import torch.nn as nn 80 from torch.func import functional_call, grad 81 82 x = torch.randn(4, 3) 83 t = torch.randn(4, 3) 84 model = nn.Linear(3, 3) 85 86 def compute_loss(params, x, t): 87 y = functional_call(model, params, x) 88 return nn.functional.mse_loss(y, t) 89 90 grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t) 91 92 .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the 93 parameters for better performance and memory usage 94 95 Example:: 96 97 >>> detached_params = {k: v.detach() for k, v in model.named_parameters()} 98 >>> grad_weights = grad(compute_loss)(detached_params, x, t) 99 >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad 100 101 This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking 102 outside of the transforms, this will result in less memory usage and faster speeds. 103 104 Args: 105 module (torch.nn.Module): the module to call 106 parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in 107 the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can 108 be used together 109 args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. 110 kwargs (dict): keyword arguments to be passed to the module call 111 tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as 112 tied in the reparameterized version. Therefore, if True and different values are passed for the tied 113 parameters and buffers, it will error. If False, it will not respect the originally tied parameters and 114 buffers unless the values passed for both weights are the same. Default: True. 115 strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and 116 buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will 117 error. Default: False. 118 119 Returns: 120 Any: the result of calling ``module``. 121 """ 122 if isinstance(parameter_and_buffer_dicts, dict): 123 parameters_and_buffers = parameter_and_buffer_dicts 124 elif isinstance(parameter_and_buffer_dicts, Sequence): 125 if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts): 126 raise ValueError( 127 "Expected all elements of parameter_and_buffer_dicts to be dictionaries" 128 ) 129 all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] 130 all_keys_counter: Dict[str, int] = {} 131 for k in all_keys: 132 v = all_keys_counter.get(k, 0) 133 all_keys_counter[k] = v + 1 134 repeated_keys = [key for key, n in all_keys_counter.items() if n > 1] 135 if len(repeated_keys) > 0: 136 raise ValueError( 137 f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous" 138 ) 139 parameters_and_buffers = { 140 k: v for d in parameter_and_buffer_dicts for k, v in d.items() 141 } 142 else: 143 raise ValueError( 144 f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, " 145 f"but got {type(parameter_and_buffer_dicts)}" 146 ) 147 148 return nn.utils.stateless._functional_call( 149 module, 150 parameters_and_buffers, 151 args, 152 kwargs, 153 tie_weights=tie_weights, 154 strict=strict, 155 ) 156 157 158@exposed_in("torch.func") 159def stack_module_state( 160 models: List[nn.Module], 161) -> Tuple[Dict[str, Any], Dict[str, Any]]: 162 """stack_module_state(models) -> params, buffers 163 164 Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. 165 166 Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries 167 that stack all of their parameters and buffers together, indexed by name. 168 The stacked parameters are optimizable (i.e. they are new leaf nodes in the 169 autograd history that are unrelated to the original parameters and can be 170 passed directly to an optimizer). 171 172 Here's an example of how to ensemble over a very simple model: 173 174 .. code-block:: python 175 176 num_models = 5 177 batch_size = 64 178 in_features, out_features = 3, 3 179 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 180 data = torch.randn(batch_size, 3) 181 182 def wrapper(params, buffers, data): 183 return torch.func.functional_call(models[0], (params, buffers), data) 184 185 params, buffers = stack_module_state(models) 186 output = vmap(wrapper, (0, 0, None))(params, buffers, data) 187 188 assert output.shape == (num_models, batch_size, out_features) 189 190 When there's submodules, this follows state dict naming conventions 191 192 .. code-block:: python 193 194 import torch.nn as nn 195 class Foo(nn.Module): 196 def __init__(self, in_features, out_features): 197 super().__init__() 198 hidden = 4 199 self.l1 = nn.Linear(in_features, hidden) 200 self.l2 = nn.Linear(hidden, out_features) 201 202 def forward(self, x): 203 return self.l2(self.l1(x)) 204 205 num_models = 5 206 in_features, out_features = 3, 3 207 models = [Foo(in_features, out_features) for i in range(num_models)] 208 params, buffers = stack_module_state(models) 209 print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias" 210 211 .. warning:: 212 All of the modules being stacked together must be the same (except for 213 the values of their parameters/buffers). For example, they should be in the 214 same mode (training vs eval). 215 """ 216 if len(models) == 0: 217 raise RuntimeError("stack_module_state: Expected at least one model, got 0.") 218 if not (all(m.training for m in models) or all(not m.training for m in models)): 219 raise RuntimeError( 220 "stack_module_state: Expected all models to have the same training/eval mode." 221 ) 222 model0_typ = type(models[0]) 223 if not all(type(m) == model0_typ for m in models): 224 raise RuntimeError( 225 "stack_module_state: Expected all models to be of the same class." 226 ) 227 all_params = [dict(model.named_parameters()) for model in models] 228 params = { 229 k: construct_stacked_leaf(tuple(params[k] for params in all_params), k) 230 for k in all_params[0] 231 } 232 all_buffers = [dict(model.named_buffers()) for model in models] 233 buffers = { 234 k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k) 235 for k in all_buffers[0] 236 } 237 238 return params, buffers 239 240 241def construct_stacked_leaf( 242 tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str 243) -> Tensor: 244 all_requires_grad = all(t.requires_grad for t in tensors) 245 none_requires_grad = all(not t.requires_grad for t in tensors) 246 if not all_requires_grad and not none_requires_grad: 247 raise RuntimeError( 248 f"Expected {name} from each model to have the same .requires_grad" 249 ) 250 result = torch.stack(tensors) 251 if all_requires_grad: 252 result = result.detach().requires_grad_() 253 return result 254