xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/rope.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8
9from ..llama_transformer import Transformer
10
11
12def materialze_broadcast_of_rope_freq_cis(
13    module: torch.nn.Module,
14):
15    assert isinstance(module, Transformer)
16    assert module.freqs_cos.dim() == 2
17    dim0 = module.freqs_cos.size(0)
18    dim1 = module.freqs_cos.size(1)
19    module_attention = module.layers[0].attention
20    assert (
21        module_attention.n_local_kv_heads == module_attention.n_local_heads
22    ), f"For rope freqs to be materialized for broadcast, q, k, v num heads must match. For q got {module_attention.n_kv_heads} for k got {module_attention.n_local_heads} and v got {module_attention.n_local_kv_heads}"
23    num_heads = module_attention.n_local_heads
24    module.freqs_cos = module.freqs_cos.view(dim0, 1, dim1)
25    module.freqs_cos = module.freqs_cos.expand(dim0, num_heads, dim1).contiguous()
26    assert module.freqs_sin.dim() == 2
27    assert dim0 == module.freqs_sin.size(
28        0
29    ), f"sin and cos freq table sizes must match. Mismatch found at dim 0: {dim0} vs {module.freqs_sin.size(0)}"
30    assert dim1 == module.freqs_sin.size(
31        1
32    ), f"sin and cos freq table sizes must match. Mismatch found at dim 1: {dim1} vs {module.freqs_sin.size(1)}"
33    module.freqs_sin = module.freqs_sin.view(dim0, 1, dim1)
34    module.freqs_sin = module.freqs_sin.expand(dim0, num_heads, dim1).contiguous()
35    return module
36