1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import argparse 10 11import os 12 13import torch 14from executorch.exir import to_edge 15 16from executorch.extension.training.examples.XOR.model import Net, TrainingNet 17from torch.export import export 18from torch.export.experimental import _export_forward_backward 19 20 21def main() -> None: 22 torch.manual_seed(0) 23 parser = argparse.ArgumentParser( 24 prog="export_model", 25 description="Exports an nn.Module model to ExecuTorch .pte files", 26 ) 27 parser.add_argument( 28 "--outdir", 29 type=str, 30 required=True, 31 help="Path to the directory to write xor.pte files to", 32 ) 33 args = parser.parse_args() 34 35 net = TrainingNet(Net()) 36 x = torch.randn(1, 2) 37 38 # Captures the forward graph. The graph will look similar to the model definition now. 39 # Will move to export_for_training soon which is the api planned to be supported in the long term. 40 ep = export(net, (x, torch.ones(1, dtype=torch.int64))) 41 # Captures the backward graph. The exported_program now contains the joint forward and backward graph. 42 ep = _export_forward_backward(ep) 43 # Lower the graph to edge dialect. 44 ep = to_edge(ep) 45 # Lower the graph to executorch. 46 ep = ep.to_executorch() 47 48 # Write out the .pte file. 49 os.makedirs(args.outdir, exist_ok=True) 50 outfile = os.path.join(args.outdir, "xor.pte") 51 with open(outfile, "wb") as fp: 52 fp.write( 53 ep.buffer, 54 ) 55 56 57if __name__ == "__main__": 58 main() 59