xref: /aosp_15_r20/external/pytorch/torchgen/executorch/parse.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from collections import defaultdict, namedtuple
4from typing import Any
5
6import yaml
7
8from torchgen.executorch.model import ETKernelIndex, ETKernelKey
9from torchgen.gen import LineLoader, parse_native_yaml
10from torchgen.model import (
11    BackendMetadata,
12    DispatchKey,
13    FunctionSchema,
14    NativeFunction,
15    OperatorName,
16)
17from torchgen.utils import NamespaceHelper
18
19
20# Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices.
21ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"])
22
23# Fields in native_functions.yaml used to determine which kernels should be used
24ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
25
26
27def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
28    """Given a loaded yaml representing kernel assignment information, extract the
29    mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
30
31    Args:
32        ei: Dict keys {kernels, type_alias, dim_order_alias}
33            See ETKernelKey for description of arguments
34    """
35    e = ei.copy()
36    if (kernels := e.pop("kernels", None)) is None:
37        return {}
38
39    type_alias: dict[str, list[str]] = e.pop("type_alias", {})  # type: ignore[assignment]
40    dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {})  # type: ignore[assignment]
41    dim_order_alias.pop("__line__", None)
42
43    kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
44
45    for entry in kernels:  # type: ignore[attr-defined]
46        arg_meta = entry.get("arg_meta")
47        if arg_meta is not None:
48            arg_meta.pop("__line__")
49
50        kernel_name = entry.get("kernel_name")
51        namespace_helper = NamespaceHelper.from_namespaced_entity(
52            kernel_name, max_level=3
53        )
54        kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
55        backend_metadata = BackendMetadata(
56            kernel=namespace_helper.entity_name,
57            structured=False,
58            cpp_namespace=(kernel_namespace + "::native"),
59        )
60
61        kernel_keys = (
62            [ETKernelKey((), default=True)]
63            if arg_meta is None
64            else ETKernelKey.gen_from_yaml(arg_meta, type_alias, dim_order_alias)  # type: ignore[arg-type]
65        )
66
67        for kernel_key in kernel_keys:
68            assert kernel_key not in kernel_mapping, (
69                "Duplicate kernel key: " + str(kernel_key) + " " + str(e)
70            )
71            kernel_mapping[kernel_key] = backend_metadata
72
73    return kernel_mapping
74
75
76def parse_et_yaml_struct(es: object) -> ETKernelIndex:
77    """Given a loaded yaml representing a list of operators, for each op extract the mapping
78    of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
79    that should be used by the kernel key).
80    """
81    indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
82    for ei in es:  # type: ignore[attr-defined]
83        e = ei.copy()
84
85        funcs = e.pop("func")
86        assert isinstance(funcs, str), f"not a str: {funcs}"
87        namespace_helper = NamespaceHelper.from_namespaced_entity(
88            namespaced_entity=funcs, max_level=1
89        )
90        opname = FunctionSchema.parse(namespace_helper.entity_name).name
91
92        assert opname not in indices, f"Duplicate func found in yaml: {opname} already"
93
94        if len(index := parse_from_yaml(e)) != 0:
95            indices[opname] = index
96
97    return ETKernelIndex(indices)
98
99
100def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
101    """Given a loaded yaml representing a list of operators, extract the
102    kernel key related fields indexed by the operator name.
103    """
104    fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
105    for ei in es:  # type: ignore[attr-defined]
106        funcs = ei.get("func")
107        assert isinstance(funcs, str), f"not a str: {funcs}"
108        namespace_helper = NamespaceHelper.from_namespaced_entity(
109            namespaced_entity=funcs, max_level=1
110        )
111        opname = FunctionSchema.parse(namespace_helper.entity_name).name
112
113        for field in ET_FIELDS:
114            if (value := ei.get(field)) is not None:
115                fields[opname][field] = value
116
117    return fields
118
119
120def parse_et_yaml(
121    path: str,
122    tags_yaml_path: str,
123    ignore_keys: set[DispatchKey] | None = None,
124    skip_native_fns_gen: bool = False,
125) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
126    """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
127    of fields to persist from native_functions.yaml to functions.yaml
128    """
129    with open(path) as f:
130        es = yaml.load(f, Loader=LineLoader)
131
132    et_kernel = extract_kernel_fields(es)
133
134    # Remove ET specific fields from entries for BC compatibility
135    strip_et_fields(es)
136
137    native_yaml = parse_native_yaml(
138        path,
139        tags_yaml_path,
140        ignore_keys,
141        skip_native_fns_gen=skip_native_fns_gen,
142        loaded_yaml=es,
143    )
144    return native_yaml.native_functions, et_kernel
145
146
147def strip_et_fields(es: object) -> None:
148    """Given a loaded yaml representing a list of operators,
149    remove ET specific fields from every entries for BC compatibility
150    """
151    for entry in es:  # type: ignore[attr-defined]
152        for field in ET_FIELDS:
153            entry.pop(field, None)
154