1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport math 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Workerimport torch.nn.functional as F 11*523fa7a6SAndroid Build Coastguard Workerfrom torch import nn 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerclass EncoderLayer(nn.Module): 15*523fa7a6SAndroid Build Coastguard Worker def __init__(self, embed_dim, num_heads=2): 16*523fa7a6SAndroid Build Coastguard Worker super().__init__() 17*523fa7a6SAndroid Build Coastguard Worker self.embed_dim = embed_dim 18*523fa7a6SAndroid Build Coastguard Worker self.kdim = self.embed_dim 19*523fa7a6SAndroid Build Coastguard Worker self.vdim = self.embed_dim 20*523fa7a6SAndroid Build Coastguard Worker self.k_proj = nn.Linear(self.kdim, embed_dim) 21*523fa7a6SAndroid Build Coastguard Worker self.v_proj = nn.Linear(self.vdim, embed_dim) 22*523fa7a6SAndroid Build Coastguard Worker self.q_proj = nn.Linear(embed_dim, embed_dim) 23*523fa7a6SAndroid Build Coastguard Worker self.out_proj = nn.Linear(embed_dim, embed_dim) 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker self.num_heads = num_heads 26*523fa7a6SAndroid Build Coastguard Worker 27*523fa7a6SAndroid Build Coastguard Worker self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 28*523fa7a6SAndroid Build Coastguard Worker self.final_layer_norm = nn.LayerNorm(self.embed_dim) 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker ffn_embed_dim = ( 31*523fa7a6SAndroid Build Coastguard Worker 2 * embed_dim 32*523fa7a6SAndroid Build Coastguard Worker ) # for simplicity we just hardcode ffn_embed_dim to be 2x of embed_dim 33*523fa7a6SAndroid Build Coastguard Worker self.fc1 = nn.Linear(embed_dim, ffn_embed_dim) 34*523fa7a6SAndroid Build Coastguard Worker self.fc2 = nn.Linear(ffn_embed_dim, embed_dim) 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 37*523fa7a6SAndroid Build Coastguard Worker residual = x 38*523fa7a6SAndroid Build Coastguard Worker query = key = value = x 39*523fa7a6SAndroid Build Coastguard Worker x, _ = F.multi_head_attention_forward( 40*523fa7a6SAndroid Build Coastguard Worker query, 41*523fa7a6SAndroid Build Coastguard Worker key, 42*523fa7a6SAndroid Build Coastguard Worker value, 43*523fa7a6SAndroid Build Coastguard Worker self.embed_dim, 44*523fa7a6SAndroid Build Coastguard Worker self.num_heads, 45*523fa7a6SAndroid Build Coastguard Worker q_proj_weight=self.q_proj.weight, 46*523fa7a6SAndroid Build Coastguard Worker k_proj_weight=self.k_proj.weight, 47*523fa7a6SAndroid Build Coastguard Worker v_proj_weight=self.v_proj.weight, 48*523fa7a6SAndroid Build Coastguard Worker out_proj_weight=self.out_proj.weight, 49*523fa7a6SAndroid Build Coastguard Worker out_proj_bias=self.out_proj.bias, 50*523fa7a6SAndroid Build Coastguard Worker add_zero_attn=False, 51*523fa7a6SAndroid Build Coastguard Worker dropout_p=0.0, 52*523fa7a6SAndroid Build Coastguard Worker use_separate_proj_weight=True, 53*523fa7a6SAndroid Build Coastguard Worker in_proj_weight=None, 54*523fa7a6SAndroid Build Coastguard Worker in_proj_bias=None, 55*523fa7a6SAndroid Build Coastguard Worker # is non None value really needed for bias_k, bias_v? 56*523fa7a6SAndroid Build Coastguard Worker bias_k=None, 57*523fa7a6SAndroid Build Coastguard Worker bias_v=None, 58*523fa7a6SAndroid Build Coastguard Worker ) 59*523fa7a6SAndroid Build Coastguard Worker x = residual + x 60*523fa7a6SAndroid Build Coastguard Worker x = self.self_attn_layer_norm(x) 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker residual = x 63*523fa7a6SAndroid Build Coastguard Worker x = F.relu(self.fc1(x)) 64*523fa7a6SAndroid Build Coastguard Worker x = self.fc2(x) 65*523fa7a6SAndroid Build Coastguard Worker x = residual + x 66*523fa7a6SAndroid Build Coastguard Worker x = self.final_layer_norm(x) 67*523fa7a6SAndroid Build Coastguard Worker return x 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker@torch.no_grad() 71*523fa7a6SAndroid Build Coastguard Workerclass Transformer(nn.Module): 72*523fa7a6SAndroid Build Coastguard Worker """ 73*523fa7a6SAndroid Build Coastguard Worker A simplified implementation of mt_model that does not have all those heavy 74*523fa7a6SAndroid Build Coastguard Worker dependencies but still be similar enough to the original model. 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker Suitable to be put in exir end2end tests. E.g., we can use it to ease the 77*523fa7a6SAndroid Build Coastguard Worker testing of memory planning for dynamic shapes on REAL models. 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker Some of the simplifications recorded here: 80*523fa7a6SAndroid Build Coastguard Worker 1. the original model will reset the embedding to a 0 vector for padding token. 81*523fa7a6SAndroid Build Coastguard Worker We skip that. 82*523fa7a6SAndroid Build Coastguard Worker 2. skip various configurations in the original model. E.g., original model 83*523fa7a6SAndroid Build Coastguard Worker has a config cfg.no_scale_embedding to control if the token embedding 84*523fa7a6SAndroid Build Coastguard Worker should be scaled or not. We just always scale the embedding. 85*523fa7a6SAndroid Build Coastguard Worker """ 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Worker def __init__(self, inp_vocab_size=10, model_dim=32, num_encoder_layers=2): 88*523fa7a6SAndroid Build Coastguard Worker super().__init__() 89*523fa7a6SAndroid Build Coastguard Worker self.inp_vocab_size = inp_vocab_size 90*523fa7a6SAndroid Build Coastguard Worker self.model_dim = 32 91*523fa7a6SAndroid Build Coastguard Worker self.token_embed_table = nn.Embedding(self.inp_vocab_size, self.model_dim) 92*523fa7a6SAndroid Build Coastguard Worker self.embed_scale = math.sqrt(self.model_dim) 93*523fa7a6SAndroid Build Coastguard Worker self.encoder_layers = [ 94*523fa7a6SAndroid Build Coastguard Worker EncoderLayer(embed_dim=self.model_dim) for _ in range(num_encoder_layers) 95*523fa7a6SAndroid Build Coastguard Worker ] 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Worker def encode(self, src_tokens): 98*523fa7a6SAndroid Build Coastguard Worker # embed = self.token_embed_table(src_tokens) * self.embed_scale # fail in runtime because of lacking broadcasting 99*523fa7a6SAndroid Build Coastguard Worker embed = self.token_embed_table(src_tokens) 100*523fa7a6SAndroid Build Coastguard Worker # TODO: add the support for positional embedding 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker # BxTxC -> TxBxC 103*523fa7a6SAndroid Build Coastguard Worker x = embed.transpose(0, 1) 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker for layer in self.encoder_layers: 106*523fa7a6SAndroid Build Coastguard Worker x = layer(x) 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Worker return x 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker def get_random_inputs(self, method): 111*523fa7a6SAndroid Build Coastguard Worker if method == "encode": 112*523fa7a6SAndroid Build Coastguard Worker seqlen = 10 # TODO: make the sequence length dynamic 113*523fa7a6SAndroid Build Coastguard Worker return torch.randint( 114*523fa7a6SAndroid Build Coastguard Worker low=0, 115*523fa7a6SAndroid Build Coastguard Worker high=self.inp_vocab_size, 116*523fa7a6SAndroid Build Coastguard Worker size=( 117*523fa7a6SAndroid Build Coastguard Worker 1, 118*523fa7a6SAndroid Build Coastguard Worker seqlen, 119*523fa7a6SAndroid Build Coastguard Worker ), 120*523fa7a6SAndroid Build Coastguard Worker ) 121*523fa7a6SAndroid Build Coastguard Worker else: 122*523fa7a6SAndroid Build Coastguard Worker raise AssertionError(f"method {method} is not supported yet") 123