xref: /aosp_15_r20/external/executorch/exir/tests/transformer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import math
8
9import torch
10import torch.nn.functional as F
11from torch import nn
12
13
14class EncoderLayer(nn.Module):
15    def __init__(self, embed_dim, num_heads=2):
16        super().__init__()
17        self.embed_dim = embed_dim
18        self.kdim = self.embed_dim
19        self.vdim = self.embed_dim
20        self.k_proj = nn.Linear(self.kdim, embed_dim)
21        self.v_proj = nn.Linear(self.vdim, embed_dim)
22        self.q_proj = nn.Linear(embed_dim, embed_dim)
23        self.out_proj = nn.Linear(embed_dim, embed_dim)
24
25        self.num_heads = num_heads
26
27        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
28        self.final_layer_norm = nn.LayerNorm(self.embed_dim)
29
30        ffn_embed_dim = (
31            2 * embed_dim
32        )  # for simplicity we just hardcode ffn_embed_dim to be 2x of embed_dim
33        self.fc1 = nn.Linear(embed_dim, ffn_embed_dim)
34        self.fc2 = nn.Linear(ffn_embed_dim, embed_dim)
35
36    def forward(self, x):
37        residual = x
38        query = key = value = x
39        x, _ = F.multi_head_attention_forward(
40            query,
41            key,
42            value,
43            self.embed_dim,
44            self.num_heads,
45            q_proj_weight=self.q_proj.weight,
46            k_proj_weight=self.k_proj.weight,
47            v_proj_weight=self.v_proj.weight,
48            out_proj_weight=self.out_proj.weight,
49            out_proj_bias=self.out_proj.bias,
50            add_zero_attn=False,
51            dropout_p=0.0,
52            use_separate_proj_weight=True,
53            in_proj_weight=None,
54            in_proj_bias=None,
55            # is non None value really needed for bias_k, bias_v?
56            bias_k=None,
57            bias_v=None,
58        )
59        x = residual + x
60        x = self.self_attn_layer_norm(x)
61
62        residual = x
63        x = F.relu(self.fc1(x))
64        x = self.fc2(x)
65        x = residual + x
66        x = self.final_layer_norm(x)
67        return x
68
69
70@torch.no_grad()
71class Transformer(nn.Module):
72    """
73    A simplified implementation of mt_model that does not have all those heavy
74    dependencies but still be similar enough to the original model.
75
76    Suitable to be put in exir end2end tests. E.g., we can use it to ease the
77    testing of memory planning for dynamic shapes on REAL models.
78
79    Some of the simplifications recorded here:
80    1. the original model will reset the embedding to a 0 vector for padding token.
81       We skip that.
82    2. skip various configurations in the original model. E.g., original model
83       has a config cfg.no_scale_embedding to control if the token embedding
84       should be scaled or not. We just always scale the embedding.
85    """
86
87    def __init__(self, inp_vocab_size=10, model_dim=32, num_encoder_layers=2):
88        super().__init__()
89        self.inp_vocab_size = inp_vocab_size
90        self.model_dim = 32
91        self.token_embed_table = nn.Embedding(self.inp_vocab_size, self.model_dim)
92        self.embed_scale = math.sqrt(self.model_dim)
93        self.encoder_layers = [
94            EncoderLayer(embed_dim=self.model_dim) for _ in range(num_encoder_layers)
95        ]
96
97    def encode(self, src_tokens):
98        # embed = self.token_embed_table(src_tokens) * self.embed_scale # fail in runtime because of lacking broadcasting
99        embed = self.token_embed_table(src_tokens)
100        # TODO: add the support for positional embedding
101
102        # BxTxC -> TxBxC
103        x = embed.transpose(0, 1)
104
105        for layer in self.encoder_layers:
106            x = layer(x)
107
108        return x
109
110    def get_random_inputs(self, method):
111        if method == "encode":
112            seqlen = 10  # TODO: make the sequence length dynamic
113            return torch.randint(
114                low=0,
115                high=self.inp_vocab_size,
116                size=(
117                    1,
118                    seqlen,
119                ),
120            )
121        else:
122            raise AssertionError(f"method {method} is not supported yet")
123