xref: /aosp_15_r20/external/pytorch/torchgen/operator_versions/gen_mobile_upgraders.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import os
6from enum import Enum
7from operator import itemgetter
8from pathlib import Path
9from typing import Any
10
11import torch
12from torch.jit.generate_bytecode import generate_upgraders_bytecode
13from torchgen.code_template import CodeTemplate
14from torchgen.operator_versions.gen_mobile_upgraders_constant import (
15    MOBILE_UPGRADERS_HEADER_DESCRIPTION,
16)
17
18
19class ByteCode(Enum):
20    instructions = 1
21    constants = 2
22    types = 3
23    operators = 4
24    register_size = 5
25
26
27EXCLUDED_OP_SET = [
28    "aten::full.names",
29    "aten::full.out",
30    "aten::full",
31]
32
33EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]
34
35ONE_INSTRUCTION = CodeTemplate(
36    """
37    Instruction{OpCode::${operator_name}, ${X}, ${N}},"""
38)
39
40INSTRUCTION_LIST = CodeTemplate(
41    """std::vector<Instruction>({
42        ${instruction_list}
43    }), // instructions list"""
44)
45
46ONE_CONSTANT = CodeTemplate(
47    """
48    c10::IValue(${constant}),"""
49)
50
51CONSTANT_LIST = CodeTemplate(
52    """std::vector<c10::IValue>({
53        ${constant_list}
54    }), // constants list"""
55)
56
57CONSTANTS_LIST_EMPTY = """std::vector<c10::IValue>(), // constants list"""
58
59ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")
60
61TYPE_LIST = CodeTemplate(
62    """std::vector<c10::TypePtr>({
63        ${type_list}
64    }), // types list"""
65)
66
67TYPE_LIST_EMPTY = """std::vector<c10::TypePtr>(), // types list"""
68
69ONE_OPERATOTR_STRING = CodeTemplate(
70    """
71    OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),"""
72)
73
74OPERATOR_STRING_LIST = CodeTemplate(
75    """
76    std::vector<OperatorString>({
77        ${operator_string_list}
78    }), // operators list"""
79)
80
81ONE_UPGRADER_FUNCTION = CodeTemplate(
82    """
83    mobile::Function::registerFunc(
84        "${upgrader_name}",
85        ${instruction_list},
86        ${constant_list},
87        ${type_list},
88        ${register_size}
89    )"""
90)
91
92ONE_UPGRADER_SRC = CodeTemplate(
93    """
94    ByteCodeFunctionWithOperator({
95        ${bytecode_function},
96        ${operator_string_list}
97    }),"""
98)
99
100
101ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate(
102    """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})"""
103)  # noqa: E501
104
105ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate(
106    """
107    {std::string("${operator_name}"),
108        std::vector<Upgrader>({
109            ${upgrader_list_in_version_map}
110        })},"""
111)
112
113
114OPERATOR_VERSION_MAP = CodeTemplate(
115    """
116const std::unordered_map<std::string, std::vector<Upgrader>>
117getOperatorVersionMapForMobile() {
118  static std::unordered_map<std::string, std::vector<Upgrader>>
119        operatorVersionMapForMobile({
120            ${operator_list_in_version_map}
121      });
122  return operatorVersionMapForMobile;
123}
124"""
125)
126
127
128UPGRADER_CPP_SRC = CodeTemplate(
129    MOBILE_UPGRADERS_HEADER_DESCRIPTION
130    + """
131#include <caffe2/serialize/versions.h>
132#include <torch/csrc/jit/mobile/upgrader_mobile.h>
133
134namespace c10 {
135TypePtr parseType(const std::string& pythonStr);
136} // namespace c10
137
138namespace torch {
139namespace jit {
140
141// clang-format off
142
143// From operator_versions_map
144${operator_version_map}
145
146const std::vector<ByteCodeFunctionWithOperator>& getUpgraderBytecodeList() {
147  auto generate_upgrader_bytecode_list = []() {
148    std::vector<ByteCodeFunctionWithOperator> upgrader_function_list({
149               ${upgrader_bytecode}
150            });
151    for (const auto& upgrader_function : upgrader_function_list) {
152      for (const auto& op : upgrader_function.operators) {
153        upgrader_function.function.append_operator(
154            op.name,
155            op.overload_name,
156            op.num_specified_args);
157      }
158    }
159    return upgrader_function_list;
160  };
161  static std::vector<ByteCodeFunctionWithOperator> upgraderBytecodeList =
162      generate_upgrader_bytecode_list();
163  return upgraderBytecodeList;
164}
165
166// clang-format on
167
168} // namespace jit
169} // namespace torch
170"""
171)
172
173UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp"
174
175UPGRADER_ELEMENT = CodeTemplate(
176    """\
177Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}),
178"""
179)
180
181PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
182    """\
183{
184  std::string(${operator_name}),
185  std::vector<Upgrader>({${upgrader_list}});
186}
187"""
188)
189
190
191def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
192    instruction_list_part = []
193    for instruction in instruction_list_from_yaml:
194        instruction_list_part.append(
195            ONE_INSTRUCTION.substitute(
196                operator_name=instruction[0],
197                X=instruction[1],
198                N=instruction[2],
199            )
200        )
201    return INSTRUCTION_LIST.substitute(
202        instruction_list="".join(instruction_list_part).lstrip("\n")
203    )
204
205
206def construct_constants(constants_list_from_yaml: list[Any]) -> str:
207    constants_list_part = []
208    for constant_from_yaml in constants_list_from_yaml:
209        convert_constant = None
210        if isinstance(constant_from_yaml, str):
211            # Add quotes if it's string
212            convert_constant = f'"{constant_from_yaml}"'
213        elif isinstance(constant_from_yaml, bool):
214            convert_constant = "true" if constant_from_yaml else "false"
215        elif constant_from_yaml is None:
216            convert_constant = ""
217        elif isinstance(constant_from_yaml, int):
218            convert_constant = str(constant_from_yaml)
219        else:
220            raise ValueError(
221                f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. "
222                "Please add change in construct_constants function in gen_mobile_upgraders.py."
223            )
224        constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant))
225    if len(constants_list_part) == 0:
226        return CONSTANTS_LIST_EMPTY
227    return CONSTANT_LIST.substitute(
228        constant_list="".join(constants_list_part).lstrip("\n")
229    )
230
231
232def construct_operators(operator_list_from_yaml: list[Any]) -> str:
233    operator_list_part = []
234    for operator in operator_list_from_yaml:
235        operator_list_part.append(
236            ONE_OPERATOTR_STRING.substitute(
237                operator_name=operator[0],
238                overload_name=operator[1],
239                num_of_args=operator[2],
240            )
241        )
242    return OPERATOR_STRING_LIST.substitute(
243        operator_string_list="".join(operator_list_part).lstrip("\n")
244    )
245
246
247def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
248    types_tr_list_part = []
249    for types_tr in types_tr_list_from_yaml:
250        types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
251    if len(types_tr_list_part) == 0:
252        return TYPE_LIST_EMPTY
253    return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n"))
254
255
256def construct_register_size(register_size_from_yaml: int) -> str:
257    if not isinstance(register_size_from_yaml, int):
258        raise ValueError(
259            f"Input register size is {register_size_from_yaml} and"
260            "it's type is {type(register_size_from_yaml)}. An int type is expected."
261        )
262    return str(register_size_from_yaml)
263
264
265def construct_version_maps(
266    upgrader_bytecode_function_to_index_map: dict[str, Any]
267) -> str:
268    version_map = torch._C._get_operator_version_map()
269    sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0))  # type: ignore[no-any-return]
270    sorted_version_map = dict(sorted_version_map_)
271
272    operator_list_in_version_map_part = []
273    for op_name in sorted_version_map:
274        upgraders_in_version_map_part = []
275        # TODO: remove the skip after these two operators schemas are fixed
276        if op_name in EXCLUDED_OP_SET:
277            continue
278        upgrader_ranges = torch._C._get_upgrader_ranges(op_name)
279        upgrader_entries = sorted_version_map[op_name]
280        assert len(upgrader_ranges) == len(upgrader_entries)
281        for idx, upgrader_entry in enumerate(upgrader_entries):
282            upgrader_name = upgrader_entry.upgrader_name
283            bytecode_function_index = upgrader_bytecode_function_to_index_map[
284                upgrader_name
285            ]
286            upgraders_in_version_map_part.append(
287                ONE_UPGRADER_IN_VERSION_MAP.substitute(
288                    upgrader_min_version=upgrader_ranges[idx].min_version,
289                    upgrader_max_version=upgrader_ranges[idx].max_version,
290                    upgrader_name=upgrader_name,
291                    bytecode_func_index=bytecode_function_index,
292                )
293            )
294        operator_list_in_version_map_part.append(
295            ONE_OPERATOR_IN_VERSION_MAP.substitute(
296                operator_name=op_name,
297                upgrader_list_in_version_map="".join(upgraders_in_version_map_part),
298            )
299        )
300    return OPERATOR_VERSION_MAP.substitute(
301        operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip(
302            "\n"
303        )
304    )
305
306
307def get_upgrader_bytecode_function_to_index_map(
308    upgrader_dict: list[dict[str, Any]]
309) -> dict[str, Any]:
310    upgrader_bytecode_function_to_index_map = {}
311    index = 0
312    for upgrader_bytecode in upgrader_dict:
313        for upgrader_name in upgrader_bytecode.keys():
314            if upgrader_name in EXCLUE_UPGRADER_SET:
315                continue
316            upgrader_bytecode_function_to_index_map[upgrader_name] = index
317            index += 1
318    return upgrader_bytecode_function_to_index_map
319
320
321def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
322    body_parts = []
323    upgrader_bytecode_function_to_index_map = (
324        get_upgrader_bytecode_function_to_index_map(upgrader_dict)
325    )
326    version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map)
327    all_upgrader_src_string = []
328    for upgrader_bytecode in upgrader_dict:
329        for upgrader_name, bytecode in upgrader_bytecode.items():
330            # TODO: remove the skip after these two operators schemas are fixed
331            if upgrader_name in EXCLUE_UPGRADER_SET:
332                continue
333            instruction_list_str = ""
334            constant_list_str = ""
335            type_list_str = ""
336            register_size_str = ""
337            operator_list_str = ""
338            for table_name, contents in bytecode.items():
339                element = ByteCode[table_name]
340                body_string = ""
341                if element is ByteCode.instructions:
342                    instruction_list_str = construct_instruction(contents)
343                elif element is ByteCode.constants:
344                    constant_list_str = construct_constants(contents)
345                elif element is ByteCode.operators:
346                    operator_list_str = construct_operators(contents)
347                elif element is ByteCode.types:
348                    type_list_str = construct_types(contents)
349                elif element is ByteCode.register_size:
350                    register_size_str = construct_register_size(contents)
351
352            one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute(
353                upgrader_name=upgrader_name,
354                instruction_list=instruction_list_str,
355                constant_list=constant_list_str,
356                type_list=type_list_str,
357                register_size=register_size_str,
358            )
359            one_upgrader_src_string = ONE_UPGRADER_SRC.substitute(
360                bytecode_function=one_upgrader_function_string.lstrip("\n"),
361                operator_string_list=operator_list_str.lstrip("\n"),
362            )
363            all_upgrader_src_string.append(one_upgrader_src_string)
364
365    upgrader_file_content = UPGRADER_CPP_SRC.substitute(
366        operator_version_map=version_map_src,
367        upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"),
368    )
369    body_parts.append(upgrader_file_content)
370    print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME)
371    with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file:
372        final_output = "".join(body_parts)
373        out_file.write(upgrader_file_content.encode("utf-8"))
374
375
376def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
377    sorted_upgrader_list = sorted(
378        upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
379    )
380    return sorted_upgrader_list
381
382
383def main() -> None:
384    upgrader_list = generate_upgraders_bytecode()
385    sorted_upgrader_list = sort_upgrader(upgrader_list)
386    for up in sorted_upgrader_list:
387        print("after sort upgrader : ", next(iter(up)))
388
389    pytorch_dir = Path(__file__).resolve().parents[2]
390    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile"
391    write_cpp(str(upgrader_path), sorted_upgrader_list)
392
393
394if __name__ == "__main__":
395    main()
396