xref: /aosp_15_r20/external/pytorch/benchmarks/gpt_fast/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# flake8: noqa: E266, C417, B950
2*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Workerfrom torch.nn import functional as F
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerdef find_multiple(n: int, k: int) -> int:
12*da0073e9SAndroid Build Coastguard Worker    if n % k == 0:
13*da0073e9SAndroid Build Coastguard Worker        return n
14*da0073e9SAndroid Build Coastguard Worker    return n + k - (n % k)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker@dataclass
18*da0073e9SAndroid Build Coastguard Workerclass ModelArgs:
19*da0073e9SAndroid Build Coastguard Worker    block_size: int = 2048
20*da0073e9SAndroid Build Coastguard Worker    vocab_size: int = 32000
21*da0073e9SAndroid Build Coastguard Worker    n_layer: int = 32
22*da0073e9SAndroid Build Coastguard Worker    n_head: int = 32
23*da0073e9SAndroid Build Coastguard Worker    dim: int = 4096
24*da0073e9SAndroid Build Coastguard Worker    intermediate_size: int = None
25*da0073e9SAndroid Build Coastguard Worker    n_local_heads: int = -1
26*da0073e9SAndroid Build Coastguard Worker    head_dim: int = 64
27*da0073e9SAndroid Build Coastguard Worker    rope_base: float = 10000
28*da0073e9SAndroid Build Coastguard Worker    norm_eps: float = 1e-5
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker    def __post_init__(self):
31*da0073e9SAndroid Build Coastguard Worker        if self.n_local_heads == -1:
32*da0073e9SAndroid Build Coastguard Worker            self.n_local_heads = self.n_head
33*da0073e9SAndroid Build Coastguard Worker        if self.intermediate_size is None:
34*da0073e9SAndroid Build Coastguard Worker            hidden_dim = 4 * self.dim
35*da0073e9SAndroid Build Coastguard Worker            n_hidden = int(2 * hidden_dim / 3)
36*da0073e9SAndroid Build Coastguard Worker            self.intermediate_size = find_multiple(n_hidden, 256)
37*da0073e9SAndroid Build Coastguard Worker        self.head_dim = self.dim // self.n_head
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    @classmethod
40*da0073e9SAndroid Build Coastguard Worker    def from_name(cls, name: str):
41*da0073e9SAndroid Build Coastguard Worker        if name in transformer_configs:
42*da0073e9SAndroid Build Coastguard Worker            return cls(**transformer_configs[name])
43*da0073e9SAndroid Build Coastguard Worker        # fuzzy search
44*da0073e9SAndroid Build Coastguard Worker        config = [
45*da0073e9SAndroid Build Coastguard Worker            config
46*da0073e9SAndroid Build Coastguard Worker            for config in transformer_configs
47*da0073e9SAndroid Build Coastguard Worker            if config in str(name).upper() or config in str(name)
48*da0073e9SAndroid Build Coastguard Worker        ]
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
51*da0073e9SAndroid Build Coastguard Worker        # take longer name (as it have more symbols matched)
52*da0073e9SAndroid Build Coastguard Worker        if len(config) > 1:
53*da0073e9SAndroid Build Coastguard Worker            config.sort(key=len, reverse=True)
54*da0073e9SAndroid Build Coastguard Worker            assert len(config[0]) != len(
55*da0073e9SAndroid Build Coastguard Worker                config[1]
56*da0073e9SAndroid Build Coastguard Worker            ), name  # make sure only one 'best' match
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        return cls(**transformer_configs[config[0]])
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workertransformer_configs = {
62*da0073e9SAndroid Build Coastguard Worker    "CodeLlama-7b-Python-hf": dict(
63*da0073e9SAndroid Build Coastguard Worker        block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
64*da0073e9SAndroid Build Coastguard Worker    ),
65*da0073e9SAndroid Build Coastguard Worker    "7B": dict(n_layer=32, n_head=32, dim=4096),
66*da0073e9SAndroid Build Coastguard Worker    "13B": dict(n_layer=40, n_head=40, dim=5120),
67*da0073e9SAndroid Build Coastguard Worker    "30B": dict(n_layer=60, n_head=52, dim=6656),
68*da0073e9SAndroid Build Coastguard Worker    "34B": dict(
69*da0073e9SAndroid Build Coastguard Worker        n_layer=48,
70*da0073e9SAndroid Build Coastguard Worker        n_head=64,
71*da0073e9SAndroid Build Coastguard Worker        dim=8192,
72*da0073e9SAndroid Build Coastguard Worker        vocab_size=32000,
73*da0073e9SAndroid Build Coastguard Worker        n_local_heads=8,
74*da0073e9SAndroid Build Coastguard Worker        intermediate_size=22016,
75*da0073e9SAndroid Build Coastguard Worker        rope_base=1000000,
76*da0073e9SAndroid Build Coastguard Worker    ),  # CodeLlama-34B-Python-hf
77*da0073e9SAndroid Build Coastguard Worker    "70B": dict(
78*da0073e9SAndroid Build Coastguard Worker        n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
79*da0073e9SAndroid Build Coastguard Worker    ),
80*da0073e9SAndroid Build Coastguard Worker    "Mistral-7B": dict(
81*da0073e9SAndroid Build Coastguard Worker        n_layer=32,
82*da0073e9SAndroid Build Coastguard Worker        n_head=32,
83*da0073e9SAndroid Build Coastguard Worker        n_local_heads=8,
84*da0073e9SAndroid Build Coastguard Worker        dim=4096,
85*da0073e9SAndroid Build Coastguard Worker        intermediate_size=14336,
86*da0073e9SAndroid Build Coastguard Worker        vocab_size=32000,
87*da0073e9SAndroid Build Coastguard Worker    ),
88*da0073e9SAndroid Build Coastguard Worker}
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Workerclass KVCache(nn.Module):
92*da0073e9SAndroid Build Coastguard Worker    def __init__(
93*da0073e9SAndroid Build Coastguard Worker        self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16
94*da0073e9SAndroid Build Coastguard Worker    ):
95*da0073e9SAndroid Build Coastguard Worker        super().__init__()
96*da0073e9SAndroid Build Coastguard Worker        cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
97*da0073e9SAndroid Build Coastguard Worker        self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
98*da0073e9SAndroid Build Coastguard Worker        self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker    def update(self, input_pos, k_val, v_val):
101*da0073e9SAndroid Build Coastguard Worker        # input_pos: [S], k_val: [B, H, S, D]
102*da0073e9SAndroid Build Coastguard Worker        assert input_pos.shape[0] == k_val.shape[2]
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        k_out = self.k_cache
105*da0073e9SAndroid Build Coastguard Worker        v_out = self.v_cache
106*da0073e9SAndroid Build Coastguard Worker        k_out[:, :, input_pos] = k_val
107*da0073e9SAndroid Build Coastguard Worker        v_out[:, :, input_pos] = v_val
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        return k_out, v_out
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerclass Transformer(nn.Module):
113*da0073e9SAndroid Build Coastguard Worker    def __init__(self, config: ModelArgs) -> None:
114*da0073e9SAndroid Build Coastguard Worker        super().__init__()
115*da0073e9SAndroid Build Coastguard Worker        self.config = config
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
118*da0073e9SAndroid Build Coastguard Worker        self.layers = nn.ModuleList(
119*da0073e9SAndroid Build Coastguard Worker            TransformerBlock(config) for _ in range(config.n_layer)
120*da0073e9SAndroid Build Coastguard Worker        )
121*da0073e9SAndroid Build Coastguard Worker        self.norm = RMSNorm(config.dim, eps=config.norm_eps)
122*da0073e9SAndroid Build Coastguard Worker        self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        self.freqs_cis: Optional[Tensor] = None
125*da0073e9SAndroid Build Coastguard Worker        self.mask_cache: Optional[Tensor] = None
126*da0073e9SAndroid Build Coastguard Worker        self.max_batch_size = -1
127*da0073e9SAndroid Build Coastguard Worker        self.max_seq_length = -1
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    def setup_caches(self, max_batch_size, max_seq_length):
130*da0073e9SAndroid Build Coastguard Worker        if (
131*da0073e9SAndroid Build Coastguard Worker            self.max_seq_length >= max_seq_length
132*da0073e9SAndroid Build Coastguard Worker            and self.max_batch_size >= max_batch_size
133*da0073e9SAndroid Build Coastguard Worker        ):
134*da0073e9SAndroid Build Coastguard Worker            return
135*da0073e9SAndroid Build Coastguard Worker        head_dim = self.config.dim // self.config.n_head
136*da0073e9SAndroid Build Coastguard Worker        max_seq_length = find_multiple(max_seq_length, 8)
137*da0073e9SAndroid Build Coastguard Worker        self.max_seq_length = max_seq_length
138*da0073e9SAndroid Build Coastguard Worker        self.max_batch_size = max_batch_size
139*da0073e9SAndroid Build Coastguard Worker        for b in self.layers:
140*da0073e9SAndroid Build Coastguard Worker            b.attention.kv_cache = KVCache(
141*da0073e9SAndroid Build Coastguard Worker                max_batch_size, max_seq_length, self.config.n_local_heads, head_dim
142*da0073e9SAndroid Build Coastguard Worker            )
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        self.freqs_cis = precompute_freqs_cis(
145*da0073e9SAndroid Build Coastguard Worker            self.config.block_size,
146*da0073e9SAndroid Build Coastguard Worker            self.config.dim // self.config.n_head,
147*da0073e9SAndroid Build Coastguard Worker            self.config.rope_base,
148*da0073e9SAndroid Build Coastguard Worker        )
149*da0073e9SAndroid Build Coastguard Worker        self.causal_mask = torch.tril(
150*da0073e9SAndroid Build Coastguard Worker            torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
151*da0073e9SAndroid Build Coastguard Worker        )
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
154*da0073e9SAndroid Build Coastguard Worker        assert self.freqs_cis is not None, "Caches must be initialized first"
155*da0073e9SAndroid Build Coastguard Worker        mask = self.causal_mask[None, None, input_pos]
156*da0073e9SAndroid Build Coastguard Worker        freqs_cis = self.freqs_cis[input_pos]
157*da0073e9SAndroid Build Coastguard Worker        x = self.tok_embeddings(idx)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker        for i, layer in enumerate(self.layers):
160*da0073e9SAndroid Build Coastguard Worker            x = layer(x, input_pos, freqs_cis, mask)
161*da0073e9SAndroid Build Coastguard Worker        x = self.norm(x)
162*da0073e9SAndroid Build Coastguard Worker        logits = self.output(x)
163*da0073e9SAndroid Build Coastguard Worker        return logits
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    @classmethod
166*da0073e9SAndroid Build Coastguard Worker    def from_name(cls, name: str):
167*da0073e9SAndroid Build Coastguard Worker        return cls(ModelArgs.from_name(name))
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Workerclass TransformerBlock(nn.Module):
171*da0073e9SAndroid Build Coastguard Worker    def __init__(self, config: ModelArgs) -> None:
172*da0073e9SAndroid Build Coastguard Worker        super().__init__()
173*da0073e9SAndroid Build Coastguard Worker        self.attention = Attention(config)
174*da0073e9SAndroid Build Coastguard Worker        self.feed_forward = FeedForward(config)
175*da0073e9SAndroid Build Coastguard Worker        self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
176*da0073e9SAndroid Build Coastguard Worker        self.attention_norm = RMSNorm(config.dim, config.norm_eps)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker    def forward(
179*da0073e9SAndroid Build Coastguard Worker        self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor
180*da0073e9SAndroid Build Coastguard Worker    ) -> Tensor:
181*da0073e9SAndroid Build Coastguard Worker        h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
182*da0073e9SAndroid Build Coastguard Worker        out = h + self.feed_forward(self.ffn_norm(h))
183*da0073e9SAndroid Build Coastguard Worker        return out
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerclass Attention(nn.Module):
187*da0073e9SAndroid Build Coastguard Worker    def __init__(self, config: ModelArgs):
188*da0073e9SAndroid Build Coastguard Worker        super().__init__()
189*da0073e9SAndroid Build Coastguard Worker        assert config.dim % config.n_head == 0
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Worker        total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
192*da0073e9SAndroid Build Coastguard Worker        # key, query, value projections for all heads, but in a batch
193*da0073e9SAndroid Build Coastguard Worker        self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
194*da0073e9SAndroid Build Coastguard Worker        self.wo = nn.Linear(config.dim, config.dim, bias=False)
195*da0073e9SAndroid Build Coastguard Worker        self.kv_cache = None
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        self.n_head = config.n_head
198*da0073e9SAndroid Build Coastguard Worker        self.head_dim = config.head_dim
199*da0073e9SAndroid Build Coastguard Worker        self.n_local_heads = config.n_local_heads
200*da0073e9SAndroid Build Coastguard Worker        self.dim = config.dim
201*da0073e9SAndroid Build Coastguard Worker        self._register_load_state_dict_pre_hook(self.load_hook)
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker    def load_hook(self, state_dict, prefix, *args):
204*da0073e9SAndroid Build Coastguard Worker        if prefix + "wq.weight" in state_dict:
205*da0073e9SAndroid Build Coastguard Worker            wq = state_dict.pop(prefix + "wq.weight")
206*da0073e9SAndroid Build Coastguard Worker            wk = state_dict.pop(prefix + "wk.weight")
207*da0073e9SAndroid Build Coastguard Worker            wv = state_dict.pop(prefix + "wv.weight")
208*da0073e9SAndroid Build Coastguard Worker            state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker    def forward(
211*da0073e9SAndroid Build Coastguard Worker        self,
212*da0073e9SAndroid Build Coastguard Worker        x: Tensor,
213*da0073e9SAndroid Build Coastguard Worker        freqs_cis: Tensor,
214*da0073e9SAndroid Build Coastguard Worker        mask: Tensor,
215*da0073e9SAndroid Build Coastguard Worker        input_pos: Optional[Tensor] = None,
216*da0073e9SAndroid Build Coastguard Worker    ) -> Tensor:
217*da0073e9SAndroid Build Coastguard Worker        bsz, seqlen, _ = x.shape
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker        kv_size = self.n_local_heads * self.head_dim
220*da0073e9SAndroid Build Coastguard Worker        q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        q = q.view(bsz, seqlen, self.n_head, self.head_dim)
223*da0073e9SAndroid Build Coastguard Worker        k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
224*da0073e9SAndroid Build Coastguard Worker        v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker        q = apply_rotary_emb(q, freqs_cis)
227*da0073e9SAndroid Build Coastguard Worker        k = apply_rotary_emb(k, freqs_cis)
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        if self.kv_cache is not None:
232*da0073e9SAndroid Build Coastguard Worker            k, v = self.kv_cache.update(input_pos, k, v)
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker        k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
235*da0073e9SAndroid Build Coastguard Worker        v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
236*da0073e9SAndroid Build Coastguard Worker        y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker        y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        y = self.wo(y)
241*da0073e9SAndroid Build Coastguard Worker        return y
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Workerclass FeedForward(nn.Module):
245*da0073e9SAndroid Build Coastguard Worker    def __init__(self, config: ModelArgs) -> None:
246*da0073e9SAndroid Build Coastguard Worker        super().__init__()
247*da0073e9SAndroid Build Coastguard Worker        self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
248*da0073e9SAndroid Build Coastguard Worker        self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
249*da0073e9SAndroid Build Coastguard Worker        self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: Tensor) -> Tensor:
252*da0073e9SAndroid Build Coastguard Worker        return self.w2(F.silu(self.w1(x)) * self.w3(x))
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerclass RMSNorm(nn.Module):
256*da0073e9SAndroid Build Coastguard Worker    def __init__(self, dim: int, eps: float = 1e-5):
257*da0073e9SAndroid Build Coastguard Worker        super().__init__()
258*da0073e9SAndroid Build Coastguard Worker        self.eps = eps
259*da0073e9SAndroid Build Coastguard Worker        self.weight = nn.Parameter(torch.ones(dim))
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    def _norm(self, x):
262*da0073e9SAndroid Build Coastguard Worker        return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker    def forward(self, x: Tensor) -> Tensor:
265*da0073e9SAndroid Build Coastguard Worker        output = self._norm(x.float()).type_as(x)
266*da0073e9SAndroid Build Coastguard Worker        return output * self.weight
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Workerdef precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
270*da0073e9SAndroid Build Coastguard Worker    freqs = 1.0 / (
271*da0073e9SAndroid Build Coastguard Worker        base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
272*da0073e9SAndroid Build Coastguard Worker    )
273*da0073e9SAndroid Build Coastguard Worker    t = torch.arange(seq_len, device=freqs.device)
274*da0073e9SAndroid Build Coastguard Worker    freqs = torch.outer(t, freqs)
275*da0073e9SAndroid Build Coastguard Worker    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
276*da0073e9SAndroid Build Coastguard Worker    cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
277*da0073e9SAndroid Build Coastguard Worker    return cache.to(dtype=torch.bfloat16)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Workerdef apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
281*da0073e9SAndroid Build Coastguard Worker    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
282*da0073e9SAndroid Build Coastguard Worker    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
283*da0073e9SAndroid Build Coastguard Worker    x_out2 = torch.stack(
284*da0073e9SAndroid Build Coastguard Worker        [
285*da0073e9SAndroid Build Coastguard Worker            xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
286*da0073e9SAndroid Build Coastguard Worker            xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
287*da0073e9SAndroid Build Coastguard Worker        ],
288*da0073e9SAndroid Build Coastguard Worker        -1,
289*da0073e9SAndroid Build Coastguard Worker    )
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    x_out2 = x_out2.flatten(3)
292*da0073e9SAndroid Build Coastguard Worker    return x_out2.type_as(x)
293