xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/lora.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9# Helper functions for tranforming the model to be able to load checkpoints with
10# LoRA adaptors. See https://arxiv.org/abs/2106.09685 for more details about LoRA.
11
12from typing import Any
13
14import torch
15from torch import nn
16from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
17from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
18
19
20class LoRAAdaptorLinear(nn.Module):
21    """
22    LoRA adaptor for linear layers.
23
24    This class implements Low-Rank Adaptation(LoRA) for linear layers.
25    See more details about LoRA here https://arxiv.org/abs/2106.09685.
26    """
27
28    def __init__(
29        self,
30        in_features: int,
31        out_features: int,
32        rank: int,
33        scale: float = 2.0,
34        dtype=torch.float32,
35        device=None,
36    ) -> None:
37        super().__init__()
38        self.scale = scale
39        self.A = nn.Linear(in_features, rank, bias=False, dtype=dtype, device=device)
40        self.B = nn.Linear(rank, out_features, bias=False, dtype=dtype, device=device)
41
42    def forward(self, x: torch.Tensor) -> torch.Tensor:
43        return self.scale * self.B(self.A(x))  # pyre-ignore[7]
44
45
46class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
47    """
48    Int8DynActInt4WeightLinear with LoRA adaptor.
49    """
50
51    def __init__(
52        self,
53        in_features: int,
54        out_features: int,
55        lora_rank: int,
56        bias=True,
57        device=None,
58        groupsize: int = 256,
59        precision: torch.dtype = torch.float32,
60        scales_precision: torch.dtype = torch.float32,
61        lora_adaptor_precision: torch.dtype = torch.bfloat16,
62        lora_scale: float = 2.0,
63    ) -> None:
64        super().__init__(
65            in_features,
66            out_features,
67            bias=bias,
68            device=device,
69            groupsize=groupsize,
70            precision=precision,
71            scales_precision=scales_precision,
72        )
73        # TODO(lunwenh): Remove this once TorchAO's commit pin in ExecuTorch is updated to include this PR
74        self.zeros = torch.zeros_like(self.zeros)
75        self.adaptor = LoRAAdaptorLinear(
76            in_features,
77            out_features,
78            lora_rank,
79            scale=lora_scale,
80            dtype=lora_adaptor_precision,
81            device=device,
82        )
83
84    def forward(self, input: torch.Tensor) -> torch.Tensor:
85        return super().forward(input) + self.adaptor(input).to(dtype=self.precision)
86
87
88def _replace_linear_8da4w_for_lora(
89    module: torch.nn.Module,
90    checkpoint: Any,
91    lora_rank: int,
92):
93    def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
94        # Only replace linear layers where the checkpoint contains explicit adaptors
95        adaptor_A_key = f"{cur_fqn}.adaptor.A.weight"
96        adaptor_B_key = f"{cur_fqn}.adaptor.B.weight"
97        if (
98            isinstance(child, Int8DynActInt4WeightLinear)
99            and adaptor_A_key in checkpoint
100            and adaptor_B_key in checkpoint
101        ):
102            assert checkpoint[adaptor_A_key].dtype == torch.bfloat16
103            assert checkpoint[adaptor_A_key].shape[0] == lora_rank
104            assert checkpoint[adaptor_A_key].shape[1] == child.in_features
105            assert checkpoint[adaptor_B_key].dtype == torch.bfloat16
106            assert checkpoint[adaptor_B_key].shape[0] == child.out_features
107            assert checkpoint[adaptor_B_key].shape[1] == lora_rank
108            return True
109        return False
110
111    def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
112        new_linear = Int8DynActInt4WeightLinearLoRA(
113            # pyre-fixme[6]: For 1st argument expected `int` but got `Union[Module,
114            #  Tensor]`.
115            child.in_features,
116            # pyre-fixme[6]: For 2nd argument expected `int` but got `Union[Module,
117            #  Tensor]`.
118            child.out_features,
119            lora_rank=lora_rank,
120            bias=False,
121            device=child.weight.device,
122            # pyre-fixme[6]: For 6th argument expected `int` but got `Union[Module,
123            #  Tensor]`.
124            groupsize=child.groupsize,
125            # pyre-fixme[6]: For 7th argument expected `dtype` but got
126            #  `Union[Module, Tensor]`.
127            precision=child.precision,
128            # pyre-fixme[6]: For 8th argument expected `dtype` but got
129            #  `Union[Module, dtype, Tensor]`.
130            scales_precision=child.scales.dtype,
131        )
132        return new_linear
133
134    _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
135
136
137def transform_linear_for_lora_after_quantization(
138    module: torch.nn.Module,
139    checkpoint: Any,
140    lora_rank: int,
141) -> torch.nn.Module:
142    """
143    Transform the model to be able to load checkpoints with LoRA adaptors.
144    The model should be already transformed to be able to load pre-quantized
145    checkpoints. The checkpoint should have been pre-quantized and added with
146    LoRA adaptors.
147    """
148    _replace_linear_8da4w_for_lora(
149        module,
150        checkpoint,
151        lora_rank,
152    )
153    return module
154