xref: /aosp_15_r20/external/executorch/exir/dialects/edge/spec/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from collections import defaultdict
8from functools import partial
9from typing import Any, Dict, List, Optional, Set, Tuple
10
11import torch
12from executorch.exir.dialects.edge.arg.model import BaseArg, BaseKwarg, Return
13from executorch.exir.dialects.edge.op.sample_input import SAMPLE_INPUT
14
15from torch._ops import OpOverloadPacket
16
17
18def is_tensor_arg(arg: torch._C.Argument) -> bool:
19    """Check if a given argument is a Tensor argument."""
20    arg_type_str: str = str(arg.type)
21    return arg_type_str in ["Tensor", "Optional[Tensor]", "List[Tensor]"]
22
23
24def is_tensor_val(arg: Any) -> bool:
25    """Check if a given value is a Tensor-like value.
26    Please make sure the legal value in this function should be same as is_tensor_arg"""
27    if isinstance(arg, torch.Tensor):
28        return True
29    if isinstance(arg, list) or isinstance(arg, tuple):
30        return len(arg) > 0 and all(isinstance(v, torch.Tensor) for v in arg)
31    return False
32
33
34def is_essential_tensor_arg(arg: torch._C.Argument) -> bool:
35    """Check if a given argument is a Tensor argument."""
36    arg_type_str: str = str(arg.type)
37    return arg_type_str in ["Tensor", "List[Tensor]"]
38
39
40def is_optional_tensor_arg(arg: torch._C.Argument) -> bool:
41    """Check if a given argument is a Tensor argument."""
42    arg_type_str: str = str(arg.type)
43    return arg_type_str in ["Optional[Tensor]"]
44
45
46def get_tensor_variable_names(
47    func_schema: torch._C.FunctionSchema,
48) -> Tuple[List[str], List[str], List[str]]:
49    """Get names of essential tensor variables, optional tensor
50    variables and all tensor variables from given function schema.
51    The tensor variables here include both input tensors and output tensors."""
52
53    essential_tensor_arg_names: List[str] = [
54        arg.name for arg in func_schema.arguments if is_essential_tensor_arg(arg)
55    ]
56    optional_tensor_arg_names: List[str] = [
57        arg.name for arg in func_schema.arguments if is_optional_tensor_arg(arg)
58    ]
59    all_tensor_arg_names: List[str] = [
60        arg.name for arg in func_schema.arguments if is_tensor_arg(arg)
61    ]
62
63    return_tensor_variable_names: List[str] = []
64
65    ret_name_base = "__ret_"
66    ret_id = 0
67    for ret in func_schema.returns:
68        name = ret.name if ret.name else f"{ret_name_base}{ret_id}"
69        if is_tensor_arg(ret):
70            return_tensor_variable_names.append(name)
71            ret_id += 1
72    return (
73        essential_tensor_arg_names + return_tensor_variable_names,
74        optional_tensor_arg_names,
75        all_tensor_arg_names + return_tensor_variable_names,
76    )
77
78
79def get_args_rets(op_name: str) -> List[BaseArg]:
80    args_rets: List[BaseArg] = []
81    args_rets.extend(SAMPLE_INPUT[op_name].get("args", []))
82    args_rets.extend(SAMPLE_INPUT[op_name].get("returns", []))
83    return args_rets
84
85
86def get_names_for_args_with_dtype(
87    op_name: str, func_schema: torch._C.FunctionSchema
88) -> List[str]:
89    """Dtype runner is returning dtypes for more arguments than edge dialect cares about.
90    This function returns a list of booleans to select the dtypes matter to edge dialect.
91    """
92    args_rets: List[BaseArg] = get_args_rets(op_name)
93    names = []
94    arguments, returns = func_schema.arguments, func_schema.returns
95    args, kwargs, rets = [], [], []
96    for arg in args_rets:
97        if isinstance(arg, Return):
98            rets.append(arg)
99        elif isinstance(arg, BaseKwarg):
100            kwargs.append(arg)
101        else:
102            args.append(arg)
103    names.extend(
104        [
105            schema.name
106            for sample, schema in zip(args, arguments)
107            if sample.type.has_dtype()
108        ]
109    )
110    names.extend([sample.argname for sample in kwargs if sample.type.has_dtype()])
111    ret_name_base = "__ret_"
112    for ret_id, (_, schema) in enumerate(zip(rets, returns)):
113        names.append(schema.name if schema.name else f"{ret_name_base}{ret_id}")
114    return names
115
116
117def get_torch_op_overload(
118    namespace: str, opname: str, overload: Optional[str]
119) -> torch._ops.OpOverload:
120    packet: OpOverloadPacket = getattr(getattr(torch.ops, namespace), opname)
121    if overload:
122        return getattr(packet, overload)
123    else:
124        return packet.default
125
126
127def group_by_format(
128    all_combinations: Set[Tuple[str]],
129) -> List[Tuple[int, Tuple[Tuple[str]]]]:
130    """Taking combinations that having same format of all_combinations as a group.
131    Two combinations having same format here means one and only one of their
132    corresponding input tensors is different. e.g. {Tensor(0), Tensor(0), Tensor(1)}
133    shares same format with {Tensor(0), Tensor(0), Tensor(0)},
134    but not {Tensor(0), Tensor(1), Tensor(0)}.
135    """
136
137    grouped_combinations: Set[Tuple[int, Tuple[Tuple[str]]]] = set()
138
139    def almost_same_except(b: Tuple[str], combination: Tuple[str], index: int):
140        """Check if a and b share same format"""
141        for i, (aa, bb) in enumerate(zip(combination, b)):
142            if (i == index and aa == bb) or (i != index and aa != bb):
143                return False
144        return True
145
146    for combination in all_combinations:
147        # filter out combinations that only differ at index
148        has_same_comb: bool = False
149        for index in range(len(combination)):
150            filtered: Set[Tuple[str]] = set()
151            filtered.add(combination)
152            combo_filter = partial(
153                almost_same_except, combination=combination, index=index
154            )
155            filtered.update(set(filter(combo_filter, all_combinations)))
156            if len(filtered) > 1:
157                has_same_comb = True
158                grouped_combinations.add((index, tuple(sorted(filtered))))
159        if not has_same_comb:
160            grouped_combinations.add((-1, (combination,)))
161    return list(grouped_combinations)
162
163
164def update_type_alias(type_alias: Dict[Tuple[str], int], new_key: Tuple[str]) -> None:
165    """Update type_alias with new type alias"""
166    if new_key not in type_alias:
167        type_alias[new_key] = len(type_alias)
168
169
170def gen_index_pairs_to_types_mapping(
171    type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]]
172) -> Dict[Tuple[int], List[str]]:
173    """Generate mapping from index pairs to types. For example, given type_constraint [0, 0], [1, 1]
174    type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}.
175    """
176
177    def gen(x: List[int]):
178        """Generate all possible pairs of elements in the list."""
179        for i in range(len(x) - 1):
180            for j in range(i + 1, len(x)):
181                yield (x[i], x[j])
182
183    reverse: Dict[Tuple[int], Set[str]] = defaultdict(set)
184    for constraint in type_constraint:
185        # collect indices of elements with the same value. Value is a list of indices.
186        positions: Dict[int, List[int]] = defaultdict(list)
187        for i, val in enumerate(constraint):
188            positions[val].append(i)
189        for key, val in positions.items():
190            # key is type_alias value which is alias index
191            alias = next(k for k, v in type_alias.items() if v == key)
192            # only care about pairs for now. Key to reverse is the pair of indices where elements are the same. Value is the list of types.
193            for pair in gen(val):
194                reverse[pair].update(alias)
195    return {k: sorted(v) for k, v in reverse.items()}
196
197
198def check_new_alias_fit_constraints(
199    type_alias: Dict[Tuple[str], int],
200    type_constraint: List[List[int]],
201    new_alias: Tuple[str],
202) -> bool:
203    """Check whether new type alias fits the existing constraints.
204    For example, for existing aliases ('Float'): 0, ('Int'): 1, a new alias of ('Float, Int') and type_constraint is [[0, 0]]
205    This new alias doesn't fit because we need [[0, 0], [0, 1]] to be satisfied.
206    """
207    constraint_set: Set[Tuple[int]] = {
208        tuple(constraint) for constraint in type_constraint
209    }
210    length = len(type_constraint[0])
211    subset: Set[Tuple[int]] = {
212        tuple([type_alias[(type_info,)]] * length) for type_info in new_alias
213    }
214    return subset.issubset(constraint_set)
215
216
217def aggregate_if_two_types_being_the_same(
218    type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]]
219) -> Tuple[List[Tuple[str]], List[Tuple[int]]]:
220    """aggregate the type constraints that has two types being the same, at the same position.
221    For example, [0, 0] and [1, 1] where ('Double',): 0, ('Int',): 1 can be aggregated into
222    [2, 2] where ('Double', 'Int'): 3.
223    """
224
225    reverse: Dict[Tuple[int], List[str]] = gen_index_pairs_to_types_mapping(
226        type_alias, type_constraint
227    )
228
229    idx_to_update: Set[int] = set()
230    for alias in reverse.values():
231        alias_tuple = tuple(alias)
232        if alias_tuple in type_alias or not check_new_alias_fit_constraints(
233            type_alias, type_constraint, alias_tuple
234        ):
235            continue
236        idx_to_update.update(
237            v for k, v in type_alias.items() if {*k}.issubset({*alias_tuple})
238        )
239        # update type_alias to include new type alias.
240        type_alias[alias_tuple] = len(type_alias)
241        # replace indices within alias to be new alias index.
242        for i in range(len(type_constraint)):
243            for j, a in enumerate(type_constraint[i]):
244                if a in idx_to_update:
245                    type_constraint[i][j] = type_alias[alias_tuple]
246
247    # remove unused aliases
248    type_alias = {k: v for k, v in type_alias.items() if v not in idx_to_update}
249    sorted_keys = sorted(type_alias.keys())
250    # map indices back to start from 0 contiguous
251    index_map = {type_alias[sorted_keys[i]]: i for i in range(len(sorted_keys))}
252    # remove duplicate constraints
253    constraint_set: Set[Tuple[int]] = {
254        tuple(index_map[i] for i in c) for c in type_constraint
255    }
256
257    return list(sorted_keys), sorted(constraint_set)
258
259
260def aggregate_grouped_type_combinations(
261    grouped_combinations: List[Tuple[int, Tuple[Tuple[str]]]],
262) -> Tuple[Dict[Tuple[str], int], List[List[int]]]:
263    """Aggregate grouped type combinations."""
264    type_alias: Dict[Tuple[str], int] = {}
265    type_constraint: List[List[int]] = []
266    for distinct_id, same_format_combinations in grouped_combinations:
267        comb_iter = iter(same_format_combinations)
268        if len(same_format_combinations) == 1:
269            # can not combine with others; each type in the comb is am individual type alias.
270            comb: Tuple[str] = next(comb_iter)
271            temp_type_constraint: List[int] = []
272            for type_str in comb:
273                update_type_alias(type_alias, (type_str,))
274                temp_type_constraint.append(type_alias[(type_str,)])
275            type_constraint.append(temp_type_constraint)
276        else:
277            # gather different types in each combinations together as a list
278            # make the list as a separate type alias
279            all_distinct_types: Tuple[str] = tuple(
280                sorted({sf_comb[distinct_id] for sf_comb in same_format_combinations})
281            )
282
283            update_type_alias(type_alias, all_distinct_types)
284
285            comb: Tuple[str] = next(comb_iter)
286            temp_type_constraint: List[int] = []
287            # assign each type of the format to a single type alias
288            for i, type_str in enumerate(comb):
289                if i == distinct_id:
290                    temp_type_constraint.append(type_alias[all_distinct_types])
291                else:
292                    update_type_alias(type_alias, (type_str,))
293                    temp_type_constraint.append(type_alias[(type_str,)])
294
295            type_constraint.append(temp_type_constraint)
296    return type_alias, type_constraint
297
298
299def type_aggregrate(
300    allow_types: Set[Tuple[str]],
301) -> Tuple[List[Tuple[str]], List[Tuple[int]]]:
302    """
303    This function aims to aggreate the enumerate combinations of supported types into type alias format.
304    E.g. input: [["Float", "Float", "Float"], ["Half", "Half", "Half"], ["Char", "Char", "Int"]]
305            output: [["Float", "Half"], ["Char"], ["Int"]], [[0, 0, 0], [1, 1, 2]]
306
307            for i-dx list in the type_constraint, any j in [0, len(self.tensor_variable_names)) self.tensor_variable_names[j],
308            can be in one of the types in type_alias[type_constraint[i][j]]; also self.tensor_variable_names[k] and
309            self.tensor_variable_names[l] shoule be same if type_constraint[i][k] == type_constraint[i][l].
310
311    NOTE: This is not the optimum way to aggregate types. It generates correct but not the optimum representation.
312    TODO(gasoonjia): continue update aggregrate algorithm.
313    """
314
315    # group combinations with the same format
316    grouped_combinations: List[Tuple[int, Tuple[Tuple[str]]]] = group_by_format(
317        allow_types
318    )
319
320    type_alias, type_constraint = aggregate_grouped_type_combinations(
321        grouped_combinations
322    )
323
324    sorted_type_alias, sorted_type_constraint = aggregate_if_two_types_being_the_same(
325        type_alias, type_constraint
326    )
327
328    return sorted_type_alias, sorted_type_constraint
329