1#!/usr/bin/env python3 2 3from __future__ import annotations 4 5import os 6from enum import Enum 7from operator import itemgetter 8from pathlib import Path 9from typing import Any 10 11import torch 12from torch.jit.generate_bytecode import generate_upgraders_bytecode 13from torchgen.code_template import CodeTemplate 14from torchgen.operator_versions.gen_mobile_upgraders_constant import ( 15 MOBILE_UPGRADERS_HEADER_DESCRIPTION, 16) 17 18 19class ByteCode(Enum): 20 instructions = 1 21 constants = 2 22 types = 3 23 operators = 4 24 register_size = 5 25 26 27EXCLUDED_OP_SET = [ 28 "aten::full.names", 29 "aten::full.out", 30 "aten::full", 31] 32 33EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"] 34 35ONE_INSTRUCTION = CodeTemplate( 36 """ 37 Instruction{OpCode::${operator_name}, ${X}, ${N}},""" 38) 39 40INSTRUCTION_LIST = CodeTemplate( 41 """std::vector<Instruction>({ 42 ${instruction_list} 43 }), // instructions list""" 44) 45 46ONE_CONSTANT = CodeTemplate( 47 """ 48 c10::IValue(${constant}),""" 49) 50 51CONSTANT_LIST = CodeTemplate( 52 """std::vector<c10::IValue>({ 53 ${constant_list} 54 }), // constants list""" 55) 56 57CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list""" 58 59ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""") 60 61TYPE_LIST = CodeTemplate( 62 """std::vector<c10::TypePtr>({ 63 ${type_list} 64 }), // types list""" 65) 66 67TYPE_LIST_EMPTY = """std::vector<c10::TypePtr>(), // types list""" 68 69ONE_OPERATOTR_STRING = CodeTemplate( 70 """ 71 OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),""" 72) 73 74OPERATOR_STRING_LIST = CodeTemplate( 75 """ 76 std::vector<OperatorString>({ 77 ${operator_string_list} 78 }), // operators list""" 79) 80 81ONE_UPGRADER_FUNCTION = CodeTemplate( 82 """ 83 mobile::Function::registerFunc( 84 "${upgrader_name}", 85 ${instruction_list}, 86 ${constant_list}, 87 ${type_list}, 88 ${register_size} 89 )""" 90) 91 92ONE_UPGRADER_SRC = CodeTemplate( 93 """ 94 ByteCodeFunctionWithOperator({ 95 ${bytecode_function}, 96 ${operator_string_list} 97 }),""" 98) 99 100 101ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate( 102 """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})""" 103) # noqa: E501 104 105ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate( 106 """ 107 {std::string("${operator_name}"), 108 std::vector<Upgrader>({ 109 ${upgrader_list_in_version_map} 110 })},""" 111) 112 113 114OPERATOR_VERSION_MAP = CodeTemplate( 115 """ 116const std::unordered_map<std::string, std::vector<Upgrader>> 117getOperatorVersionMapForMobile() { 118 static std::unordered_map<std::string, std::vector<Upgrader>> 119 operatorVersionMapForMobile({ 120 ${operator_list_in_version_map} 121 }); 122 return operatorVersionMapForMobile; 123} 124""" 125) 126 127 128UPGRADER_CPP_SRC = CodeTemplate( 129 MOBILE_UPGRADERS_HEADER_DESCRIPTION 130 + """ 131#include <caffe2/serialize/versions.h> 132#include <torch/csrc/jit/mobile/upgrader_mobile.h> 133 134namespace c10 { 135TypePtr parseType(const std::string& pythonStr); 136} // namespace c10 137 138namespace torch { 139namespace jit { 140 141// clang-format off 142 143// From operator_versions_map 144${operator_version_map} 145 146const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() { 147 auto generate_upgrader_bytecode_list = []() { 148 std::vector<ByteCodeFunctionWithOperator> upgrader_function_list({ 149 ${upgrader_bytecode} 150 }); 151 for (const auto& upgrader_function : upgrader_function_list) { 152 for (const auto& op : upgrader_function.operators) { 153 upgrader_function.function.append_operator( 154 op.name, 155 op.overload_name, 156 op.num_specified_args); 157 } 158 } 159 return upgrader_function_list; 160 }; 161 static std::vector<ByteCodeFunctionWithOperator> upgraderBytecodeList = 162 generate_upgrader_bytecode_list(); 163 return upgraderBytecodeList; 164} 165 166// clang-format on 167 168} // namespace jit 169} // namespace torch 170""" 171) 172 173UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp" 174 175UPGRADER_ELEMENT = CodeTemplate( 176 """\ 177Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}), 178""" 179) 180 181PER_OPERATOR_UPGRADER_LIST = CodeTemplate( 182 """\ 183{ 184 std::string(${operator_name}), 185 std::vector<Upgrader>({${upgrader_list}}); 186} 187""" 188) 189 190 191def construct_instruction(instruction_list_from_yaml: list[Any]) -> str: 192 instruction_list_part = [] 193 for instruction in instruction_list_from_yaml: 194 instruction_list_part.append( 195 ONE_INSTRUCTION.substitute( 196 operator_name=instruction[0], 197 X=instruction[1], 198 N=instruction[2], 199 ) 200 ) 201 return INSTRUCTION_LIST.substitute( 202 instruction_list="".join(instruction_list_part).lstrip("\n") 203 ) 204 205 206def construct_constants(constants_list_from_yaml: list[Any]) -> str: 207 constants_list_part = [] 208 for constant_from_yaml in constants_list_from_yaml: 209 convert_constant = None 210 if isinstance(constant_from_yaml, str): 211 # Add quotes if it's string 212 convert_constant = f'"{constant_from_yaml}"' 213 elif isinstance(constant_from_yaml, bool): 214 convert_constant = "true" if constant_from_yaml else "false" 215 elif constant_from_yaml is None: 216 convert_constant = "" 217 elif isinstance(constant_from_yaml, int): 218 convert_constant = str(constant_from_yaml) 219 else: 220 raise ValueError( 221 f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. " 222 "Please add change in construct_constants function in gen_mobile_upgraders.py." 223 ) 224 constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant)) 225 if len(constants_list_part) == 0: 226 return CONSTANTS_LIST_EMPTY 227 return CONSTANT_LIST.substitute( 228 constant_list="".join(constants_list_part).lstrip("\n") 229 ) 230 231 232def construct_operators(operator_list_from_yaml: list[Any]) -> str: 233 operator_list_part = [] 234 for operator in operator_list_from_yaml: 235 operator_list_part.append( 236 ONE_OPERATOTR_STRING.substitute( 237 operator_name=operator[0], 238 overload_name=operator[1], 239 num_of_args=operator[2], 240 ) 241 ) 242 return OPERATOR_STRING_LIST.substitute( 243 operator_string_list="".join(operator_list_part).lstrip("\n") 244 ) 245 246 247def construct_types(types_tr_list_from_yaml: list[Any]) -> str: 248 types_tr_list_part = [] 249 for types_tr in types_tr_list_from_yaml: 250 types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr)) 251 if len(types_tr_list_part) == 0: 252 return TYPE_LIST_EMPTY 253 return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n")) 254 255 256def construct_register_size(register_size_from_yaml: int) -> str: 257 if not isinstance(register_size_from_yaml, int): 258 raise ValueError( 259 f"Input register size is {register_size_from_yaml} and" 260 "it's type is {type(register_size_from_yaml)}. An int type is expected." 261 ) 262 return str(register_size_from_yaml) 263 264 265def construct_version_maps( 266 upgrader_bytecode_function_to_index_map: dict[str, Any] 267) -> str: 268 version_map = torch._C._get_operator_version_map() 269 sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] 270 sorted_version_map = dict(sorted_version_map_) 271 272 operator_list_in_version_map_part = [] 273 for op_name in sorted_version_map: 274 upgraders_in_version_map_part = [] 275 # TODO: remove the skip after these two operators schemas are fixed 276 if op_name in EXCLUDED_OP_SET: 277 continue 278 upgrader_ranges = torch._C._get_upgrader_ranges(op_name) 279 upgrader_entries = sorted_version_map[op_name] 280 assert len(upgrader_ranges) == len(upgrader_entries) 281 for idx, upgrader_entry in enumerate(upgrader_entries): 282 upgrader_name = upgrader_entry.upgrader_name 283 bytecode_function_index = upgrader_bytecode_function_to_index_map[ 284 upgrader_name 285 ] 286 upgraders_in_version_map_part.append( 287 ONE_UPGRADER_IN_VERSION_MAP.substitute( 288 upgrader_min_version=upgrader_ranges[idx].min_version, 289 upgrader_max_version=upgrader_ranges[idx].max_version, 290 upgrader_name=upgrader_name, 291 bytecode_func_index=bytecode_function_index, 292 ) 293 ) 294 operator_list_in_version_map_part.append( 295 ONE_OPERATOR_IN_VERSION_MAP.substitute( 296 operator_name=op_name, 297 upgrader_list_in_version_map="".join(upgraders_in_version_map_part), 298 ) 299 ) 300 return OPERATOR_VERSION_MAP.substitute( 301 operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip( 302 "\n" 303 ) 304 ) 305 306 307def get_upgrader_bytecode_function_to_index_map( 308 upgrader_dict: list[dict[str, Any]] 309) -> dict[str, Any]: 310 upgrader_bytecode_function_to_index_map = {} 311 index = 0 312 for upgrader_bytecode in upgrader_dict: 313 for upgrader_name in upgrader_bytecode.keys(): 314 if upgrader_name in EXCLUE_UPGRADER_SET: 315 continue 316 upgrader_bytecode_function_to_index_map[upgrader_name] = index 317 index += 1 318 return upgrader_bytecode_function_to_index_map 319 320 321def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None: 322 body_parts = [] 323 upgrader_bytecode_function_to_index_map = ( 324 get_upgrader_bytecode_function_to_index_map(upgrader_dict) 325 ) 326 version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map) 327 all_upgrader_src_string = [] 328 for upgrader_bytecode in upgrader_dict: 329 for upgrader_name, bytecode in upgrader_bytecode.items(): 330 # TODO: remove the skip after these two operators schemas are fixed 331 if upgrader_name in EXCLUE_UPGRADER_SET: 332 continue 333 instruction_list_str = "" 334 constant_list_str = "" 335 type_list_str = "" 336 register_size_str = "" 337 operator_list_str = "" 338 for table_name, contents in bytecode.items(): 339 element = ByteCode[table_name] 340 body_string = "" 341 if element is ByteCode.instructions: 342 instruction_list_str = construct_instruction(contents) 343 elif element is ByteCode.constants: 344 constant_list_str = construct_constants(contents) 345 elif element is ByteCode.operators: 346 operator_list_str = construct_operators(contents) 347 elif element is ByteCode.types: 348 type_list_str = construct_types(contents) 349 elif element is ByteCode.register_size: 350 register_size_str = construct_register_size(contents) 351 352 one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute( 353 upgrader_name=upgrader_name, 354 instruction_list=instruction_list_str, 355 constant_list=constant_list_str, 356 type_list=type_list_str, 357 register_size=register_size_str, 358 ) 359 one_upgrader_src_string = ONE_UPGRADER_SRC.substitute( 360 bytecode_function=one_upgrader_function_string.lstrip("\n"), 361 operator_string_list=operator_list_str.lstrip("\n"), 362 ) 363 all_upgrader_src_string.append(one_upgrader_src_string) 364 365 upgrader_file_content = UPGRADER_CPP_SRC.substitute( 366 operator_version_map=version_map_src, 367 upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"), 368 ) 369 body_parts.append(upgrader_file_content) 370 print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME) 371 with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file: 372 final_output = "".join(body_parts) 373 out_file.write(upgrader_file_content.encode("utf-8")) 374 375 376def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]: 377 sorted_upgrader_list = sorted( 378 upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) 379 ) 380 return sorted_upgrader_list 381 382 383def main() -> None: 384 upgrader_list = generate_upgraders_bytecode() 385 sorted_upgrader_list = sort_upgrader(upgrader_list) 386 for up in sorted_upgrader_list: 387 print("after sort upgrader : ", next(iter(up))) 388 389 pytorch_dir = Path(__file__).resolve().parents[2] 390 upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile" 391 write_cpp(str(upgrader_path), sorted_upgrader_list) 392 393 394if __name__ == "__main__": 395 main() 396