1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Helper functions for tranforming the model to be able to run SpinQuant. 10# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. 11 12 13import torch 14 15import torch.nn.functional as F 16 17from executorch.examples.models.llama.llama_transformer import FeedForward 18from torch import nn 19 20 21def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): 22 """ 23 SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. 24 R3 needs to be injected as well when KV cache quantization is enabled. 25 """ 26 try: 27 from fast_hadamard_transform import hadamard_transform 28 except ImportError: 29 raise ImportError( 30 "Please install fast-hadamard-transform: pip install fast-hadamard-transform" 31 ) 32 33 class FeedForwardCudaCustom(nn.Module): 34 def __init__(self, w1, w2, w3): 35 super().__init__() 36 self.w1 = w1 37 self.w2 = w2 38 self.w3 = w3 39 40 def forward(self, x): 41 w = F.silu(self.w1(x)) * self.w3(x) 42 n = w.shape[-1] 43 return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt()) 44 45 for name, child in module.named_children(): 46 if isinstance(child, FeedForward): 47 setattr(module, name, FeedForwardCudaCustom(child.w1, child.w2, child.w3)) 48 else: 49 _inject_fast_hadamard_transform_cuda_for_spin_quant(child) 50 51 52def inject_fast_hadamard_transform_cuda_for_spin_quant( 53 module: torch.nn.Module, 54) -> torch.nn.Module: 55 _inject_fast_hadamard_transform_cuda_for_spin_quant(module) 56 return module 57 58 59def _inject_fast_hadamard_transform_native_for_spin_quant(module: torch.nn.Module): 60 """ 61 SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. 62 R3 needs to be injected as well when KV cache quantization is enabled. 63 """ 64 65 class FeedForwardNativeCustom(nn.Module): 66 def __init__(self, w1, w2, w3): 67 super().__init__() 68 self.w1 = w1 69 self.w2 = w2 70 self.w3 = w3 71 72 def forward(self, x): 73 return self.w2( 74 torch.ops.llama.fast_hadamard_transform(F.silu(self.w1(x)) * self.w3(x)) 75 ) 76 77 for name, child in module.named_children(): 78 if isinstance(child, FeedForward): 79 setattr(module, name, FeedForwardNativeCustom(child.w1, child.w2, child.w3)) 80 else: 81 _inject_fast_hadamard_transform_native_for_spin_quant(child) 82 83 84def inject_fast_hadamard_transform_native_for_spin_quant( 85 module: torch.nn.Module, 86) -> torch.nn.Module: 87 _inject_fast_hadamard_transform_native_for_spin_quant(module) 88 return module 89