xref: /aosp_15_r20/external/pytorch/torchgen/decompositions/gen_jit_decompositions.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2import os
3from pathlib import Path
4
5from torch.jit._decompositions import decomposition_table
6
7
8# from torchgen.code_template import CodeTemplate
9
10DECOMP_HEADER = r"""
11/**
12 * @generated
13 * This is an auto-generated file. Please do not modify it by hand.
14 * To re-generate, please run:
15 * cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py
16 */
17#include <torch/csrc/jit/jit_log.h>
18#include <torch/csrc/jit/passes/inliner.h>
19#include <torch/csrc/jit/runtime/operator.h>
20#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
21
22namespace torch {
23namespace jit {
24
25
26const std::string decomp_funcs =
27R"("""
28
29
30DECOMP_CENTER = r"""
31)";
32
33const std::string& GetSerializedDecompositions() {
34  return decomp_funcs;
35}
36
37const OperatorMap<std::string>& GetDecompositionMapping() {
38  // clang-format off
39 static const OperatorMap<std::string> decomposition_mapping {
40"""
41
42DECOMP_END = r"""
43  };
44  // clang-format on
45
46  return decomposition_mapping;
47}
48
49} // namespace jit
50} // namespace torch
51"""
52
53
54DECOMPOSITION_UTIL_FILE_NAME = "decomposition_registry_util.cpp"
55
56
57def gen_serialized_decompisitions() -> str:
58    return "\n".join(
59        [scripted_func.code for scripted_func in decomposition_table.values()]  # type: ignore[misc]
60    )
61
62
63def gen_decomposition_mappings() -> str:
64    decomposition_mappings = []
65    for schema, scripted_func in decomposition_table.items():
66        decomposition_mappings.append(
67            '    {"' + schema + '", "' + scripted_func.name + '"},'  # type: ignore[operator]
68        )
69    return "\n".join(decomposition_mappings)
70
71
72def write_decomposition_util_file(path: str) -> None:
73    decomposition_str = gen_serialized_decompisitions()
74    decomposition_mappings = gen_decomposition_mappings()
75    file_components = [
76        DECOMP_HEADER,
77        decomposition_str,
78        DECOMP_CENTER,
79        decomposition_mappings,
80        DECOMP_END,
81    ]
82    print("writing file to : ", path + "/" + DECOMPOSITION_UTIL_FILE_NAME)
83    with open(os.path.join(path, DECOMPOSITION_UTIL_FILE_NAME), "wb") as out_file:
84        final_output = "".join(file_components)
85        out_file.write(final_output.encode("utf-8"))
86
87
88def main() -> None:
89    pytorch_dir = Path(__file__).resolve().parents[3]
90    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
91    write_decomposition_util_file(str(upgrader_path))
92
93
94if __name__ == "__main__":
95    main()
96