xref: /aosp_15_r20/external/pytorch/tools/code_analyzer/gen_operators_yaml.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2
3from __future__ import annotations
4
5import argparse
6import json
7import sys
8from typing import Any
9
10import yaml
11from gen_op_registration_allowlist import (
12    canonical_name,
13    gen_transitive_closure,
14    load_op_dep_graph,
15)
16
17from torchgen.selective_build.operator import (
18    merge_operator_dicts,
19    SelectiveBuildOperator,
20)
21from torchgen.selective_build.selector import merge_kernel_metadata
22
23
24# Generate YAML file containing the operators used for a specific PyTorch model.
25# ------------------------------------------------------------------------------
26#
27# This binary is responsible for generating the model_operators.yaml file for
28# each model from a pt_operator_library() BUCK macro invocation.
29#
30# Output YAML file format:
31# ------------------------
32#
33# <BEGIN FILE CONTENTS>
34# include_all_non_op_selectives: False
35# include_all_operators: False
36# debug_info:
37#   - model1@v100
38#   - model2@v50
39# operators:
40#   aten::add:
41#     is_root_operator: Yes
42#     is_used_for_training: Yes
43#     include_all_overloads: No
44#     debug_info:
45#       - model1@v100
46#       - model2@v50
47#   aten::add.int:
48#     is_root_operator: No
49#     is_used_for_training: No
50#     include_all_overloads: Yes
51# kernel_metadata:
52#   add_kernel:
53#     - Int8
54#     - UInt32
55#   sub_kernel:
56#     - Int16
57#     - Float
58# <END FILE CONTENTS>
59#
60# There are a few main inputs to this application
61# -----------------------------------------------
62#
63# 1. Inference Root Operators (--root-ops): Root operators (called directly
64#    from TorchScript) used by inference use-cases.
65#
66# 2. Training Root Operators (--training-root-ops): Root operators used
67#    by training use-cases. Currently, this list is the list of all operators
68#    used by training, and not just the root operators. All Training ops are
69#    also considered for inference, so these are merged into inference ops.
70#
71# 3. Operator Depencency Graph (--dep-graph-yaml-path): A path to the
72#    operator dependency graph used to determine which operators depend on
73#    which other operators for correct functioning. This is used for
74#    generating the transitive closure of all the operators used by the
75#    model based on the root operators when static selective build is used.
76#    For tracing based selective build, we don't need to perform this
77#    transitive cloure.
78#
79# 4. Model Metadata (--model-name, --model-versions, --model-assets,
80#    --model-backends): Self-descriptive. These are used to tell this
81#    script which model operator lists to fetch from the Model
82#    Build Metadata YAML files.
83#
84# 5. Model YAML files (--models-yaml-path): These yaml files contains
85#    (for each model/version/asset/backend) the set of used root and traced
86#    operators. This is used to extract the actual set of operators
87#    needed to be included in the build.
88#
89
90
91def canonical_opnames(opnames: list[str]) -> list[str]:
92    return [canonical_name(opname) for opname in opnames]
93
94
95def make_filter_from_options(
96    model_name: str,
97    model_versions: list[str],
98    model_assets: list[str] | None,
99    model_backends: list[str] | None,
100):
101    def is_model_included(model_info) -> bool:
102        model = model_info["model"]
103        if model["name"] != model_name:
104            return False
105        if str(model["version"]) not in model_versions:
106            return False
107        if model_assets is not None and model["asset"] not in model_assets:
108            return False
109        # TODO: Handle backend later
110        return True
111
112    return is_model_included
113
114
115# Returns if a the specified rule is a new or old style pt_operator_library
116def is_new_style_rule(model_name: str, model_versions: list[str] | None):
117    return model_name is not None and model_versions is not None
118
119
120# Verifies that specified model_name, and all specified versions and assets
121# appear in at least one model yaml. Throws if verification is failed,
122# returns None on success
123def verify_all_specified_present(
124    model_assets: list[str] | None,
125    model_versions: list[str],
126    selected_models_yaml: list[dict[str, Any]],
127    rule_name: str,
128    model_name: str,
129    new_style_rule: bool,
130) -> None:
131    def find_missing_items(model_items, key, selected_models_yaml):
132        missing_items = []
133        if not new_style_rule or not model_items:
134            return missing_items
135        for item in model_items:
136            found = False
137            for model in selected_models_yaml:
138                if str(model["model"][key]) == item:
139                    found = True
140            if not found:
141                missing_items.append(item)
142        return missing_items
143
144    missing_assets = find_missing_items(model_assets, "asset", selected_models_yaml)
145    missing_versions = find_missing_items(
146        model_versions, "version", selected_models_yaml
147    )
148
149    if len(missing_versions) > 0 or len(missing_assets) > 0:  # at least one is missing
150        name_warning = ""
151        if len(selected_models_yaml) == 0:
152            name_warning = (
153                "WARNING: 0 yaml's were found for target rule. This could be because the "
154                + "provided model name: {name} is incorrect. Please check that field as well as "
155                + "the assets and versions."
156            ).format(name=model_name)
157        raise RuntimeError(
158            (
159                "Error: From the pt_operator_library rule for Rule: {name}, at least one entry for the "
160                + "following fields was expected -- Model: {model_name} Expected Assets: {expected_assets}, Expected Versions: "
161                + "{expected_versions}. {name_warning} In all_mobile_models.yaml either no assets were on one of the "
162                + "specified versions, one of the specified assets was not present on any of the specified "
163                + "versions, or both. Assets not found: {missing_assets}, Versions not found: {missing_versions} "
164                + "For questions please ask in https://fb.workplace.com/groups/2148543255442743/"
165            ).format(
166                name=rule_name,
167                model_name=model_name,
168                expected_versions=model_versions,
169                expected_assets=model_assets
170                if model_assets
171                else "<All model assets present on specified versions>",
172                name_warning=name_warning,
173                missing_versions=missing_versions
174                if len(missing_versions) > 0
175                else "<All specified versions had at least one asset>",
176                missing_assets=missing_assets
177                if len(missing_assets) > 0
178                else "<All specified assets are present on at least 1 version>",
179            )
180        )
181
182
183# Uses the selected models configs and then combines them into one dictionary,
184# formats them as a string, and places the string into output as a top level debug_info
185def create_debug_info_from_selected_models(
186    output: dict[str, object],
187    selected_models: list[dict],
188    new_style_rule: bool,
189) -> None:
190    model_dict = {
191        "asset_info": {},  # maps asset name -> dict of asset metadata like hashes
192        "is_new_style_rule": new_style_rule,
193    }
194
195    for model in selected_models:
196        model_info = model["model"]
197        asset = model_info["asset"]
198        hash = model_info["md5_hash"]
199
200        asset_info = model_dict["asset_info"].setdefault(asset, {})
201
202        asset_info.setdefault("md5_hash", []).append(hash)
203
204    # Will later be used in gen_oplist to generate the model/version/asset checking
205    output["debug_info"] = [json.dumps(model_dict)]
206
207
208def fill_output(output: dict[str, object], options: object) -> None:
209    """Populate the output dict with the information required to serialize
210    the YAML file used for selective build.
211    """
212    dept_graph = load_op_dep_graph(options.dep_graph_yaml_path)
213
214    model_versions = (
215        options.model_versions.split(",") if options.model_versions is not None else []
216    )
217    model_assets = (
218        options.model_assets.split(",") if options.model_assets is not None else None
219    )
220
221    all_models_yaml = []
222    if options.models_yaml_path:
223        for yaml_path in options.models_yaml_path:
224            with open(yaml_path, "rb") as f:
225                all_models_yaml.append(yaml.safe_load(f))
226
227    model_filter_func = make_filter_from_options(
228        options.model_name, model_versions, model_assets, options.model_backends
229    )
230
231    selected_models_yaml = list(filter(model_filter_func, all_models_yaml))
232
233    verify_all_specified_present(
234        model_assets=model_assets,
235        model_versions=model_versions,
236        selected_models_yaml=selected_models_yaml,
237        rule_name=options.rule_name,
238        model_name=options.model_name,
239        new_style_rule=is_new_style_rule(options.model_name, options.model_versions),
240    )
241
242    create_debug_info_from_selected_models(
243        output,
244        selected_models_yaml,
245        is_new_style_rule(options.model_name, options.model_versions),
246    )
247
248    # initialize variables for static build from the pt_operator_library rule
249    if options.root_ops is not None:
250        static_root_ops = set(filter(lambda x: len(x) > 0, options.root_ops.split(",")))
251    else:
252        static_root_ops = set()
253
254    static_training_root_ops = set(
255        filter(
256            lambda x: len(x) > 0,
257            (options.training_root_ops or "").split(","),
258        )
259    )
260    if len(static_training_root_ops) > 0:
261        static_root_ops = static_root_ops | static_training_root_ops
262    # end if
263
264    root_ops_unexpand = set()
265    traced_ops = set()
266    training_root_ops_unexpand = set()
267    traced_training_ops = set()
268    all_kernel_metadata = []
269    all_custom_classes = set()
270    all_build_features = set()
271
272    # Go through each yaml file and retrieve operator information.
273    for model_info in selected_models_yaml:
274        if "traced_operators" not in model_info:
275            # If this YAML file doesn't specify any traced operators, then it is using
276            # the static analysis selective build approach of finding transitively
277            # used operators, and we should update root_ops with the set of root
278            # operators, all of whose overloads must be included. In addition, these
279            # root_ops will be further expanded using the transitive closure of
280            # operator dependencies.
281            static_root_ops = static_root_ops | set(model_info["root_operators"])
282        else:
283            # If this YAML file specifies traced operators, then it is using
284            # the tracing based selective build approach of finding used
285            # operators, and we should update root_ops_unexpand with the set of root
286            # operators whose overloads don't need to be included. In addition, these
287            # root_ops_unexpand will NOT be further expanded. If the train flag is
288            # set then the ops will be used for training, so we put them in a separate
289            # set
290            if model_info["train"]:
291                training_root_ops_unexpand = training_root_ops_unexpand | set(
292                    model_info["root_operators"]
293                )
294                traced_training_ops = traced_training_ops | set(
295                    model_info["traced_operators"]
296                )
297            else:
298                root_ops_unexpand = root_ops_unexpand | set(
299                    model_info["root_operators"]
300                )
301                traced_ops = traced_ops | set(model_info["traced_operators"])
302
303        if "kernel_metadata" in model_info:
304            all_kernel_metadata.append(model_info["kernel_metadata"])
305
306        if "custom_classes" in model_info:
307            all_custom_classes = all_custom_classes | set(model_info["custom_classes"])
308
309        if "build_features" in model_info:
310            all_build_features = all_build_features | set(model_info["build_features"])
311
312    # This following section on transitive closure is relevant to static build only
313    canonical_root_ops = canonical_opnames(static_root_ops)
314    # If no canonical_root_ops exist, don't compute the transitive closure
315    # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
316    # for inference.
317    if len(canonical_root_ops) > 0:
318        closure_op_list = gen_transitive_closure(dept_graph, canonical_root_ops)
319    else:
320        closure_op_list = set()
321
322    canonical_training_root_ops = canonical_opnames(static_training_root_ops)
323    # If no canonical_training_root_ops exist, don't compute the transitive closure
324    # otherwise, we will include __BASE__ and __ROOT__ ops and mark them as required
325    # for training.
326    if len(canonical_training_root_ops) > 0:
327        closure_training_op_list = gen_transitive_closure(
328            dept_graph, canonical_training_root_ops, train=True
329        )
330    else:
331        closure_training_op_list = set()
332
333    # bucketed_ops holds sets of operators that correspond to specific semantic buckets. For
334    # example:
335    #
336    # 1. Root Operators not used for training w/o full overload inclusion
337    # 2. Root Operators not used for training w/ full overload inclusion
338    # 3. Root Operators used for training w/o full overload inclusion
339    # 4. Root Operators used for training w/ full overload inclusion
340    # 5. Non-root Operators not used for training w/o full overload inclusion
341    # etc...
342    #
343    # Basically for each of the 3 boolean conditional, there are 2
344    # options (True/False).
345    #
346    bucketed_ops = []
347
348    # START STATIC BUILD OPS
349    static_root_ops_bucket = {}
350    for op_name in static_root_ops:
351        op = SelectiveBuildOperator.from_yaml_dict(
352            op_name,
353            {
354                "is_root_operator": True,
355                "is_used_for_training": False,
356                "include_all_overloads": not options.not_include_all_overloads_static_root_ops,
357                "debug_info": [options.model_name],
358            },
359        )
360        static_root_ops_bucket[op_name] = op
361    bucketed_ops.append(static_root_ops_bucket)
362
363    closure_ops_bucket = {}
364    for op_name in closure_op_list:
365        op = SelectiveBuildOperator.from_yaml_dict(
366            op_name,
367            {
368                "is_root_operator": False,
369                "is_used_for_training": False,
370                "include_all_overloads": not options.not_include_all_overloads_closure_ops,
371                "debug_info": [options.model_name],
372            },
373        )
374        closure_ops_bucket[op_name] = op
375    bucketed_ops.append(closure_ops_bucket)
376
377    static_training_root_ops_bucket = {}
378    for op_name in static_training_root_ops:
379        op = SelectiveBuildOperator.from_yaml_dict(
380            op_name,
381            {
382                "is_root_operator": True,
383                "is_used_for_training": True,
384                "include_all_overloads": True,
385                "debug_info": [options.model_name],
386            },
387        )
388        static_training_root_ops_bucket[op_name] = op
389    bucketed_ops.append(static_training_root_ops_bucket)
390
391    closure_training_ops_bucket = {}
392    for op_name in closure_training_op_list:
393        op = SelectiveBuildOperator.from_yaml_dict(
394            op_name,
395            {
396                "is_root_operator": False,
397                "is_used_for_training": True,
398                "include_all_overloads": True,
399                "debug_info": [options.model_name],
400            },
401        )
402        closure_training_ops_bucket[op_name] = op
403    bucketed_ops.append(closure_training_ops_bucket)
404    # END STATIC BUILD OPS
405
406    # START TRACING BASED BUILD OPS
407    root_ops_unexpand_bucket = {}
408    for op_name in root_ops_unexpand:
409        op = SelectiveBuildOperator.from_yaml_dict(
410            op_name,
411            {
412                "is_root_operator": True,
413                "is_used_for_training": False,
414                "include_all_overloads": False,
415                "debug_info": [options.model_name],
416            },
417        )
418        root_ops_unexpand_bucket[op_name] = op
419    bucketed_ops.append(root_ops_unexpand_bucket)
420
421    traced_ops_bucket = {}
422    for op_name in traced_ops:
423        op = SelectiveBuildOperator.from_yaml_dict(
424            op_name,
425            {
426                "is_root_operator": False,
427                "is_used_for_training": False,
428                "include_all_overloads": False,
429                "debug_info": [options.model_name],
430            },
431        )
432        traced_ops_bucket[op_name] = op
433    bucketed_ops.append(traced_ops_bucket)
434
435    training_root_ops_unexpand_bucket = {}
436    for op_name in training_root_ops_unexpand:
437        op = SelectiveBuildOperator.from_yaml_dict(
438            op_name,
439            {
440                "is_root_operator": True,
441                "is_used_for_training": True,
442                "include_all_overloads": False,
443                "debug_info": [options.model_name],
444            },
445        )
446        training_root_ops_unexpand_bucket[op_name] = op
447    bucketed_ops.append(training_root_ops_unexpand_bucket)
448
449    traced_training_ops_bucket = {}
450    for op_name in traced_training_ops:
451        op = SelectiveBuildOperator.from_yaml_dict(
452            op_name,
453            {
454                "is_root_operator": False,
455                "is_used_for_training": True,
456                "include_all_overloads": False,
457                "debug_info": [options.model_name],
458            },
459        )
460        traced_training_ops_bucket[op_name] = op
461    bucketed_ops.append(traced_training_ops_bucket)
462    # END TRACING BASED BUILD OPS
463
464    # Merge dictionaries together to remove op duplication
465    operators: dict[str, SelectiveBuildOperator] = {}
466    for ops_dict in bucketed_ops:
467        operators = merge_operator_dicts(operators, ops_dict)
468
469    # Loop over all operators, and if any of the them specifies that
470    # all overloads need to be included, then set include_all_non_op_selectives
471    # to True, since it indicates that this operator list came from something
472    # other than a traced operator list.
473    include_all_non_op_selectives = False
474    for op_name, op_info in operators.items():
475        include_all_non_op_selectives = (
476            include_all_non_op_selectives or op_info.include_all_overloads
477        )
478
479    operators_as_dict = {}
480    for k, v in operators.items():
481        operators_as_dict[k] = v.to_dict()
482
483    output["operators"] = operators_as_dict
484
485    output["custom_classes"] = all_custom_classes
486
487    output["build_features"] = all_build_features
488
489    output["include_all_non_op_selectives"] = include_all_non_op_selectives
490    if len(all_kernel_metadata) > 0:
491        kernel_metadata = {}
492        for kt in all_kernel_metadata:
493            kernel_metadata = merge_kernel_metadata(kernel_metadata, kt)
494        output["kernel_metadata"] = kernel_metadata
495
496
497def add_arguments_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
498    parser.add_argument(
499        "--root-ops",
500        "--root_ops",
501        help="A comma separated list of root operators used by the model",
502        required=False,
503    )
504    parser.add_argument(
505        "--training-root-ops",
506        "--training_root_ops",
507        help="A comma separated list of root operators used for training",
508        required=False,
509    )
510    parser.add_argument(
511        "--output-path",
512        "--output_path",
513        help="The location of the output yaml file.",
514        required=True,
515    )
516    parser.add_argument(
517        "--dep-graph-yaml-path",
518        "--dep_graph_yaml_path",
519        type=str,
520        help="A path to the Operator Dependency Graph YAML file.",
521        required=True,
522    )
523    parser.add_argument(
524        "--model-name",
525        "--model_name",
526        type=str,
527        help="The name of the model that uses the specified root operators.",
528        required=True,
529    )
530    parser.add_argument(
531        "--model-versions",
532        "--model_versions",
533        type=str,
534        help="A comma separated list of model versions.",
535        required=False,
536    )
537    parser.add_argument(
538        "--model-assets",
539        "--model_assets",
540        type=str,
541        help="A comma separate list of model asset names (if absent, defaults to all assets for this model).",
542        required=False,
543    )
544    parser.add_argument(
545        "--model-backends",
546        "--model_backends",
547        type=str,
548        default="CPU",
549        help="A comma separated list of model backends.",
550        required=False,
551    )
552    parser.add_argument(
553        "--models-yaml-path",
554        "--models_yaml_path",
555        type=str,
556        help="The paths to the mobile model config YAML files.",
557        required=False,
558        nargs="+",
559    )
560    parser.add_argument(
561        "--include-all-operators",
562        "--include_all_operators",
563        action="store_true",
564        default=False,
565        help="Set this flag to request inclusion of all operators (i.e. build is not selective).",
566        required=False,
567    )
568    parser.add_argument(
569        "--rule-name",
570        "--rule_name",
571        type=str,
572        help="The name of pt_operator_library rule resulting in this generation",
573        required=True,
574    )
575    parser.add_argument(
576        "--not-include-all-overloads-static-root-ops",
577        "--not_include_all_overloads_static_root_ops",
578        action="store_true",
579        default=False,
580        help="Set this flag to not include all overloaded operators for static root ops bucket in fill_output() subroutine",
581        required=False,
582    )
583    parser.add_argument(
584        "--not-include-all-overloads-closure-ops",
585        "--not_include_all_overloads_closure_ops",
586        action="store_true",
587        default=False,
588        help="Set this flag to not include all overloaded operators for closure ops bucket in fill_output() subroutine",
589        required=False,
590    )
591    return parser
592
593
594def parse_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
595    return parser.parse_args()
596
597
598def get_parser_options(parser: argparse.ArgumentParser) -> argparse.Namespace:
599    parser = add_arguments_parser(parser)
600    return parse_options(parser)
601
602
603def main(argv) -> None:
604    parser = argparse.ArgumentParser(description="Generate used operators YAML")
605    options = get_parser_options(parser)
606
607    model_dict = {
608        "model_name": options.model_name,
609        "asset_info": {},
610        "is_new_style_rule": False,
611    }
612    output = {
613        "debug_info": [json.dumps(model_dict)],
614    }
615
616    if options.include_all_operators:
617        output["include_all_operators"] = True
618        output["operators"] = {}
619        output["kernel_metadata"] = {}
620    else:
621        fill_output(output, options)
622
623    with open(options.output_path, "wb") as out_file:
624        out_file.write(
625            yaml.safe_dump(
626                output,
627                default_flow_style=False,
628            ).encode("utf-8")
629        )
630
631
632if __name__ == "__main__":
633    sys.exit(main(sys.argv))
634