1#!/usr/bin/env python3 2import os 3from pathlib import Path 4 5from torch.jit._decompositions import decomposition_table 6 7 8# from torchgen.code_template import CodeTemplate 9 10DECOMP_HEADER = r""" 11/** 12 * @generated 13 * This is an auto-generated file. Please do not modify it by hand. 14 * To re-generate, please run: 15 * cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py 16 */ 17#include <torch/csrc/jit/jit_log.h> 18#include <torch/csrc/jit/passes/inliner.h> 19#include <torch/csrc/jit/runtime/operator.h> 20#include <torch/csrc/jit/runtime/decomposition_registry_util.h> 21 22namespace torch { 23namespace jit { 24 25 26const std::string decomp_funcs = 27R"(""" 28 29 30DECOMP_CENTER = r""" 31)"; 32 33const std::string& GetSerializedDecompositions() { 34 return decomp_funcs; 35} 36 37const OperatorMap<std::string>& GetDecompositionMapping() { 38 // clang-format off 39 static const OperatorMap<std::string> decomposition_mapping { 40""" 41 42DECOMP_END = r""" 43 }; 44 // clang-format on 45 46 return decomposition_mapping; 47} 48 49} // namespace jit 50} // namespace torch 51""" 52 53 54DECOMPOSITION_UTIL_FILE_NAME = "decomposition_registry_util.cpp" 55 56 57def gen_serialized_decompisitions() -> str: 58 return "\n".join( 59 [scripted_func.code for scripted_func in decomposition_table.values()] # type: ignore[misc] 60 ) 61 62 63def gen_decomposition_mappings() -> str: 64 decomposition_mappings = [] 65 for schema, scripted_func in decomposition_table.items(): 66 decomposition_mappings.append( 67 ' {"' + schema + '", "' + scripted_func.name + '"},' # type: ignore[operator] 68 ) 69 return "\n".join(decomposition_mappings) 70 71 72def write_decomposition_util_file(path: str) -> None: 73 decomposition_str = gen_serialized_decompisitions() 74 decomposition_mappings = gen_decomposition_mappings() 75 file_components = [ 76 DECOMP_HEADER, 77 decomposition_str, 78 DECOMP_CENTER, 79 decomposition_mappings, 80 DECOMP_END, 81 ] 82 print("writing file to : ", path + "/" + DECOMPOSITION_UTIL_FILE_NAME) 83 with open(os.path.join(path, DECOMPOSITION_UTIL_FILE_NAME), "wb") as out_file: 84 final_output = "".join(file_components) 85 out_file.write(final_output.encode("utf-8")) 86 87 88def main() -> None: 89 pytorch_dir = Path(__file__).resolve().parents[3] 90 upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime" 91 write_decomposition_util_file(str(upgrader_path)) 92 93 94if __name__ == "__main__": 95 main() 96