xref: /aosp_15_r20/external/executorch/exir/tests/transformer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport math
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Workerimport torch.nn.functional as F
11*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerclass EncoderLayer(nn.Module):
15*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, embed_dim, num_heads=2):
16*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
17*523fa7a6SAndroid Build Coastguard Worker        self.embed_dim = embed_dim
18*523fa7a6SAndroid Build Coastguard Worker        self.kdim = self.embed_dim
19*523fa7a6SAndroid Build Coastguard Worker        self.vdim = self.embed_dim
20*523fa7a6SAndroid Build Coastguard Worker        self.k_proj = nn.Linear(self.kdim, embed_dim)
21*523fa7a6SAndroid Build Coastguard Worker        self.v_proj = nn.Linear(self.vdim, embed_dim)
22*523fa7a6SAndroid Build Coastguard Worker        self.q_proj = nn.Linear(embed_dim, embed_dim)
23*523fa7a6SAndroid Build Coastguard Worker        self.out_proj = nn.Linear(embed_dim, embed_dim)
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker        self.num_heads = num_heads
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
28*523fa7a6SAndroid Build Coastguard Worker        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker        ffn_embed_dim = (
31*523fa7a6SAndroid Build Coastguard Worker            2 * embed_dim
32*523fa7a6SAndroid Build Coastguard Worker        )  # for simplicity we just hardcode ffn_embed_dim to be 2x of embed_dim
33*523fa7a6SAndroid Build Coastguard Worker        self.fc1 = nn.Linear(embed_dim, ffn_embed_dim)
34*523fa7a6SAndroid Build Coastguard Worker        self.fc2 = nn.Linear(ffn_embed_dim, embed_dim)
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
37*523fa7a6SAndroid Build Coastguard Worker        residual = x
38*523fa7a6SAndroid Build Coastguard Worker        query = key = value = x
39*523fa7a6SAndroid Build Coastguard Worker        x, _ = F.multi_head_attention_forward(
40*523fa7a6SAndroid Build Coastguard Worker            query,
41*523fa7a6SAndroid Build Coastguard Worker            key,
42*523fa7a6SAndroid Build Coastguard Worker            value,
43*523fa7a6SAndroid Build Coastguard Worker            self.embed_dim,
44*523fa7a6SAndroid Build Coastguard Worker            self.num_heads,
45*523fa7a6SAndroid Build Coastguard Worker            q_proj_weight=self.q_proj.weight,
46*523fa7a6SAndroid Build Coastguard Worker            k_proj_weight=self.k_proj.weight,
47*523fa7a6SAndroid Build Coastguard Worker            v_proj_weight=self.v_proj.weight,
48*523fa7a6SAndroid Build Coastguard Worker            out_proj_weight=self.out_proj.weight,
49*523fa7a6SAndroid Build Coastguard Worker            out_proj_bias=self.out_proj.bias,
50*523fa7a6SAndroid Build Coastguard Worker            add_zero_attn=False,
51*523fa7a6SAndroid Build Coastguard Worker            dropout_p=0.0,
52*523fa7a6SAndroid Build Coastguard Worker            use_separate_proj_weight=True,
53*523fa7a6SAndroid Build Coastguard Worker            in_proj_weight=None,
54*523fa7a6SAndroid Build Coastguard Worker            in_proj_bias=None,
55*523fa7a6SAndroid Build Coastguard Worker            # is non None value really needed for bias_k, bias_v?
56*523fa7a6SAndroid Build Coastguard Worker            bias_k=None,
57*523fa7a6SAndroid Build Coastguard Worker            bias_v=None,
58*523fa7a6SAndroid Build Coastguard Worker        )
59*523fa7a6SAndroid Build Coastguard Worker        x = residual + x
60*523fa7a6SAndroid Build Coastguard Worker        x = self.self_attn_layer_norm(x)
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker        residual = x
63*523fa7a6SAndroid Build Coastguard Worker        x = F.relu(self.fc1(x))
64*523fa7a6SAndroid Build Coastguard Worker        x = self.fc2(x)
65*523fa7a6SAndroid Build Coastguard Worker        x = residual + x
66*523fa7a6SAndroid Build Coastguard Worker        x = self.final_layer_norm(x)
67*523fa7a6SAndroid Build Coastguard Worker        return x
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker@torch.no_grad()
71*523fa7a6SAndroid Build Coastguard Workerclass Transformer(nn.Module):
72*523fa7a6SAndroid Build Coastguard Worker    """
73*523fa7a6SAndroid Build Coastguard Worker    A simplified implementation of mt_model that does not have all those heavy
74*523fa7a6SAndroid Build Coastguard Worker    dependencies but still be similar enough to the original model.
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    Suitable to be put in exir end2end tests. E.g., we can use it to ease the
77*523fa7a6SAndroid Build Coastguard Worker    testing of memory planning for dynamic shapes on REAL models.
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker    Some of the simplifications recorded here:
80*523fa7a6SAndroid Build Coastguard Worker    1. the original model will reset the embedding to a 0 vector for padding token.
81*523fa7a6SAndroid Build Coastguard Worker       We skip that.
82*523fa7a6SAndroid Build Coastguard Worker    2. skip various configurations in the original model. E.g., original model
83*523fa7a6SAndroid Build Coastguard Worker       has a config cfg.no_scale_embedding to control if the token embedding
84*523fa7a6SAndroid Build Coastguard Worker       should be scaled or not. We just always scale the embedding.
85*523fa7a6SAndroid Build Coastguard Worker    """
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, inp_vocab_size=10, model_dim=32, num_encoder_layers=2):
88*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
89*523fa7a6SAndroid Build Coastguard Worker        self.inp_vocab_size = inp_vocab_size
90*523fa7a6SAndroid Build Coastguard Worker        self.model_dim = 32
91*523fa7a6SAndroid Build Coastguard Worker        self.token_embed_table = nn.Embedding(self.inp_vocab_size, self.model_dim)
92*523fa7a6SAndroid Build Coastguard Worker        self.embed_scale = math.sqrt(self.model_dim)
93*523fa7a6SAndroid Build Coastguard Worker        self.encoder_layers = [
94*523fa7a6SAndroid Build Coastguard Worker            EncoderLayer(embed_dim=self.model_dim) for _ in range(num_encoder_layers)
95*523fa7a6SAndroid Build Coastguard Worker        ]
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker    def encode(self, src_tokens):
98*523fa7a6SAndroid Build Coastguard Worker        # embed = self.token_embed_table(src_tokens) * self.embed_scale # fail in runtime because of lacking broadcasting
99*523fa7a6SAndroid Build Coastguard Worker        embed = self.token_embed_table(src_tokens)
100*523fa7a6SAndroid Build Coastguard Worker        # TODO: add the support for positional embedding
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker        # BxTxC -> TxBxC
103*523fa7a6SAndroid Build Coastguard Worker        x = embed.transpose(0, 1)
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker        for layer in self.encoder_layers:
106*523fa7a6SAndroid Build Coastguard Worker            x = layer(x)
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker        return x
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker    def get_random_inputs(self, method):
111*523fa7a6SAndroid Build Coastguard Worker        if method == "encode":
112*523fa7a6SAndroid Build Coastguard Worker            seqlen = 10  # TODO: make the sequence length dynamic
113*523fa7a6SAndroid Build Coastguard Worker            return torch.randint(
114*523fa7a6SAndroid Build Coastguard Worker                low=0,
115*523fa7a6SAndroid Build Coastguard Worker                high=self.inp_vocab_size,
116*523fa7a6SAndroid Build Coastguard Worker                size=(
117*523fa7a6SAndroid Build Coastguard Worker                    1,
118*523fa7a6SAndroid Build Coastguard Worker                    seqlen,
119*523fa7a6SAndroid Build Coastguard Worker                ),
120*523fa7a6SAndroid Build Coastguard Worker            )
121*523fa7a6SAndroid Build Coastguard Worker        else:
122*523fa7a6SAndroid Build Coastguard Worker            raise AssertionError(f"method {method} is not supported yet")
123