1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8 9from ..llama_transformer import Transformer 10 11 12def materialze_broadcast_of_rope_freq_cis( 13 module: torch.nn.Module, 14): 15 assert isinstance(module, Transformer) 16 assert module.freqs_cos.dim() == 2 17 dim0 = module.freqs_cos.size(0) 18 dim1 = module.freqs_cos.size(1) 19 module_attention = module.layers[0].attention 20 assert ( 21 module_attention.n_local_kv_heads == module_attention.n_local_heads 22 ), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}" 23 num_heads = module_attention.n_local_heads 24 module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1) 25 module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous() 26 assert module.freqs_sin.dim() == 2 27 assert dim0 == module.freqs_sin.size( 28 0 29 ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}" 30 assert dim1 == module.freqs_sin.size( 31 1 32 ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}" 33 module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1) 34 module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous() 35 return module 36