xref: /aosp_15_r20/external/executorch/exir/dialects/edge/_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Any, Dict, List, Optional, Set, Union
8
9import pkg_resources
10
11import torch
12
13from executorch.exir.dialects.edge.dtype.supported import regular_tensor_str_to_dtypes
14from executorch.exir.dialects.edge.op.api import to_variant
15from executorch.exir.dialects.edge.spec.utils import get_tensor_variable_names
16
17# pyre-ignore
18from ruamel.yaml import YAML
19from torchgen.model import SchemaKind
20
21
22class AllowedDtypeSet:
23    """All legal dtypes for current type alias.
24
25    This class is a wrapper of Set[torch.dtype]. Normally it is a set of all legal types listed in
26    edge/edge.yaml file for each type alias. If one of the argument under the type alias receiving
27    its actual type, AllowedDtypeSet will be degenerated to the set of only the actual type.
28
29    TODO(gasoonjia): Prevent users from misusing.
30
31    Public Attributes:
32        types: a set of all allowed dtypes listed in edge/edge.yaml.
33
34    Private Attributes:
35        _reduced_type: the actual type this type alias currently represents. 0 means unrestricted,
36                       each type in self.types is legal.
37
38    """
39
40    def __init__(self, types: Set[torch.dtype]):
41        self.types: Set[torch.dtype] = types
42        self._reduced_type: Union[torch.dtype, int] = 0
43
44    def reduce_to(self, t: torch.dtype) -> bool:
45        """Reduce the legal dtype to given t.
46        t must be a legal type for this type alias.
47
48        return True if reduction succeed; otherwise False.
49        """
50        if self.__contains__(t):
51            self._reduced_type = t
52            return True
53        else:
54            return False
55
56    def clear(self):
57        """Derestrict AllowedDtypeSet to all allowed dtypes in yaml."""
58        self._reduced_type = 0
59
60    def __contains__(self, key: torch.dtype):
61        """Check if key is a legal type of this type alias"""
62        if self._reduced_type:
63            return key == self._reduced_type
64        return key in self.types
65
66
67class FunctionDtypeConstraint:
68    """Dtype constraint for each EdgeDialect ops.
69
70    Arguments:
71        essential_tensor_io_names: All names of essential tensor inputs and outputs.
72        optional_tensor_io_names: All names of optional tensor inputs.
73        type_alias: Dict of type alias name to corresponding list of dtypes.
74        type_constraint: List of dict containing dtype constraint represented in type alias for each arg name.
75    """
76
77    def __init__(
78        self,
79        essential_tensor_io_names: List[str],
80        optional_tensor_io_names: List[str],
81        type_alias: Dict[str, List[torch.dtype]],
82        type_constraint: List[Dict[str, str]],
83    ):
84        self.essential_tensor_io_names: List[str] = essential_tensor_io_names
85        self.optional_tensor_io_names: List[str] = optional_tensor_io_names
86        self.type_alias: Dict[str, AllowedDtypeSet] = {
87            alias: AllowedDtypeSet(set(types)) for alias, types in type_alias.items()
88        }
89        self.type_constraint: List[Dict[str, str]] = type_constraint
90        # type_constraint's non return entries should include all tensor-like arguments.
91        for t_constraint in self.type_constraint:
92            type_constraint_names = set(t_constraint)
93            all_tensor_arg_names = set(
94                self.essential_tensor_io_names + self.optional_tensor_io_names
95            )
96            if not all_tensor_arg_names.issubset(type_constraint_names):
97                raise RuntimeError(
98                    "Input entries of type_constraint must contain all tensor-like arguments, "
99                    + f"but get {type_constraint_names} and {all_tensor_arg_names}"
100                )
101
102    def validate(self, types: Dict[str, Optional[torch.dtype]]) -> bool:
103        """Check if the given input type combination a legal one of current function.
104
105        Args:
106            types: A dict of arg name to its current dtype.
107
108        Returns:
109            True iff a. types are legal for current operator b. all arg name can be found
110            in current operator and c. input contains all essential tensor inputs; False otherwise.
111
112            The essential tensor inputs here mean non-optional inputs in tensor and tensor list.
113        """
114
115        # Every arg name in `types` should be one of the tensor ios in current function.
116        for arg_name in types:
117            if not self.__contains__(arg_name):
118                return False
119
120        # Any essential tensor input should exist in current `type` input.
121        for io_name in self.essential_tensor_io_names:
122            if io_name not in types:
123                return False
124
125        valid_type = False
126        for constraint in self.type_constraint:
127            if valid_type:
128                break
129
130            valid_type = True
131            # Narrow down the type_alias based on contraint and actual input
132            for arg_name, arg_type in types.items():
133                if arg_type is None:
134                    # None means the user didn't set dtype for this argment
135                    # (i.e. empty tensorlist), skipping the validation.
136                    continue
137                elif arg_type in self.type_alias[constraint[arg_name]]:
138                    self.type_alias[constraint[arg_name]].reduce_to(arg_type)
139                else:
140                    valid_type = False
141                    break
142
143            for alias in self.type_alias.values():
144                alias.clear()
145
146        return valid_type
147
148    def __contains__(self, key: str):
149        return key in self.type_constraint[0]
150
151    def __getitem__(self, arg_name: str) -> Set[torch.dtype]:
152        """Return all legal types for given arg name.
153        Return all its legal type in a set, or an empty set if can not find
154        the arg_name in current function."""
155
156        if arg_name not in self.type_constraint[0]:
157            return set()
158
159        valid_dtype: Set[torch.dtype] = set()
160        for constraint in self.type_constraint:
161            valid_dtype = self.type_alias[constraint[arg_name]].types | valid_dtype
162
163        return valid_dtype
164
165
166def _load_edge_dialect_info() -> Dict[str, Dict[str, Any]]:
167    # pyre-ignore
168    yaml = YAML(typ="safe")
169    edge_dialect_yaml_info = yaml.load(
170        pkg_resources.resource_string(__name__, "edge.yaml").decode("utf8")
171    )
172    if edge_dialect_yaml_info:
173        return {
174            edge_op_yaml_info["inherits"]: edge_op_yaml_info
175            for edge_op_yaml_info in edge_dialect_yaml_info
176        }
177    else:
178        return {}
179
180
181_edge_dialect_info: Dict[str, Dict[str, Any]] = _load_edge_dialect_info()
182
183
184class EdgeDialectArgument:
185    """Argument class for EdgeDialect ops.
186    Wraps around torch._C.Argument with dtype constraints.
187    Redirects all `getattr` calls to torch._C.Argument.
188    """
189
190    def __init__(self, argument: torch._C.Argument, allowed_types: Set[torch.dtype]):
191        self.argument = argument
192        self.allowed_types = allowed_types
193
194    def __getattr__(self, name):
195        if name == "allowed_types":  # arg.allowed_types
196            return self.allowed_types
197        return getattr(self.argument, name)
198
199
200class EdgeDialectFunctionSchema:
201    """FunctionSchema class for EdgeDialect ops.
202    Wraps around torch._C.FunctionSchema with Tensor dtype constraints.
203    In constructor, walk through all Tensor arguments and returns in the original schema
204    for ATen operator, replace the argument with EdgeDialectArgument.
205    """
206
207    def __init__(
208        self,
209        schema: torch._C.FunctionSchema,
210    ):
211        self.schema = schema
212        edge_op_full_name = schema.name + (
213            ".{}".format(schema.overload_name) if schema.overload_name else ""
214        )
215
216        (
217            essential_tensor_io_names,
218            optional_tensor_io_names,
219            all_tensor_io_names,
220        ) = get_tensor_variable_names(self.schema)
221
222        if edge_op_full_name in _edge_dialect_info:
223            # Directly use the information from edge.yaml if available.
224            _edge_op_info = _edge_dialect_info[edge_op_full_name]
225            type_alias = {
226                alias: [regular_tensor_str_to_dtypes[t] for t in types]
227                for alias, types in _edge_op_info["type_alias"].items()
228            }
229            type_constraint = _edge_op_info["type_constraint"]
230        else:
231            # Not get the info from edge.yaml
232            # Create a dtype constraint for this operator that allows any dtype
233            # combinations as long as any dtype is legal in ExecuTorch.
234            type_alias = {
235                f"T{idx}": list(regular_tensor_str_to_dtypes.values())
236                for idx in range(len(all_tensor_io_names))
237            }
238            type_constraint = [
239                {io_name: f"T{idx}" for idx, io_name in enumerate(all_tensor_io_names)}
240            ]
241
242        self.dtype_constraint = FunctionDtypeConstraint(
243            essential_tensor_io_names=essential_tensor_io_names,
244            optional_tensor_io_names=optional_tensor_io_names,
245            type_alias=type_alias,
246            type_constraint=type_constraint,
247        )
248
249        arg_list: List[Union[torch._C.Argument, EdgeDialectArgument]] = []
250        for argument in self.schema.arguments:
251            if argument.name in self.dtype_constraint:
252                arg_list.append(
253                    EdgeDialectArgument(
254                        argument,
255                        self.dtype_constraint[argument.name],
256                    )
257                )
258            else:
259                arg_list.append(argument)
260        self.arguments = arg_list
261        return_names = sorted(
262            n
263            for n in self.dtype_constraint.type_constraint[0].keys()
264            if n.startswith("__ret")
265        )
266        ret_list: List[Union[torch._C.Argument, EdgeDialectArgument]] = []
267        ret_iter = iter(return_names)
268        for ret in self.schema.returns:
269            if isinstance(ret.type, torch.TensorType):
270                name = next(ret_iter, None)
271                if name:
272                    ret_list.append(
273                        EdgeDialectArgument(ret, self.dtype_constraint[name])
274                    )
275                    continue
276            ret_list.append(ret)
277        self.returns = ret_list
278
279    def __getattr__(self, name):
280        if name == "arguments":
281            return self.arguments
282        if name == "returns":
283            return self.returns
284        if name == "dtype_constraint":
285            return self.dtype_constraint
286        return getattr(self.schema, name)
287
288    def __str__(self):
289        return str(self.schema)
290
291
292class EdgeOpOverload:
293    """OpOverload for edge ops.
294    Contains API to find the out variant of this operator overload.
295    """
296
297    def __init__(
298        self,
299        op: torch._ops.OpOverload,
300        schema: EdgeDialectFunctionSchema,
301    ):
302        self._schema = schema
303        self._op = op
304        self.__name__ = f"{self.namespace}.{self._op.__name__}"
305
306    def to_out_variant(self) -> torch._ops.OpOverload:
307        """Find out the out-variant of this operator and return it.
308        TODO (larryliu): Implement execution dialect class and let this function return that.
309        This implementation assumes the out variant is available in torch.ops.*.
310
311        Raises:
312            RuntimeError: if we could't find the out variant, raise an exception.
313            TODO (larryliu): Catch this in BackendDialect and generate an operator definition
314            for missing out variant.
315        Returns:
316            torch._ops.OpOverload: The out-variant operator of self.
317        """
318
319        # return if already found
320        if "_out_variant" in self.__dict__ and self._out_variant:
321            return self._out_variant
322        out_variant = to_variant(self._op, SchemaKind.out)
323        self._out_variant = out_variant
324        return out_variant
325
326    def __getattr__(self, name):
327        if name == "_schema":
328            return self._schema
329        else:
330            return getattr(self._op, name)
331
332    def __call__(self, *args, **kwargs):
333        return self._op(*args, **kwargs)
334
335    def __repr__(self):
336        return "<EdgeOpOverload: {}>: schema = {}".format(
337            self.__name__, self._schema.schema
338        )
339
340    __str__ = __repr__
341
342
343class EdgeOpOverloadPacket:
344    """OpOverloadPacket for edge ops.
345    Wraps torch._ops.OpOverloadPacket and overrides __getattr__ to return OpOverload
346    for Edge ops. The main difference between an Edge op and its corresponding ATen op
347    is that Edge op contains a different schema (see EdgeDialectFunctionSchema).
348    """
349
350    def __init__(
351        self,
352        qualified_op_name: str,  # e.g., edge::aten::add
353        op_name: str,
354        parent_overload_packet: torch._ops.OpOverloadPacket,
355    ):
356        self._parent_overload_packet = parent_overload_packet
357        self._parent_qualified_op_name = parent_overload_packet._qualified_op_name
358        self._qualified_op_name = qualified_op_name
359        self.__name__ = self._qualified_op_name.replace("::", ".")
360        self._op = parent_overload_packet._op
361        self._overload_names = parent_overload_packet._overload_names
362        self._dir = []
363
364    def __repr__(self):
365        return "<EdgeOpOverloadPacket(op='{}', parent_op='{}')>".format(
366            self._qualified_op_name.replace("::", "."),
367            self._parent_qualified_op_name.replace("::", "."),
368        )
369
370    def __hash__(self):
371        return hash(self._op)
372
373    def __str__(self):
374        return "{}".format(self._qualified_op_name.replace("::", "."))
375
376    @property
377    def op(self):
378        return self._op
379
380    def __getattr__(self, key):
381        # It is not a valid op_name when __file__ is passed in
382        if key == "__file__":
383            return "exir.ops.edge"
384        try:
385            parent_overload = getattr(self._parent_overload_packet, key)
386        except AttributeError:
387            raise AttributeError(
388                "The underlying op of '{}' has no overload name '{}'".format(
389                    str(self), key
390                )
391            ) from None
392
393        edge_schema = EdgeDialectFunctionSchema(
394            parent_overload._schema,
395        )  # create a new schema based on parent op schema
396        overload = EdgeOpOverload(
397            parent_overload,
398            edge_schema,
399        )
400        # cache the overload object
401        setattr(self, key, overload)
402        self._dir.append(key)
403        return overload
404
405    def __call__(self, *args, **kwargs):
406        return self._parent_overload_packet(*args, **kwargs or {})
407