xref: /aosp_15_r20/external/executorch/examples/portable/custom_ops/custom_ops_2.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""Example of showcasing registering custom operator through loading a shared
8library that calls PyTorch C++ op registration API.
9"""
10
11import argparse
12
13import torch
14from examples.portable.scripts.export import export_to_exec_prog, save_pte_program
15from executorch.exir import EdgeCompileConfig
16
17
18# example model
19class Model(torch.nn.Module):
20    def forward(self, a):
21        return torch.ops.my_ops.mul4.default(a)
22
23
24def main():
25    m = Model()
26    input = torch.randn(2, 3)
27
28    # capture and lower
29    model_name = "custom_ops_2"
30    prog = export_to_exec_prog(
31        m,
32        (input,),
33        edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
34    )
35    save_pte_program(prog, model_name)
36
37
38if __name__ == "__main__":
39    parser = argparse.ArgumentParser()
40    parser.add_argument(
41        "-s",
42        "--so_library",
43        required=True,
44        help="Provide path to so library. E.g., cmake-out/examples/portable/custom_ops/libcustom_ops_aot_lib.so",
45    )
46    args = parser.parse_args()
47    # See if we have custom op my_ops::mul4.out registered
48    has_out_ops = True
49    try:
50        op = torch.ops.my_ops.mul4.out
51    except AttributeError:
52        print("No registered custom op my_ops::mul4.out")
53        has_out_ops = False
54    if not has_out_ops:
55        if args.so_library:
56            torch.ops.load_library(args.so_library)
57        else:
58            raise RuntimeError(
59                "Need to specify shared library path to register custom op my_ops::mul4.out into"
60                "EXIR. The required shared library is defined as `custom_ops_aot_lib` in "
61                "examples/portable/custom_ops/CMakeLists.txt if you are using CMake build, or `custom_ops_aot_lib_2` in "
62                "examples/portable/custom_ops/targets.bzl for buck2. One example path would be cmake-out/examples/portable/custom_ops/"
63                "libcustom_ops_aot_lib.[so|dylib]."
64            )
65    print(args.so_library)
66
67    main()
68