1#!/usr/bin/env python3 2 3from __future__ import annotations 4 5import argparse 6import json 7import os 8import sys 9from functools import reduce 10from typing import Any 11 12import yaml 13from tools.lite_interpreter.gen_selected_mobile_ops_header import ( 14 write_selected_mobile_ops, 15) 16 17from torchgen.selective_build.selector import ( 18 combine_selective_builders, 19 SelectiveBuilder, 20) 21 22 23def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]: 24 return set(selective_builder.operators.keys()) 25 26 27def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]: 28 ops = [] 29 for op_name, op in selective_builder.operators.items(): 30 if op.is_used_for_training: 31 ops.append(op_name) 32 return set(ops) 33 34 35def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None: 36 ops = [] 37 for op_name, op in selective_builder.operators.items(): 38 if op.include_all_overloads: 39 ops.append(op_name) 40 if ops: 41 raise Exception( # noqa: TRY002 42 ( 43 "Operators that include all overloads are " 44 + "not allowed since --allow-include-all-overloads " 45 + "was specified: {}" 46 ).format(", ".join(ops)) 47 ) 48 49 50def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None: 51 supported_mobile_models_source = """/* 52 * Generated by gen_oplist.py 53 */ 54#include "fb/supported_mobile_models/SupportedMobileModels.h" 55 56 57struct SupportedMobileModelCheckerRegistry {{ 58 SupportedMobileModelCheckerRegistry() {{ 59 auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton(); 60 ref.set_supported_md5_hashes(std::unordered_set<std::string>{{ 61 {supported_hashes_template} 62 }}); 63 }} 64}}; 65 66// This is a global object, initializing which causes the registration to happen. 67SupportedMobileModelCheckerRegistry register_model_versions; 68 69 70""" 71 72 # Generate SupportedMobileModelsRegistration.cpp 73 md5_hashes = set() 74 for model_dict in model_dicts: 75 if "debug_info" in model_dict: 76 debug_info = json.loads(model_dict["debug_info"][0]) 77 if debug_info["is_new_style_rule"]: 78 for asset_info in debug_info["asset_info"].values(): 79 md5_hashes.update(asset_info["md5_hash"]) 80 81 supported_hashes = "" 82 for md5 in md5_hashes: 83 supported_hashes += f'"{md5}",\n' 84 with open( 85 os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb" 86 ) as out_file: 87 source = supported_mobile_models_source.format( 88 supported_hashes_template=supported_hashes 89 ) 90 out_file.write(source.encode("utf-8")) 91 92 93def main(argv: list[Any]) -> None: 94 """This binary generates 3 files: 95 96 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function 97 dtypes captured by tracing 98 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis) 99 """ 100 parser = argparse.ArgumentParser(description="Generate operator lists") 101 parser.add_argument( 102 "--output-dir", 103 "--output_dir", 104 help=( 105 "The directory to store the output yaml files (selected_mobile_ops.h, " 106 + "selected_kernel_dtypes.h, selected_operators.yaml)" 107 ), 108 required=True, 109 ) 110 parser.add_argument( 111 "--model-file-list-path", 112 "--model_file_list_path", 113 help=( 114 "Path to a file that contains the locations of individual " 115 + "model YAML files that contain the set of used operators. This " 116 + "file path must have a leading @-symbol, which will be stripped " 117 + "out before processing." 118 ), 119 required=True, 120 ) 121 parser.add_argument( 122 "--allow-include-all-overloads", 123 "--allow_include_all_overloads", 124 help=( 125 "Flag to allow operators that include all overloads. " 126 + "If not set, operators registered without using the traced style will" 127 + "break the build." 128 ), 129 action="store_true", 130 default=False, 131 required=False, 132 ) 133 options = parser.parse_args(argv) 134 135 if os.path.isfile(options.model_file_list_path): 136 print("Processing model file: ", options.model_file_list_path) 137 model_dicts = [] 138 model_dict = yaml.safe_load(open(options.model_file_list_path)) 139 model_dicts.append(model_dict) 140 else: 141 print("Processing model directory: ", options.model_file_list_path) 142 assert options.model_file_list_path[0] == "@" 143 model_file_list_path = options.model_file_list_path[1:] 144 145 model_dicts = [] 146 with open(model_file_list_path) as model_list_file: 147 model_file_names = model_list_file.read().split() 148 for model_file_name in model_file_names: 149 with open(model_file_name, "rb") as model_file: 150 model_dict = yaml.safe_load(model_file) 151 model_dicts.append(model_dict) 152 153 selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts] 154 155 # While we have the model_dicts generate the supported mobile models api 156 gen_supported_mobile_models(model_dicts, options.output_dir) 157 158 # We may have 0 selective builders since there may not be any viable 159 # pt_operator_library rule marked as a dep for the pt_operator_registry rule. 160 # This is potentially an error, and we should probably raise an assertion 161 # failure here. However, this needs to be investigated further. 162 selective_builder = SelectiveBuilder.from_yaml_dict({}) 163 if len(selective_builders) > 0: 164 selective_builder = reduce( 165 combine_selective_builders, 166 selective_builders, 167 ) 168 169 if not options.allow_include_all_overloads: 170 throw_if_any_op_includes_overloads(selective_builder) 171 with open( 172 os.path.join(options.output_dir, "selected_operators.yaml"), "wb" 173 ) as out_file: 174 out_file.write( 175 yaml.safe_dump( 176 selective_builder.to_dict(), default_flow_style=False 177 ).encode("utf-8"), 178 ) 179 180 write_selected_mobile_ops( 181 os.path.join(options.output_dir, "selected_mobile_ops.h"), 182 selective_builder, 183 ) 184 185 186if __name__ == "__main__": 187 main(sys.argv[1:]) 188