xref: /aosp_15_r20/external/pytorch/test/custom_operator/model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os.path
3import sys
4
5import torch
6
7
8def get_custom_op_library_path():
9    if sys.platform.startswith("win32"):
10        library_filename = "custom_ops.dll"
11    elif sys.platform.startswith("darwin"):
12        library_filename = "libcustom_ops.dylib"
13    else:
14        library_filename = "libcustom_ops.so"
15    path = os.path.abspath(f"build/{library_filename}")
16    assert os.path.exists(path), path
17    return path
18
19
20class Model(torch.jit.ScriptModule):
21    def __init__(self) -> None:
22        super().__init__()
23        self.p = torch.nn.Parameter(torch.eye(5))
24
25    @torch.jit.script_method
26    def forward(self, input):
27        return torch.ops.custom.op_with_defaults(input)[0] + 1
28
29
30def main():
31    parser = argparse.ArgumentParser(
32        description="Serialize a script module with custom ops"
33    )
34    parser.add_argument("--export-script-module-to", required=True)
35    options = parser.parse_args()
36
37    torch.ops.load_library(get_custom_op_library_path())
38
39    model = Model()
40    model.save(options.export_script_module_to)
41
42
43if __name__ == "__main__":
44    main()
45