# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import re from typing import Any, Dict, Mapping def convert_model_state_dict( state_dict: Dict[str, Any], key_map: Mapping[str, str] ) -> Dict[str, Any]: """Convert a model state dictionary to fairseq2. :param state_dict: The original model state dictionary. :param key_map: A map of regex patterns to fairseq2 model keys. :returns: A converted model state dictionary that is compatible with fairseq2. """ new_state_dict = {} def get_new_key(old_key: str) -> str: for old_pattern, replacement in key_map.items(): if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: return new_key return old_key # Convert module keys from fairseq to fairseq2. for old_key in state_dict.keys(): new_key = get_new_key(old_key) new_state_dict[new_key] = state_dict[old_key] return new_state_dict def convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: """Convert a fairseq2 LLaMA checkpoint to the reference format.""" # state_dict = checkpoint["model"] key_map = { # fmt: off r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.", r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.", r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.", r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.", r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.", r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.", r"decoder.layer_norm.": r"norm.", r"decoder_frontend.embed.": r"tok_embeddings.", r"final_proj.": r"output.", # fmt: on } return convert_model_state_dict(checkpoint, key_map)