1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7"""Example of showcasing registering custom operator through loading a shared 8library that calls PyTorch C++ op registration API. 9""" 10 11import argparse 12 13import torch 14from examples.portable.scripts.export import export_to_exec_prog, save_pte_program 15from executorch.exir import EdgeCompileConfig 16 17 18# example model 19class Model(torch.nn.Module): 20 def forward(self, a): 21 return torch.ops.my_ops.mul4.default(a) 22 23 24def main(): 25 m = Model() 26 input = torch.randn(2, 3) 27 28 # capture and lower 29 model_name = "custom_ops_2" 30 prog = export_to_exec_prog( 31 m, 32 (input,), 33 edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), 34 ) 35 save_pte_program(prog, model_name) 36 37 38if __name__ == "__main__": 39 parser = argparse.ArgumentParser() 40 parser.add_argument( 41 "-s", 42 "--so_library", 43 required=True, 44 help="Provide path to so library. E.g., cmake-out/examples/portable/custom_ops/libcustom_ops_aot_lib.so", 45 ) 46 args = parser.parse_args() 47 # See if we have custom op my_ops::mul4.out registered 48 has_out_ops = True 49 try: 50 op = torch.ops.my_ops.mul4.out 51 except AttributeError: 52 print("No registered custom op my_ops::mul4.out") 53 has_out_ops = False 54 if not has_out_ops: 55 if args.so_library: 56 torch.ops.load_library(args.so_library) 57 else: 58 raise RuntimeError( 59 "Need to specify shared library path to register custom op my_ops::mul4.out into" 60 "EXIR. The required shared library is defined as `custom_ops_aot_lib` in " 61 "examples/portable/custom_ops/CMakeLists.txt if you are using CMake build, or `custom_ops_aot_lib_2` in " 62 "examples/portable/custom_ops/targets.bzl for buck2. One example path would be cmake-out/examples/portable/custom_ops/" 63 "libcustom_ops_aot_lib.[so|dylib]." 64 ) 65 print(args.so_library) 66 67 main() 68