xref: /aosp_15_r20/external/pytorch/test/custom_backend/backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import argparse
2import os.path
3import sys
4
5import torch
6
7
8def get_custom_backend_library_path():
9    """
10    Get the path to the library containing the custom backend.
11
12    Return:
13        The path to the custom backend object, customized by platform.
14    """
15    if sys.platform.startswith("win32"):
16        library_filename = "custom_backend.dll"
17    elif sys.platform.startswith("darwin"):
18        library_filename = "libcustom_backend.dylib"
19    else:
20        library_filename = "libcustom_backend.so"
21    path = os.path.abspath(f"build/{library_filename}")
22    assert os.path.exists(path), path
23    return path
24
25
26def to_custom_backend(module):
27    """
28    This is a helper that wraps torch._C._jit_to_test_backend and compiles
29    only the forward method with an empty compile spec.
30
31    Args:
32        module: input ScriptModule.
33
34    Returns:
35        The module, lowered so that it can run on TestBackend.
36    """
37    lowered_module = torch._C._jit_to_backend(
38        "custom_backend", module, {"forward": {"": ""}}
39    )
40    return lowered_module
41
42
43class Model(torch.nn.Module):
44    """
45    Simple model used for testing that to_backend API supports saving, loading,
46    and executing in C++.
47    """
48
49    def forward(self, a, b):
50        return (a + b, a - b)
51
52
53def main():
54    parser = argparse.ArgumentParser(description="Lower a Module to a custom backend")
55    parser.add_argument("--export-module-to", required=True)
56    options = parser.parse_args()
57
58    # Load the library containing the custom backend.
59    library_path = get_custom_backend_library_path()
60    torch.ops.load_library(library_path)
61    assert library_path in torch.ops.loaded_libraries
62
63    # Lower an instance of Model to the custom backend  and export it
64    # to the specified location.
65    lowered_module = to_custom_backend(torch.jit.script(Model()))
66    torch.jit.save(lowered_module, options.export_module_to)
67
68
69if __name__ == "__main__":
70    main()
71