1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import math 8 9import torch 10import torch.nn.functional as F 11from torch import nn 12 13 14class EncoderLayer(nn.Module): 15 def __init__(self, embed_dim, num_heads=2): 16 super().__init__() 17 self.embed_dim = embed_dim 18 self.kdim = self.embed_dim 19 self.vdim = self.embed_dim 20 self.k_proj = nn.Linear(self.kdim, embed_dim) 21 self.v_proj = nn.Linear(self.vdim, embed_dim) 22 self.q_proj = nn.Linear(embed_dim, embed_dim) 23 self.out_proj = nn.Linear(embed_dim, embed_dim) 24 25 self.num_heads = num_heads 26 27 self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 28 self.final_layer_norm = nn.LayerNorm(self.embed_dim) 29 30 ffn_embed_dim = ( 31 2 * embed_dim 32 ) # for simplicity we just hardcode ffn_embed_dim to be 2x of embed_dim 33 self.fc1 = nn.Linear(embed_dim, ffn_embed_dim) 34 self.fc2 = nn.Linear(ffn_embed_dim, embed_dim) 35 36 def forward(self, x): 37 residual = x 38 query = key = value = x 39 x, _ = F.multi_head_attention_forward( 40 query, 41 key, 42 value, 43 self.embed_dim, 44 self.num_heads, 45 q_proj_weight=self.q_proj.weight, 46 k_proj_weight=self.k_proj.weight, 47 v_proj_weight=self.v_proj.weight, 48 out_proj_weight=self.out_proj.weight, 49 out_proj_bias=self.out_proj.bias, 50 add_zero_attn=False, 51 dropout_p=0.0, 52 use_separate_proj_weight=True, 53 in_proj_weight=None, 54 in_proj_bias=None, 55 # is non None value really needed for bias_k, bias_v? 56 bias_k=None, 57 bias_v=None, 58 ) 59 x = residual + x 60 x = self.self_attn_layer_norm(x) 61 62 residual = x 63 x = F.relu(self.fc1(x)) 64 x = self.fc2(x) 65 x = residual + x 66 x = self.final_layer_norm(x) 67 return x 68 69 70@torch.no_grad() 71class Transformer(nn.Module): 72 """ 73 A simplified implementation of mt_model that does not have all those heavy 74 dependencies but still be similar enough to the original model. 75 76 Suitable to be put in exir end2end tests. E.g., we can use it to ease the 77 testing of memory planning for dynamic shapes on REAL models. 78 79 Some of the simplifications recorded here: 80 1. the original model will reset the embedding to a 0 vector for padding token. 81 We skip that. 82 2. skip various configurations in the original model. E.g., original model 83 has a config cfg.no_scale_embedding to control if the token embedding 84 should be scaled or not. We just always scale the embedding. 85 """ 86 87 def __init__(self, inp_vocab_size=10, model_dim=32, num_encoder_layers=2): 88 super().__init__() 89 self.inp_vocab_size = inp_vocab_size 90 self.model_dim = 32 91 self.token_embed_table = nn.Embedding(self.inp_vocab_size, self.model_dim) 92 self.embed_scale = math.sqrt(self.model_dim) 93 self.encoder_layers = [ 94 EncoderLayer(embed_dim=self.model_dim) for _ in range(num_encoder_layers) 95 ] 96 97 def encode(self, src_tokens): 98 # embed = self.token_embed_table(src_tokens) * self.embed_scale # fail in runtime because of lacking broadcasting 99 embed = self.token_embed_table(src_tokens) 100 # TODO: add the support for positional embedding 101 102 # BxTxC -> TxBxC 103 x = embed.transpose(0, 1) 104 105 for layer in self.encoder_layers: 106 x = layer(x) 107 108 return x 109 110 def get_random_inputs(self, method): 111 if method == "encode": 112 seqlen = 10 # TODO: make the sequence length dynamic 113 return torch.randint( 114 low=0, 115 high=self.inp_vocab_size, 116 size=( 117 1, 118 seqlen, 119 ), 120 ) 121 else: 122 raise AssertionError(f"method {method} is not supported yet") 123