1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import torch.nn as nn 4from torch._functorch.utils import exposed_in 5 6 7def batch_norm_without_running_stats(module: nn.Module): 8 if ( 9 isinstance(module, nn.modules.batchnorm._BatchNorm) 10 and module.track_running_stats 11 ): 12 module.running_mean = None 13 module.running_var = None 14 module.num_batches_tracked = None 15 module.track_running_stats = False 16 17 18@exposed_in("torch.func") 19def replace_all_batch_norm_modules_(root: nn.Module) -> nn.Module: 20 """ 21 In place updates :attr:`root` by setting the ``running_mean`` and ``running_var`` to be None and 22 setting track_running_stats to be False for any nn.BatchNorm module in :attr:`root` 23 """ 24 # base case 25 batch_norm_without_running_stats(root) 26 27 for obj in root.modules(): 28 batch_norm_without_running_stats(obj) 29 return root 30