1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Example script for exporting simple models to flatbuffer 8 9import logging 10 11import torch 12 13from executorch.backends.cadence.aot.ops_registrations import * # noqa 14 15from typing import Tuple 16 17from executorch.backends.cadence.aot.export_example import export_model 18 19 20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 21logging.basicConfig(level=logging.INFO, format=FORMAT) 22 23 24if __name__ == "__main__": 25 26 class Predictor(torch.nn.Module): 27 def __init__( 28 self, 29 num_symbols: int, 30 symbol_embedding_dim: int, 31 ) -> None: 32 super().__init__() 33 self.embedding = torch.nn.Embedding(num_symbols, symbol_embedding_dim) 34 self.relu = torch.nn.ReLU() 35 self.linear = torch.nn.Linear(symbol_embedding_dim, symbol_embedding_dim) 36 self.layer_norm = torch.nn.LayerNorm(symbol_embedding_dim) 37 38 def forward( 39 self, 40 input: torch.Tensor, 41 lengths: torch.Tensor, 42 ) -> Tuple[torch.Tensor, torch.Tensor]: 43 input_tb = input.permute(1, 0) 44 embedding_out = self.embedding(input_tb) 45 relu_out = self.relu(embedding_out) 46 linear_out = self.linear(relu_out) 47 layer_norm_out = self.layer_norm(linear_out) 48 return layer_norm_out.permute(1, 0, 2), lengths 49 50 # Predictor 51 model = Predictor(128, 256) 52 model.eval() 53 54 # Batch size 55 batch_size = 1 56 57 num_symbols = 128 58 max_target_length = 10 59 60 # Dummy inputs 61 predictor_input = torch.randint(0, num_symbols, (batch_size, max_target_length)) 62 predictor_lengths = torch.randint(1, max_target_length + 1, (batch_size,)) 63 64 example_inputs = ( 65 predictor_input, 66 predictor_lengths, 67 ) 68 69 export_model(model, example_inputs) 70