1import argparse 2import os.path 3import sys 4 5import torch 6 7 8def get_custom_op_library_path(): 9 if sys.platform.startswith("win32"): 10 library_filename = "custom_ops.dll" 11 elif sys.platform.startswith("darwin"): 12 library_filename = "libcustom_ops.dylib" 13 else: 14 library_filename = "libcustom_ops.so" 15 path = os.path.abspath(f"build/{library_filename}") 16 assert os.path.exists(path), path 17 return path 18 19 20class Model(torch.jit.ScriptModule): 21 def __init__(self) -> None: 22 super().__init__() 23 self.p = torch.nn.Parameter(torch.eye(5)) 24 25 @torch.jit.script_method 26 def forward(self, input): 27 return torch.ops.custom.op_with_defaults(input)[0] + 1 28 29 30def main(): 31 parser = argparse.ArgumentParser( 32 description="Serialize a script module with custom ops" 33 ) 34 parser.add_argument("--export-script-module-to", required=True) 35 options = parser.parse_args() 36 37 torch.ops.load_library(get_custom_op_library_path()) 38 39 model = Model() 40 model.save(options.export_script_module_to) 41 42 43if __name__ == "__main__": 44 main() 45