1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9# Helper functions for tranforming the model to be able to load checkpoints with 10# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA. 11 12from typing import Any 13 14import torch 15from torch import nn 16from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear 17from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter 18 19 20class LoRAAdaptorLinear(nn.Module): 21 """ 22 LoRA adaptor for linear layers. 23 24 This class implements Low-Rank Adaptation(LoRA) for linear layers. 25 See more details about LoRA here https://arxiv.org/abs/2106.09685. 26 """ 27 28 def __init__( 29 self, 30 in_features: int, 31 out_features: int, 32 rank: int, 33 scale: float = 2.0, 34 dtype=torch.float32, 35 device=None, 36 ) -> None: 37 super().__init__() 38 self.scale = scale 39 self.A = nn.Linear(in_features, rank, bias=False, dtype=dtype, device=device) 40 self.B = nn.Linear(rank, out_features, bias=False, dtype=dtype, device=device) 41 42 def forward(self, x: torch.Tensor) -> torch.Tensor: 43 return self.scale * self.B(self.A(x)) # pyre-ignore[7] 44 45 46class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear): 47 """ 48 Int8DynActInt4WeightLinear with LoRA adaptor. 49 """ 50 51 def __init__( 52 self, 53 in_features: int, 54 out_features: int, 55 lora_rank: int, 56 bias=True, 57 device=None, 58 groupsize: int = 256, 59 precision: torch.dtype = torch.float32, 60 scales_precision: torch.dtype = torch.float32, 61 lora_adaptor_precision: torch.dtype = torch.bfloat16, 62 lora_scale: float = 2.0, 63 ) -> None: 64 super().__init__( 65 in_features, 66 out_features, 67 bias=bias, 68 device=device, 69 groupsize=groupsize, 70 precision=precision, 71 scales_precision=scales_precision, 72 ) 73 # TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR 74 self.zeros = torch.zeros_like(self.zeros) 75 self.adaptor = LoRAAdaptorLinear( 76 in_features, 77 out_features, 78 lora_rank, 79 scale=lora_scale, 80 dtype=lora_adaptor_precision, 81 device=device, 82 ) 83 84 def forward(self, input: torch.Tensor) -> torch.Tensor: 85 return super().forward(input) + self.adaptor(input).to(dtype=self.precision) 86 87 88def _replace_linear_8da4w_for_lora( 89 module: torch.nn.Module, 90 checkpoint: Any, 91 lora_rank: int, 92): 93 def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool: 94 # Only replace linear layers where the checkpoint contains explicit adaptors 95 adaptor_A_key = f"{cur_fqn}.adaptor.A.weight" 96 adaptor_B_key = f"{cur_fqn}.adaptor.B.weight" 97 if ( 98 isinstance(child, Int8DynActInt4WeightLinear) 99 and adaptor_A_key in checkpoint 100 and adaptor_B_key in checkpoint 101 ): 102 assert checkpoint[adaptor_A_key].dtype == torch.bfloat16 103 assert checkpoint[adaptor_A_key].shape[0] == lora_rank 104 assert checkpoint[adaptor_A_key].shape[1] == child.in_features 105 assert checkpoint[adaptor_B_key].dtype == torch.bfloat16 106 assert checkpoint[adaptor_B_key].shape[0] == child.out_features 107 assert checkpoint[adaptor_B_key].shape[1] == lora_rank 108 return True 109 return False 110 111 def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: 112 new_linear = Int8DynActInt4WeightLinearLoRA( 113 # pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module, 114 # Tensor]`. 115 child.in_features, 116 # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module, 117 # Tensor]`. 118 child.out_features, 119 lora_rank=lora_rank, 120 bias=False, 121 device=child.weight.device, 122 # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Module, 123 # Tensor]`. 124 groupsize=child.groupsize, 125 # pyre-fixme[6]: For 7th argument expected `dtype` but got 126 # `Union[Module, Tensor]`. 127 precision=child.precision, 128 # pyre-fixme[6]: For 8th argument expected `dtype` but got 129 # `Union[Module, dtype, Tensor]`. 130 scales_precision=child.scales.dtype, 131 ) 132 return new_linear 133 134 _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) 135 136 137def transform_linear_for_lora_after_quantization( 138 module: torch.nn.Module, 139 checkpoint: Any, 140 lora_rank: int, 141) -> torch.nn.Module: 142 """ 143 Transform the model to be able to load checkpoints with LoRA adaptors. 144 The model should be already transformed to be able to load pre-quantized 145 checkpoints. The checkpoint should have been pre-quantized and added with 146 LoRA adaptors. 147 """ 148 _replace_linear_8da4w_for_lora( 149 module, 150 checkpoint, 151 lora_rank, 152 ) 153 return module 154